keras: Avoid unneccesary call to .call() when building models with subclassing.
This fixes a regression in the defun microbenchmarks (ResNet50Benchmarks.eager_train_with_defun_gpu_batch_32_channels_first etc.) in tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py seen after https://github.com/tensorflow/tensorflow/commit/9a84277be2cb8233c5c14270db6fcdff31ab4d93 (which embeds a model in model) Without this change, converting a model call to a graph function using something like: model.call = tfe.defun(model.call) could result in redundant nodes being added to the graph function as the model._set_inputs() call would invoke model.call() again. PiperOrigin-RevId: 187391494
Loading
Please sign in to comment