From f35dc0a522ae630902baa5be16d2a53b59266770 Mon Sep 17 00:00:00 2001 From: Bruno Goncalves <882745+brunomorishita@users.noreply.github.com> Date: Sat, 28 Apr 2018 19:24:22 -0300 Subject: [PATCH 001/540] Fix cmake library path for libpng16.a --- tensorflow/contrib/cmake/external/png.cmake | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index ad2af01bc0..1a147e9c8e 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== include (ExternalProject) +include (GNUInstallDirs) set(png_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/png_archive) set(png_URL https://mirror.bazel.build/github.com/glennrp/libpng/archive/v1.6.34.tar.gz) @@ -35,7 +36,7 @@ if(WIN32) endif() endif() else() - set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng16.a) + set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/${CMAKE_INSTALL_LIBDIR}/libpng16.a) endif() set(png_HEADERS -- GitLab From e298fae53bee33eaed6ab152d029db5c6fac34c3 Mon Sep 17 00:00:00 2001 From: JxKing Date: Thu, 31 May 2018 12:55:35 +0800 Subject: [PATCH 002/540] fix multiple values for keyword argument error --- .../contrib/opt/python/training/model_average_optimizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py index b6b10e500b..e4d1ae5d63 100644 --- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py @@ -89,7 +89,9 @@ class ModelAverageCustomGetter(object): self._local_2_global[local_var] = global_variable return local_var else: - return getter(name, trainable, collections, *args, **kwargs) + kwargs['trainable'] = trainable + kwargs['collections'] = collections + return getter(name, *args, **kwargs) class ModelAverageOptimizer(optimizer.Optimizer): -- GitLab From 7004927328cd8166c6858984ec649e4eea0ceab0 Mon Sep 17 00:00:00 2001 From: JxKing Date: Thu, 31 May 2018 12:57:52 +0800 Subject: [PATCH 003/540] fix multiple values for keyword argument for easgd --- .../contrib/opt/python/training/elastic_average_optimizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index 5763593b81..545c3477bf 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -100,7 +100,9 @@ class ElasticAverageCustomGetter(object): self._global_map[local_var] = global_center_variable return local_var else: - return getter(name, trainable, collections, *args, **kwargs) + kwargs['trainable'] = trainable + kwargs['collections'] = collections + return getter(name, *args, **kwargs) class ElasticAverageOptimizer(optimizer.Optimizer): -- GitLab From bdc37544a98cd777e71f83fd1c46a42038004476 Mon Sep 17 00:00:00 2001 From: JxKing Date: Thu, 31 May 2018 12:59:45 +0800 Subject: [PATCH 004/540] place easgd in ea_coustom_getter scope --- .../elastic_average_optimizer_test.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py index 5ed8057b86..9d57dc08f6 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py @@ -79,21 +79,21 @@ def _get_workers(num_workers, period, workers, moving_rate): var_0 = variable_scope.get_variable(initializer=0.0, name="v0") var_1 = variable_scope.get_variable(initializer=1.0, name="v1") - with ops.device("/job:worker/task:" + str(worker_id)): - grads_0 = constant_op.constant(-1.0) - grads_1 = constant_op.constant(-1.0) - - sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) - opt = ElasticAverageOptimizer( - opt=sgd_opt, - num_worker=num_workers, - moving_rate=moving_rate, - communication_period=period, - ea_custom_getter=ea_coustom) - train_op = [ - opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]), - global_step) - ] + with ops.device("/job:worker/task:" + str(worker_id)): + grads_0 = constant_op.constant(-1.0) + grads_1 = constant_op.constant(-1.0) + + sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) + opt = ElasticAverageOptimizer( + opt=sgd_opt, + num_worker=num_workers, + moving_rate=moving_rate, + communication_period=period, + ea_custom_getter=ea_coustom) + train_op = [ + opt.apply_gradients(([grads_0, var_0], [grads_1, var_1]), + global_step) + ] easgd_hook = opt.make_session_run_hook(is_chief, worker_id) # Creates MonitoredSession sess = training.MonitoredTrainingSession( -- GitLab From f4020cfc79582aa689f7a575445b95e60974071f Mon Sep 17 00:00:00 2001 From: JxKing Date: Thu, 31 May 2018 13:01:25 +0800 Subject: [PATCH 005/540] place ma_opt in ma_coustom_getter scope --- .../training/model_average_optimizer_test.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py index 3acd940268..b1fc50a21f 100644 --- a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py @@ -80,28 +80,28 @@ def _get_workers(num_workers, steps, workers): var_0 = variable_scope.get_variable(initializer=0.0, name="v0") var_1 = variable_scope.get_variable(initializer=1.0, name="v1") - with ops.device("/job:worker/task:" + str(worker_id)): - if worker_id == 0: - grads_0 = constant_op.constant(-1.0) - grads_1 = constant_op.constant(-1.0) - else: - grads_0 = constant_op.constant(-2.0) - grads_1 = constant_op.constant(-2.0) - sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) - opt = model_average_optimizer.ModelAverageOptimizer( - opt=sgd_opt, - num_worker=num_workers, - ma_custom_getter=ma_coustom, - is_chief=is_chief, - interval_steps=steps) - train_op = [ - opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]], - global_step) - ] - easgd_hook = opt.make_session_run_hook() + with ops.device("/job:worker/task:" + str(worker_id)): + if worker_id == 0: + grads_0 = constant_op.constant(-1.0) + grads_1 = constant_op.constant(-1.0) + else: + grads_0 = constant_op.constant(-2.0) + grads_1 = constant_op.constant(-2.0) + sgd_opt = gradient_descent.GradientDescentOptimizer(1.0) + opt = model_average_optimizer.ModelAverageOptimizer( + opt=sgd_opt, + num_worker=num_workers, + ma_custom_getter=ma_coustom, + is_chief=is_chief, + interval_steps=steps) + train_op = [ + opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]], + global_step) + ] + ma_hook = opt.make_session_run_hook() # Creates MonitoredSession sess = training.MonitoredTrainingSession( - workers[worker_id].target, hooks=[easgd_hook]) + workers[worker_id].target, hooks=[ma_hook]) sessions.append(sess) graphs.append(graph) -- GitLab From 6c279ad4055a2d568977a02a2eb3b1303117ac15 Mon Sep 17 00:00:00 2001 From: JxKing Date: Thu, 31 May 2018 19:23:32 +0800 Subject: [PATCH 006/540] fix "workers share local variables" error --- .../contrib/opt/python/training/model_average_optimizer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py index e4d1ae5d63..746df77ba2 100644 --- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py @@ -91,7 +91,11 @@ class ModelAverageCustomGetter(object): else: kwargs['trainable'] = trainable kwargs['collections'] = collections - return getter(name, *args, **kwargs) + if ops.GraphKeys.LOCAL_VARIABLES in collections: + with ops.device(self._worker_device): + return getter(name, *args, **kwargs) + else: + return getter(name, *args, **kwargs) class ModelAverageOptimizer(optimizer.Optimizer): -- GitLab From 16c42f0d4826b12a5359281997ee3f8e27fd5a87 Mon Sep 17 00:00:00 2001 From: JxKing Date: Thu, 31 May 2018 19:24:19 +0800 Subject: [PATCH 007/540] fix "workers share local variables" error --- .../opt/python/training/elastic_average_optimizer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py index 545c3477bf..209c4611f3 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py @@ -102,7 +102,12 @@ class ElasticAverageCustomGetter(object): else: kwargs['trainable'] = trainable kwargs['collections'] = collections - return getter(name, *args, **kwargs) + if ops.GraphKeys.LOCAL_VARIABLES in collections: + with ops.device(self._worker_device): + return getter(name, *args, **kwargs) + else: + return getter(name, *args, **kwargs) + class ElasticAverageOptimizer(optimizer.Optimizer): -- GitLab From f78fd433118830482dddbf6055751898a19265de Mon Sep 17 00:00:00 2001 From: jiefangxuanyan <505745416@qq.com> Date: Wed, 13 Jun 2018 17:28:23 +0800 Subject: [PATCH 008/540] Specify endianness in expected_result array to fix #15767. --- tensorflow/python/kernel_tests/decode_raw_op_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/kernel_tests/decode_raw_op_test.py b/tensorflow/python/kernel_tests/decode_raw_op_test.py index 122a9ed469..0bd8bc3c7b 100644 --- a/tensorflow/python/kernel_tests/decode_raw_op_test.py +++ b/tensorflow/python/kernel_tests/decode_raw_op_test.py @@ -79,7 +79,7 @@ class DecodeRawOpTest(test.TestCase): decode = parsing_ops.decode_raw(in_bytes, out_type=dtypes.float16) self.assertEqual([None, None], decode.get_shape().as_list()) - expected_result = np.matrix([[1, -2, -3, 4]], dtype=np.float16) + expected_result = np.matrix([[1, -2, -3, 4]], dtype=" Date: Wed, 8 Aug 2018 14:34:16 -0700 Subject: [PATCH 009/540] Add deprecation warning to tf.gfile.FastGFile. Fixes #12663. --- tensorflow/python/platform/gfile.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py index 45de047894..510701e344 100644 --- a/tensorflow/python/platform/gfile.py +++ b/tensorflow/python/platform/gfile.py @@ -33,6 +33,7 @@ from tensorflow.python.lib.io.file_io import rename as Rename from tensorflow.python.lib.io.file_io import stat as Stat from tensorflow.python.lib.io.file_io import walk as Walk # pylint: enable=unused-import +from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export @@ -52,6 +53,7 @@ class GFile(_FileIO): @tf_export('gfile.FastGFile') +@deprecated(None, 'Use tf.gfile.GFile.') class FastGFile(_FileIO): """File I/O wrappers without thread locking. -- GitLab From 6c14d85b41c565ed9dabc3677aedf76757097242 Mon Sep 17 00:00:00 2001 From: rasmi Date: Wed, 8 Aug 2018 16:35:12 -0700 Subject: [PATCH 010/540] Changed order of export and deprecated decorators. --- tensorflow/python/platform/gfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py index 510701e344..ac53609434 100644 --- a/tensorflow/python/platform/gfile.py +++ b/tensorflow/python/platform/gfile.py @@ -52,8 +52,8 @@ class GFile(_FileIO): super(GFile, self).__init__(name=name, mode=mode) -@tf_export('gfile.FastGFile') @deprecated(None, 'Use tf.gfile.GFile.') +@tf_export('gfile.FastGFile') class FastGFile(_FileIO): """File I/O wrappers without thread locking. -- GitLab From c3c6c45987692e8bc73eff2f10f9ec1a82f55287 Mon Sep 17 00:00:00 2001 From: rasmi Date: Thu, 9 Aug 2018 10:27:37 -0700 Subject: [PATCH 011/540] Moved @deprecated decorator to __init__ --- tensorflow/python/platform/gfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py index ac53609434..5927bc2409 100644 --- a/tensorflow/python/platform/gfile.py +++ b/tensorflow/python/platform/gfile.py @@ -52,7 +52,6 @@ class GFile(_FileIO): super(GFile, self).__init__(name=name, mode=mode) -@deprecated(None, 'Use tf.gfile.GFile.') @tf_export('gfile.FastGFile') class FastGFile(_FileIO): """File I/O wrappers without thread locking. @@ -64,6 +63,7 @@ class FastGFile(_FileIO): invocations in network filesystems). """ + @deprecated(None, 'Use tf.gfile.GFile.') def __init__(self, name, mode='r'): super(FastGFile, self).__init__(name=name, mode=mode) -- GitLab From 22ebbbc60e5d94d67cdf6c26b44919f7dbb8f600 Mon Sep 17 00:00:00 2001 From: feiquan Date: Mon, 13 Aug 2018 23:44:38 +0800 Subject: [PATCH 012/540] extends the tensor index operator to support character access --- tensorflow/contrib/autograph/operators/slices.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py index 04fbeb2f6e..d878bddf3c 100644 --- a/tensorflow/contrib/autograph/operators/slices.py +++ b/tensorflow/contrib/autograph/operators/slices.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import gen_string_ops # TODO(mdan): Support extended slices. @@ -57,6 +58,8 @@ def get_item(target, i, opts): elif tensor_util.is_tensor(target): if target.dtype == dtypes.variant: return _tf_tensor_list_get_item(target, i, opts) + if target.dtype == dtypes.string: + return _tf_tensor_string_get_item(target, i) else: return _tf_tensor_get_item(target, i) else: @@ -81,6 +84,10 @@ def _tf_tensor_get_item(target, i): """Overload of get_item that stages a Tensor (not Tensor list) read.""" return target[i] +def _tf_tensor_string_get_item(target, i): + """Overload of get_item that stages a Tensor string read.""" + x = gen_string_ops.substr(target, i, 1) + return x def _py_get_item(target, i): """Overload of get_item that executes a Python list modification.""" -- GitLab From 349d81c80a5b64ae09a36624571ec24d9e7a8b1d Mon Sep 17 00:00:00 2001 From: feiquan Date: Tue, 14 Aug 2018 00:07:28 +0800 Subject: [PATCH 013/540] add test for gen_item_tensor_string --- tensorflow/contrib/autograph/operators/slices_test.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py index d4aacb9d20..9c0b2c77a1 100644 --- a/tensorflow/contrib/autograph/operators/slices_test.py +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -46,6 +46,13 @@ class SlicesTest(test.TestCase): with self.test_session() as sess: self.assertAllEqual(sess.run(t), [3, 4]) + def test_get_item_tensor_string(self): + initial_str = constant_op.constant("abcd") + t = slices.get_item(initial_str, 1, slices.GetItemOpts(element_dtype=initial_str.dtype)) + + with self.test_session() as sess: + self.assertEqual(sess.run(t), b"b") + if __name__ == '__main__': test.main() -- GitLab From 48aef32dcd356fa6bae490fa1c853b9b2cdd4846 Mon Sep 17 00:00:00 2001 From: kouml Date: Wed, 15 Aug 2018 02:27:32 +0900 Subject: [PATCH 014/540] removing redundant semicolon --- tensorflow/contrib/lite/toco/python/toco_from_protos_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py index 3761e0095e..75c1c8970c 100644 --- a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py +++ b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py @@ -50,7 +50,7 @@ class TocoFromProtosTest(googletest.TestCase): toco_flags.output_format = toco_flags_pb2.TFLITE toco_flags.inference_input_type = types_pb2.FLOAT toco_flags.inference_type = types_pb2.FLOAT - toco_flags.allow_custom_ops = True; + toco_flags.allow_custom_ops = True model_flags = model_flags_pb2.ModelFlags() input_array = model_flags.input_arrays.add() input_array.name = TensorName(in_tensor) -- GitLab From f2134cbd2ec4dd98f9f20ac41e4f46cdd0246af2 Mon Sep 17 00:00:00 2001 From: feiquan Date: Wed, 15 Aug 2018 08:47:22 +0800 Subject: [PATCH 015/540] use get_item_tensor_string for string with rank 0 --- tensorflow/contrib/autograph/operators/slices_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py index 9c0b2c77a1..5300428462 100644 --- a/tensorflow/contrib/autograph/operators/slices_test.py +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -53,6 +53,12 @@ class SlicesTest(test.TestCase): with self.test_session() as sess: self.assertEqual(sess.run(t), b"b") + initial_list_str = constant_op.constant(["abcd", "bcde"]) + t = slices.get_item(initial_list_str, 1, slices.GetItemOpts(element_dtype=initial_str.dtype)) + + with self.test_session() as sess: + self.assertEqual(sess.run(t), b"bcde") + if __name__ == '__main__': test.main() -- GitLab From 1843dc2bef2beabc1ac6765c14e03b1a07823bef Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 23 Jul 2018 14:43:28 -0700 Subject: [PATCH 016/540] Network.to_json should handle numpy.ndarray correctly --- tensorflow/python/keras/engine/network.py | 5 ++++- .../python/keras/engine/topology_test.py | 22 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 708fa1c807..3cdd714d7e 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -1574,7 +1574,10 @@ class Network(base_layer.Layer): def get_json_type(obj): # If obj is any numpy type if type(obj).__module__ == np.__name__: - return obj.item() + if isinstance(obj, np.ndarray): + return obj.tolist() + else: + return obj.item() # If obj is a python 'type' if type(obj).__name__ == type.__name__: diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py index 079c8dae71..3dfa933913 100644 --- a/tensorflow/python/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -913,6 +913,28 @@ class TopologyConstructionTest(test.TestCase): self.assertAllClose(out, x * 0.2 + x * 0.3, atol=1e-4) + def test_constant_initializer_with_numpy(self): + + with self.test_session(): + model = keras.models.Sequential() + model.add( + keras.layers.Dense( + 2, + input_shape = (3,), + kernel_initializer = keras.initializers.Constant(np.ones((3, 2))) + ) + ) + model.add(keras.layers.Dense(3)) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + + json_str = model.to_json() + keras.models.model_from_json(json_str) + + if yaml is not None: + yaml_str = model.to_yaml() + keras.models.model_from_yaml(yaml_str) + + class DeferredModeTest(test.TestCase): def testDeferredTensorAttributes(self): -- GitLab From 5ef4de5b01d10c4dae86a1e69cf1296671d55e47 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 15 Aug 2018 17:40:22 -0700 Subject: [PATCH 017/540] Fix bad indentation --- tensorflow/python/keras/engine/topology_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py index 3dfa933913..25ae3a61c3 100644 --- a/tensorflow/python/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -918,11 +918,11 @@ class TopologyConstructionTest(test.TestCase): with self.test_session(): model = keras.models.Sequential() model.add( - keras.layers.Dense( - 2, - input_shape = (3,), - kernel_initializer = keras.initializers.Constant(np.ones((3, 2))) - ) + keras.layers.Dense( + 2, + input_shape = (3,), + kernel_initializer = keras.initializers.Constant(np.ones((3, 2))) + ) ) model.add(keras.layers.Dense(3)) model.compile(loss='mse', optimizer='sgd', metrics=['acc']) -- GitLab From 4a1fdff581db18e3262daebbc1f9543936bf47d1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 16 Aug 2018 13:14:34 -0700 Subject: [PATCH 018/540] Reorg code to escape bad indentation. --- tensorflow/python/keras/engine/topology_test.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py index 25ae3a61c3..1fcd77d7f6 100644 --- a/tensorflow/python/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -912,18 +912,13 @@ class TopologyConstructionTest(test.TestCase): assert out.shape == (4, 3, 2, 1) self.assertAllClose(out, x * 0.2 + x * 0.3, atol=1e-4) - def test_constant_initializer_with_numpy(self): with self.test_session(): + initializer = keras.initializers.Constant(np.ones((3, 2))) model = keras.models.Sequential() - model.add( - keras.layers.Dense( - 2, - input_shape = (3,), - kernel_initializer = keras.initializers.Constant(np.ones((3, 2))) - ) - ) + model.add(keras.layers.Dense(2, input_shape=(3,), + kernel_initializer=initializer)) model.add(keras.layers.Dense(3)) model.compile(loss='mse', optimizer='sgd', metrics=['acc']) -- GitLab From c4858c15110286b1afd091c70ab4d99549b2e856 Mon Sep 17 00:00:00 2001 From: Lukas Geiger Date: Sat, 18 Aug 2018 10:01:17 +0200 Subject: [PATCH 019/540] [tfgan] Respect use_loss_summaries in GANEstimator Since the refactor done in 47dea684efa41981e10299c2737317c504ce41af the `use_loss_summaries` argument of GANEstimator isn't respected anymore. This PR restores the original behavior and passes `use_loss_summaries` down to the loss functions. --- .../gan/python/estimator/python/gan_estimator_impl.py | 10 ++++++---- .../gan/python/estimator/python/gan_estimator_test.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 8e4affb9b4..3dd066a406 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -187,7 +187,7 @@ class GANEstimator(estimator.Estimator): return _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - get_hooks_fn) + get_hooks_fn, use_loss_summaries) super(GANEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) @@ -214,15 +214,17 @@ def _get_gan_model( def _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - get_hooks_fn=None): + get_hooks_fn=None, use_loss_summaries=True): """Get the EstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = model_fn_lib.EstimatorSpec( mode=mode, predictions=gan_model.generated_data) else: gan_loss = tfgan_tuples.GANLoss( - generator_loss=generator_loss_fn(gan_model), - discriminator_loss=discriminator_loss_fn(gan_model)) + generator_loss=generator_loss_fn( + gan_model, add_summaries=use_loss_summaries), + discriminator_loss=discriminator_loss_fn( + gan_model, add_summaries=use_loss_summaries)) if mode == model_fn_lib.ModeKeys.EVAL: estimator_spec = _get_eval_estimator_spec( gan_model, gan_loss, get_eval_metric_ops_fn) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 9ac9c6ca9c..83f8dd641f 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -116,7 +116,7 @@ def get_dummy_gan_model(): discriminator_fn=None) -def dummy_loss_fn(gan_model): +def dummy_loss_fn(gan_model, add_summaries=True): return math_ops.reduce_sum(gan_model.discriminator_real_outputs - gan_model.discriminator_gen_outputs) -- GitLab From 74c3a77ab3eb91f1ca36c3728e15827246f4d089 Mon Sep 17 00:00:00 2001 From: Artem Sobolev Date: Sun, 19 Aug 2018 12:45:42 +0300 Subject: [PATCH 020/540] Use tf.platform FLAGS wrapper instead of raw absl --- tensorflow/python/ops/parallel_for/pfor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 2e4b2fd64e..6689c309c7 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -21,8 +21,6 @@ from __future__ import print_function import collections -from absl import flags - from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -41,6 +39,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import flags from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest -- GitLab From 0c8c6fc35f5939c9ae54e29c0051090f49cee274 Mon Sep 17 00:00:00 2001 From: Artem Sobolev Date: Sun, 19 Aug 2018 12:46:57 +0300 Subject: [PATCH 021/540] Make SoftplusGrad convertible --- tensorflow/python/ops/parallel_for/pfor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py index 6689c309c7..58fa6447f3 100644 --- a/tensorflow/python/ops/parallel_for/pfor.py +++ b/tensorflow/python/ops/parallel_for/pfor.py @@ -2010,6 +2010,7 @@ def _convert_biasaddgrad(pfor_input): @RegisterPForWithArgs("ReluGrad") @RegisterPForWithArgs("TanhGrad") @RegisterPForWithArgs("SigmoidGrad") +@RegisterPForWithArgs("SoftplusGrad") def _convert_grads(pfor_input, op_type, *args, **kw_args): del args del kw_args -- GitLab From 8c4737fa73d74e0c445a1ac90a4f08e4196f0e34 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 19 Aug 2018 23:12:22 +0000 Subject: [PATCH 022/540] Fix documentation issue with `tf.nn.conv1d` The `tf.nn.conv1d` supports float16, float32, and float64 though in `tf.nn.conv1d.__doc__` only float16 and float32 are mentioned. This fix updates the doc string to add float64 as the supported data type. Signed-off-by: Yong Tang --- tensorflow/python/ops/nn_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index edc6e04b48..b6e8174ace 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -2454,7 +2454,7 @@ def conv1d(value, returned to the caller. Args: - value: A 3D `Tensor`. Must be of type `float16` or `float32`. + value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`. filters: A 3D `Tensor`. Must have the same type as `value`. stride: An `integer`. The number of entries by which the filter is moved right at each step. -- GitLab From 94ef0a70717d83316042cba924e70fe024a51661 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Tue, 21 Aug 2018 21:38:30 +0200 Subject: [PATCH 023/540] Fixed mode in load_inputs_from_input_arg_string NPY files are binary and should be opened with mode "rb". --- tensorflow/python/tools/saved_model_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 38fed5335e..f215ac80ae 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -544,7 +544,7 @@ def load_inputs_from_input_arg_string(inputs_str, input_exprs_str, input_examples = preprocess_input_examples_arg_string(input_examples_str) for input_tensor_key, (filename, variable_name) in inputs.items(): - data = np.load(file_io.FileIO(filename, mode='r')) + data = np.load(file_io.FileIO(filename, mode='rb')) # When a variable_name key is specified for the input file if variable_name: -- GitLab From 4c2f6aeaaf4aeafccc85a289a5a105d52738b410 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 17 Aug 2018 17:06:47 -0400 Subject: [PATCH 024/540] Simplyfing the evaluation step by taking argmax of the softmax of the predictions instead of tf.multinomial --- .../generative_examples/image_captioning_with_attention.ipynb | 2 +- .../python/examples/generative_examples/text_generation.ipynb | 2 +- .../python/examples/nmt_with_attention/nmt_with_attention.ipynb | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index 315d7a4893..e0f7137184 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1056,7 +1056,7 @@ "\n", " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", "\n", - " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n", + " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n", " result.append(index_word[predicted_id])\n", "\n", " if index_word[predicted_id] == '':\n", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index 40bc098724..b13e5aae9b 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -610,7 +610,7 @@ "\n", " # using a multinomial distribution to predict the word returned by the model\n", " predictions = predictions / temperature\n", - " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n", + " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n", " \n", " # We pass the predicted word as the next input to the model\n", " # along with the previous hidden state\n", diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index f1e1f99c57..3e02d9fbb0 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -677,7 +677,7 @@ " attention_weights = tf.reshape(attention_weights, (-1, ))\n", " attention_plot[t] = attention_weights.numpy()\n", "\n", - " predicted_id = tf.multinomial(predictions, num_samples=1)[0][0].numpy()\n", + " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n", "\n", " result += targ_lang.idx2word[predicted_id] + ' '\n", "\n", -- GitLab From c36ff7ae1d667979fa49899bf97de26cf35321de Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 17 Aug 2018 20:44:14 -0400 Subject: [PATCH 025/540] Removing tf.nn.softmax --- .../generative_examples/image_captioning_with_attention.ipynb | 2 +- .../python/examples/generative_examples/text_generation.ipynb | 2 +- .../python/examples/nmt_with_attention/nmt_with_attention.ipynb | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb index e0f7137184..5c753ec0f5 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb @@ -1056,7 +1056,7 @@ "\n", " attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()\n", "\n", - " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", " result.append(index_word[predicted_id])\n", "\n", " if index_word[predicted_id] == '':\n", diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb index b13e5aae9b..e0d5e494d4 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/text_generation.ipynb @@ -610,7 +610,7 @@ "\n", " # using a multinomial distribution to predict the word returned by the model\n", " predictions = predictions / temperature\n", - " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", " \n", " # We pass the predicted word as the next input to the model\n", " # along with the previous hidden state\n", diff --git a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb index 3e02d9fbb0..560fc8c5a2 100644 --- a/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb +++ b/tensorflow/contrib/eager/python/examples/nmt_with_attention/nmt_with_attention.ipynb @@ -677,7 +677,7 @@ " attention_weights = tf.reshape(attention_weights, (-1, ))\n", " attention_plot[t] = attention_weights.numpy()\n", "\n", - " predicted_id = tf.argmax(tf.nn.softmax(predictions[0])).numpy()\n", + " predicted_id = tf.argmax(predictions[0]).numpy()\n", "\n", " result += targ_lang.idx2word[predicted_id] + ' '\n", "\n", -- GitLab From e357bcea4b10d5e5cbc3a4ba59385e832401ba8d Mon Sep 17 00:00:00 2001 From: Dao Zhang Date: Thu, 23 Aug 2018 20:11:10 +0800 Subject: [PATCH 026/540] merge_repeated option is confusing I have the same question with [WIP: Remove invalid merge_repeated option from CTC beam decoder](https://github.com/tensorflow/tensorflow/pull/15586), it's a pity I haven't seen any changes for so long. Generally I will use the default value of merge_repeated: True, but I found it's confusing, that is, I got the wrong anser, it has been explained well in [WIP: Remove invalid merge_repeated option from CTC beam decoder](https://github.com/tensorflow/tensorflow/pull/15586). And the top path in ctc_beam_search_decoder is similar with sequence in ctc_greedy_decoder, this is confusing, I have found the project [CRNN](https://github.com/Belval/CRNN/blob/master/CRNN/crnn.py)(line 167) and some other projects use the wrong settings. So I think it's better to give a explain here, this has no conflict with the existing code. --- tensorflow/python/ops/ctc_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py index 908e793902..6bfe405b2b 100644 --- a/tensorflow/python/ops/ctc_ops.py +++ b/tensorflow/python/ops/ctc_ops.py @@ -242,11 +242,11 @@ def ctc_beam_search_decoder(inputs, sequence_length, beam_width=100, If `merge_repeated` is `True`, merge repeated classes in the output beams. This means that if consecutive entries in a beam are the same, - only the first of these is emitted. That is, when the top path - is `A B B B B`, the return value is: + only the first of these is emitted. That is, when the sequence is `A B B * B * B` (where '*' + is the blank label), the return value is: * `A B` if `merge_repeated = True`. - * `A B B B B` if `merge_repeated = False`. + * `A B B B` if `merge_repeated = False`. Args: inputs: 3-D `float` `Tensor`, size -- GitLab From 512f95d4b5e350fa0709aeef975730f22112b970 Mon Sep 17 00:00:00 2001 From: Clayne Robison Date: Fri, 24 Aug 2018 11:34:10 -0700 Subject: [PATCH 027/540] [Intel MKL] Adding cc tests to the MKL public CI tests. --- tensorflow/tools/ci_build/linux/cpu/run_mkl.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh index 2a9f295188..7be5f454ec 100755 --- a/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh +++ b/tensorflow/tools/ci_build/linux/cpu/run_mkl.sh @@ -33,7 +33,7 @@ yes "" | $PYTHON_BIN_PATH configure.py # Setting KMP_BLOCKTIME to 0 lets OpenMP threads to sleep right after parallel execution # in an MKL primitive. This reduces the effects of an oversubscription of OpenMP threads # caused by executing multiple tests concurrently. -bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=py -k \ +bazel test --test_tag_filters=-no_oss,-oss_serial,-gpu,-benchmark-test --test_lang_filters=cc,py -k \ --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 --build_tests_only \ --config=mkl --test_env=KMP_BLOCKTIME=0 --config=opt --test_output=errors -- \ //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... -- GitLab From a7deb79f258a5dded26fcf85e9416f8463def451 Mon Sep 17 00:00:00 2001 From: Loo Rong Jie Date: Wed, 11 Jul 2018 11:24:58 +0800 Subject: [PATCH 028/540] [XLA/AOT] Build LLVM with Bazel on Windows --- third_party/llvm/llvm.bzl | 170 +++++++++++++++++++++++++++++++------- 1 file changed, 141 insertions(+), 29 deletions(-) diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl index d493a3c476..626e0db3b1 100644 --- a/third_party/llvm/llvm.bzl +++ b/third_party/llvm/llvm.bzl @@ -150,6 +150,35 @@ def expand_cmake_vars(name, src, dst, cmake_vars): # The set of CMake variables common to all targets. cmake_vars = { + # LLVM features + "ENABLE_BACKTRACES": 1, + "LLVM_BINDIR": "/dev/null", + "LLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING": 0, + "LLVM_ENABLE_ABI_BREAKING_CHECKS": 0, + "LLVM_ENABLE_THREADS": 1, + "LLVM_ENABLE_ZLIB": 1, + "LLVM_HAS_ATOMICS": 1, + "LLVM_INCLUDEDIR": "/dev/null", + "LLVM_INFODIR": "/dev/null", + "LLVM_MANDIR": "/dev/null", + "LLVM_NATIVE_TARGET": 1, + "LLVM_NATIVE_TARGETINFO": 1, + "LLVM_NATIVE_TARGETMC": 1, + "LLVM_NATIVE_ASMPRINTER": 1, + "LLVM_NATIVE_ASMPARSER": 1, + "LLVM_NATIVE_DISASSEMBLER": 1, + "LLVM_PREFIX": "/dev/null", + "LLVM_VERSION_MAJOR": 0, + "LLVM_VERSION_MINOR": 0, + "LLVM_VERSION_PATCH": 0, + "PACKAGE_NAME": "llvm", + "PACKAGE_STRING": "llvm tensorflow-trunk", + "PACKAGE_VERSION": "tensorflow-trunk", + "RETSIGTYPE": "void", +} + +# The set of CMake variables common to POSIX targets. +posix_cmake_vars = { # Headers "HAVE_DIRENT_H": 1, "HAVE_DLFCN_H": 1, @@ -206,32 +235,8 @@ cmake_vars = { "HAVE__UNWIND_BACKTRACE": 1, # LLVM features - "ENABLE_BACKTRACES": 1, - "LLVM_BINDIR": "/dev/null", - "LLVM_DISABLE_ABI_BREAKING_CHECKS_ENFORCING": 0, - "LLVM_ENABLE_ABI_BREAKING_CHECKS": 0, - "LLVM_ENABLE_THREADS": 1, - "LLVM_ENABLE_ZLIB": 1, - "LLVM_HAS_ATOMICS": 1, - "LLVM_INCLUDEDIR": "/dev/null", - "LLVM_INFODIR": "/dev/null", - "LLVM_MANDIR": "/dev/null", - "LLVM_NATIVE_TARGET": 1, - "LLVM_NATIVE_TARGETINFO": 1, - "LLVM_NATIVE_TARGETMC": 1, - "LLVM_NATIVE_ASMPRINTER": 1, - "LLVM_NATIVE_ASMPARSER": 1, - "LLVM_NATIVE_DISASSEMBLER": 1, "LLVM_ON_UNIX": 1, - "LLVM_PREFIX": "/dev/null", - "LLVM_VERSION_MAJOR": 0, - "LLVM_VERSION_MINOR": 0, - "LLVM_VERSION_PATCH": 0, "LTDL_SHLIB_EXT": ".so", - "PACKAGE_NAME": "llvm", - "PACKAGE_STRING": "llvm tensorflow-trunk", - "PACKAGE_VERSION": "tensorflow-trunk", - "RETSIGTYPE": "void", } # CMake variables specific to the Linux platform @@ -247,6 +252,40 @@ darwin_cmake_vars = { "HAVE_MALLOC_MALLOC_H": 1, } +# CMake variables specific to the Windows platform. +win32_cmake_vars = { + # Headers + "HAVE_ERRNO_H": 1, + "HAVE_EXECINFO_H": 1, + "HAVE_FCNTL_H": 1, + "HAVE_FENV_H": 1, + "HAVE_INTTYPES_H": 1, + "HAVE_MALLOC_H": 1, + "HAVE_SIGNAL_H": 1, + "HAVE_STDINT_H": 1, + "HAVE_SYS_STAT_H": 1, + "HAVE_SYS_TYPES_H": 1, + "HAVE_ZLIB_H": 1, + + # Features + "BACKTRACE_HEADER": "execinfo.h", + "HAVE_GETCWD": 1, + "HAVE_INT64_T": 1, + "HAVE_STRERROR": 1, + "HAVE_STRTOLL": 1, + "HAVE_SYSCONF": 1, + "HAVE_UINT64_T": 1, + "HAVE__CHSIZE_S": 1, + "HAVE___CHKSTK": 1, + + # MSVC specific + "stricmp": "_stricmp", + "strdup": "_strdup", + + # LLVM features + "LTDL_SHLIB_EXT": ".dll", +} + # Select a set of CMake variables based on the platform. # TODO(phawkins): use a better method to select the right host triple, rather # than hardcoding x86_64. @@ -265,6 +304,13 @@ llvm_all_cmake_vars = select({ linux_cmake_vars, ), ), + "@org_tensorflow//tensorflow:windows": cmake_var_string( + _dict_add( + cmake_vars, + llvm_target_cmake_vars("X86", "x86_64-pc-win32"), + win32_cmake_vars, + ), + ), "//conditions:default": cmake_var_string( _dict_add( cmake_vars, @@ -274,23 +320,89 @@ llvm_all_cmake_vars = select({ ), }) -llvm_linkopts = ["-ldl", "-lm", "-lpthread"] +llvm_linkopts = select({ + "@org_tensorflow//tensorflow:windows": [], + "//conditions:default": ["-ldl", "-lm", "-lpthread"], +}) -llvm_defines = [ - "LLVM_ENABLE_STATS", +llvm_defines = select({ + "@org_tensorflow//tensorflow:windows": [ + "_CRT_SECURE_NO_DEPRECATE", + "_CRT_SECURE_NO_WARNINGS", + "_CRT_NONSTDC_NO_DEPRECATE", + "_CRT_NONSTDC_NO_WARNINGS", + "_SCL_SECURE_NO_DEPRECATE", + "_SCL_SECURE_NO_WARNINGS", + "UNICODE", + "_UNICODE", + ], + "//conditions:default": ["_DEBUG"], +}) + [ "__STDC_LIMIT_MACROS", "__STDC_CONSTANT_MACROS", "__STDC_FORMAT_MACROS", - "_DEBUG", "LLVM_BUILD_GLOBAL_ISEL", ] -llvm_copts = [] +llvm_copts = select({ + "@org_tensorflow//tensorflow:windows": [ + "-Zc:inline", + "-Zc:strictStrings", + "-Zc:rvalueCast", + "-Oi", + "-wd4141", + "-wd4146", + "-wd4180", + "-wd4244", + "-wd4258", + "-wd4267", + "-wd4291", + "-wd4345", + "-wd4351", + "-wd4355", + "-wd4456", + "-wd4457", + "-wd4458", + "-wd4459", + "-wd4503", + "-wd4624", + "-wd4722", + "-wd4800", + "-wd4100", + "-wd4127", + "-wd4512", + "-wd4505", + "-wd4610", + "-wd4510", + "-wd4702", + "-wd4245", + "-wd4706", + "-wd4310", + "-wd4701", + "-wd4703", + "-wd4389", + "-wd4611", + "-wd4805", + "-wd4204", + "-wd4577", + "-wd4091", + "-wd4592", + "-wd4319", + "-wd4324", + "-w14062", + "-we4238", + ], + "//conditions:default": [], +}) # Platform specific sources for libSupport. def llvm_support_platform_specific_srcs_glob(): return select({ + "@org_tensorflow//tensorflow:windows": native.glob([ + "lib/Support/Windows/*.inc", + "lib/Support/Windows/*.h" + ]), "//conditions:default": native.glob([ "lib/Support/Unix/*.inc", "lib/Support/Unix/*.h", -- GitLab From 4a4ce8c6bff872f2a5522b289845491ea2da6f1e Mon Sep 17 00:00:00 2001 From: Loo Rong Jie Date: Wed, 11 Jul 2018 11:32:54 +0800 Subject: [PATCH 029/540] Add back LLVM_ENABLE_STATS --- third_party/llvm/llvm.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl index 626e0db3b1..6da3e0755c 100644 --- a/third_party/llvm/llvm.bzl +++ b/third_party/llvm/llvm.bzl @@ -338,6 +338,7 @@ llvm_defines = select({ ], "//conditions:default": ["_DEBUG"], }) + [ + "LLVM_ENABLE_STATS", "__STDC_LIMIT_MACROS", "__STDC_CONSTANT_MACROS", "__STDC_FORMAT_MACROS", -- GitLab From d0b4230bc3052f080c901f7d999cf848c7d81450 Mon Sep 17 00:00:00 2001 From: Loo Rong Jie Date: Sat, 11 Aug 2018 18:11:47 +0800 Subject: [PATCH 030/540] Actually add posix_cmake_vars --- third_party/llvm/llvm.bzl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/llvm/llvm.bzl b/third_party/llvm/llvm.bzl index 6da3e0755c..586935b6e6 100644 --- a/third_party/llvm/llvm.bzl +++ b/third_party/llvm/llvm.bzl @@ -294,6 +294,7 @@ llvm_all_cmake_vars = select({ _dict_add( cmake_vars, llvm_target_cmake_vars("X86", "x86_64-apple-darwin"), + posix_cmake_vars, darwin_cmake_vars, ), ), @@ -301,6 +302,7 @@ llvm_all_cmake_vars = select({ _dict_add( cmake_vars, llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu"), + posix_cmake_vars, linux_cmake_vars, ), ), @@ -315,6 +317,7 @@ llvm_all_cmake_vars = select({ _dict_add( cmake_vars, llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu"), + posix_cmake_vars, linux_cmake_vars, ), ), -- GitLab From b4fe246e9680192532a949292ef10e95c0f8b98c Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 26 Aug 2018 19:08:56 +0000 Subject: [PATCH 031/540] Fix incorrect link in `dockerfiles/README.md` This fix fixes incorrect link in `dockerfiles/README.md`. --- tensorflow/tools/dockerfiles/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/tools/dockerfiles/README.md b/tensorflow/tools/dockerfiles/README.md index c484c162cb..d64db35afb 100644 --- a/tensorflow/tools/dockerfiles/README.md +++ b/tensorflow/tools/dockerfiles/README.md @@ -2,8 +2,8 @@ This directory houses TensorFlow's Dockerfiles. **DO NOT EDIT THE DOCKERFILES MANUALLY!** They are maintained by `assembler.py`, which builds Dockerfiles from -the files in `partials/` and the rules in `spec.yml`. See [the Maintaining -section](#maintaining) for more information. +the files in `partials/` and the rules in `spec.yml`. See [the Contributing +section](#contributing) for more information. ## Building -- GitLab From 476f65230982842fdd7fabe2ed8d80ee719c20dc Mon Sep 17 00:00:00 2001 From: "William D. Irons" Date: Mon, 27 Aug 2018 13:29:52 -0500 Subject: [PATCH 032/540] Disable GPU test for scatter_add_ndim_op_test As scatter_add_ndim doesn't have implementation for GPU, the test needs to be excluded from GPU test to prevent it from failing. Currently fails on both x86_64 and ppc64le. Fixes #21833 --- tensorflow/contrib/tensor_forest/BUILD | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index cf55fec488..4008699dda 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -462,7 +462,10 @@ py_test( size = "small", srcs = ["python/kernel_tests/scatter_add_ndim_op_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip_gpu"], + tags = [ + "no_pip_gpu", + "no_gpu", + ], deps = [ ":tensor_forest_ops_py", "//tensorflow/python:framework_test_lib", -- GitLab From b146281fd7f11325251fb085aca6bda8e2d77bfd Mon Sep 17 00:00:00 2001 From: Niranjan Hasabnis Date: Mon, 27 Aug 2018 11:33:21 -0700 Subject: [PATCH 033/540] [Intel MKL] Using default CPU allocator for small allocations in MklCPUAllocator This PR adds support to use default CPU allocator for handling small-size allocations. We found that BFC allocator does not do well on small allocations, but is good for large allocations. --- .../core/common_runtime/mkl_cpu_allocator.h | 177 +++++++++++++++++- 1 file changed, 168 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 99bd43e090..2778213a82 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/framework/allocator_registry.h" +#include "tensorflow/core/platform/mutex.h" #ifndef INTEL_MKL_DNN_ONLY #include "i_malloc.h" @@ -48,6 +50,120 @@ class MklSubAllocator : public SubAllocator { void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); } }; +/// CPU allocator that handles small-size allocations by calling +/// suballocator directly. Mostly, it is just a wrapper around a suballocator +/// (that calls malloc and free directly) with support for bookkeeping. +class MklSmallSizeAllocator : public VisitableAllocator { + public: + MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory, + const string& name) : sub_allocator_(sub_allocator), + name_(name) { + stats_.bytes_limit = total_memory; + } + ~MklSmallSizeAllocator() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(MklSmallSizeAllocator); + + inline string Name() override { return name_; } + + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + void* ptr = nullptr; + if ((ptr = sub_allocator_->Alloc(alignment, num_bytes)) != nullptr) { + std::pair map_val(ptr, num_bytes); + mutex_lock l(mutex_); + // Check that insertion in the hash map was successful. + CHECK_EQ(map_.insert(map_val).second, true); + // Increment statistics for small-size allocations. + IncrementStats(num_bytes); + // Call alloc visitors. + for (const auto& visitor : alloc_visitors_) { + visitor(ptr, num_bytes); + } + } + return ptr; + } + + void DeallocateRaw(void* ptr) override { + if (ptr == nullptr) { + LOG(ERROR) << "tried to deallocate nullptr"; + return; + } + + mutex_lock l(mutex_); + auto map_iter = map_.find(ptr); + if (map_iter != map_.end()) { + // Call free visitors. + size_t dealloc_bytes = map_iter->second; + for (const auto& visitor : free_visitors_) { + visitor(ptr, dealloc_bytes); + } + sub_allocator_->Free(ptr, dealloc_bytes); + DecrementStats(dealloc_bytes); + map_.erase(map_iter); + } + } + + inline bool IsSmallSizeAllocation(const void* ptr) const { + mutex_lock l(mutex_); + return map_.find(ptr) != map_.end(); + } + + void GetStats(AllocatorStats* stats) override { + mutex_lock l(mutex_); + *stats = stats_; + } + + void ClearStats() override { + mutex_lock l(mutex_); + stats_.Clear(); + } + + void AddAllocVisitor(Visitor visitor) override { + mutex_lock l(mutex_); + alloc_visitors_.push_back(visitor); + } + + void AddFreeVisitor(Visitor visitor) override { + mutex_lock l(mutex_); + free_visitors_.push_back(visitor); + } + + private: + /// Increment statistics for the allocator handling small allocations. + inline void IncrementStats(size_t alloc_size) { + ++stats_.num_allocs; + stats_.bytes_in_use += alloc_size; + stats_.max_bytes_in_use = std::max(stats_.max_bytes_in_use, + stats_.bytes_in_use); + stats_.max_alloc_size = std::max(alloc_size, + static_cast(stats_.max_alloc_size)); + } + + /// Decrement statistics for the allocator handling small allocations. + inline void DecrementStats(size_t dealloc_size) { + stats_.bytes_in_use -= dealloc_size; + } + + SubAllocator* sub_allocator_; // Not owned by this class. + + /// Mutex for protecting updates to map of allocations. + mutable mutex mutex_; + + /// Allocator name + string name_; + + /// Hash map to keep track of "small" allocations + /// We do not use BFC allocator for small allocations. + std::unordered_map map_ GUARDED_BY(mutex_); + + /// Allocator stats for small allocs + AllocatorStats stats_ GUARDED_BY(mutex_); + + /// Visitors + std::vector alloc_visitors_ GUARDED_BY(mutex_); + std::vector free_visitors_ GUARDED_BY(mutex_); +}; + /// CPU allocator for MKL that wraps BFC allocator and intercepts /// and redirects memory allocation calls from MKL. class MklCPUAllocator : public VisitableAllocator { @@ -62,7 +178,10 @@ class MklCPUAllocator : public VisitableAllocator { MklCPUAllocator() { TF_CHECK_OK(Initialize()); } - ~MklCPUAllocator() override { delete allocator_; } + ~MklCPUAllocator() override { + delete small_size_allocator_; + delete large_size_allocator_; + } Status Initialize() { VLOG(2) << "MklCPUAllocator: In MklCPUAllocator"; @@ -96,7 +215,11 @@ class MklCPUAllocator : public VisitableAllocator { } VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes; - allocator_ = new BFCAllocator(new MklSubAllocator, max_mem_bytes, + + sub_allocator_ = new MklSubAllocator(); + small_size_allocator_ = new MklSmallSizeAllocator(sub_allocator_, + max_mem_bytes, kName); + large_size_allocator_ = new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName); #ifndef INTEL_MKL_DNN_ONLY // For redirecting all allocations from MKL to this allocator @@ -112,23 +235,52 @@ class MklCPUAllocator : public VisitableAllocator { inline string Name() override { return kName; } inline void* AllocateRaw(size_t alignment, size_t num_bytes) override { - return allocator_->AllocateRaw(alignment, num_bytes); + // If the allocation size is less than threshold, call small allocator, + // otherwise call large-size allocator (BFC). We found that BFC allocator + // does not deliver good performance for small allocations when + // inter_op_parallelism_threads is high. + return (num_bytes < kSmallAllocationsThreshold) ? + small_size_allocator_->AllocateRaw(alignment, num_bytes) : + large_size_allocator_->AllocateRaw(alignment, num_bytes); } inline void DeallocateRaw(void* ptr) override { - allocator_->DeallocateRaw(ptr); + // Check if ptr is for "small" allocation. If it is, then call Free + // directly. Otherwise, call BFC to handle free. + if (small_size_allocator_->IsSmallSizeAllocation(ptr)) { + small_size_allocator_->DeallocateRaw(ptr); + } else { + large_size_allocator_->DeallocateRaw(ptr); + } } - void GetStats(AllocatorStats* stats) override { allocator_->GetStats(stats); } + void GetStats(AllocatorStats* stats) override { + AllocatorStats l_stats, s_stats; + small_size_allocator_->GetStats(&s_stats); + large_size_allocator_->GetStats(&l_stats); + + // Combine statistics from small-size and large-size allocator. + stats->num_allocs = l_stats.num_allocs + s_stats.num_allocs; + stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use; + stats->max_bytes_in_use = l_stats.max_bytes_in_use + + s_stats.max_bytes_in_use; + stats->max_alloc_size = std::max(l_stats.max_alloc_size, + s_stats.max_alloc_size); + } - void ClearStats() override { allocator_->ClearStats(); } + void ClearStats() override { + small_size_allocator_->ClearStats(); + large_size_allocator_->ClearStats(); + } void AddAllocVisitor(Visitor visitor) override { - allocator_->AddAllocVisitor(visitor); + small_size_allocator_->AddAllocVisitor(visitor); + large_size_allocator_->AddAllocVisitor(visitor); } void AddFreeVisitor(Visitor visitor) override { - allocator_->AddFreeVisitor(visitor); + small_size_allocator_->AddFreeVisitor(visitor); + large_size_allocator_->AddFreeVisitor(visitor); } private: @@ -165,7 +317,14 @@ class MklCPUAllocator : public VisitableAllocator { /// The alignment that we need for the allocations static constexpr const size_t kAlignment = 64; - VisitableAllocator* allocator_; // owned by this class + VisitableAllocator* large_size_allocator_; // owned by this class + MklSmallSizeAllocator* small_size_allocator_; // owned by this class. + + SubAllocator* sub_allocator_; // not owned by this class + + /// Size in bytes that defines the upper-bound for "small" allocations. + /// Any allocation below this threshold is "small" allocation. + static constexpr const size_t kSmallAllocationsThreshold = 4096; }; } // namespace tensorflow -- GitLab From 713cc582954399763a078226b62953bba1450b91 Mon Sep 17 00:00:00 2001 From: Ming Li Date: Wed, 29 Aug 2018 15:02:49 +0100 Subject: [PATCH 034/540] minor typo in `make_callable` method. --- tensorflow/python/client/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 1841dd998b..c04d289773 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1132,7 +1132,7 @@ class BaseSession(SessionInterface): for details of the allowable fetch types. feed_list: (Optional.) A list of `feed_dict` keys. See `tf.Session.run` for details of the allowable feed key types. - accept_options: (Optional.) Iff `True`, the returned `Callable` will be + accept_options: (Optional.) If `True`, the returned `Callable` will be able to accept `tf.RunOptions` and `tf.RunMetadata` as optional keyword arguments `options` and `run_metadata`, respectively, with the same syntax and semantics as `tf.Session.run`, which is useful -- GitLab From 30d2046f016f948f5b572be2f2f4f649f34d576d Mon Sep 17 00:00:00 2001 From: Jason Zaman Date: Fri, 31 Aug 2018 15:39:06 +0800 Subject: [PATCH 035/540] third_party: update libjpeg-turbo to 2.0.0 libjpeg-turbo-2.0.0 fixes CVE-2018-1152 and CVE-2018-11813 The build and source tree has been rearranged, the simd files are now in subdirs. Signed-off-by: Jason Zaman --- tensorflow/workspace.bzl | 8 +- third_party/jpeg/jpeg.BUILD | 324 ++++++++++++++++++++++-------------- 2 files changed, 201 insertions(+), 131 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index e131c532cb..758c94c542 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -240,11 +240,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "jpeg", urls = [ - "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz", - "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.3.tar.gz", + "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz", + "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz", ], - sha256 = "1a17020f859cb12711175a67eab5c71fc1904e04b587046218e36106e07eabde", - strip_prefix = "libjpeg-turbo-1.5.3", + sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b", + strip_prefix = "libjpeg-turbo-2.0.0", build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"), system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"), ) diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/jpeg.BUILD index 96e7ac061c..946f13de12 100644 --- a/third_party/jpeg/jpeg.BUILD +++ b/third_party/jpeg/jpeg.BUILD @@ -144,27 +144,27 @@ cc_library( "jpeglib.h", "jsimd.h", "jsimddct.h", - "simd/jccolor-altivec.c", - "simd/jcgray-altivec.c", - "simd/jcsample.h", - "simd/jcsample-altivec.c", - "simd/jdcolor-altivec.c", - "simd/jdmerge-altivec.c", - "simd/jdsample-altivec.c", - "simd/jfdctfst-altivec.c", - "simd/jfdctint-altivec.c", - "simd/jidctfst-altivec.c", - "simd/jidctint-altivec.c", - "simd/jquanti-altivec.c", "simd/jsimd.h", - "simd/jsimd_altivec.h", - "simd/jsimd_powerpc.c", + "simd/powerpc/jccolor-altivec.c", + "simd/powerpc/jcgray-altivec.c", + "simd/powerpc/jcsample-altivec.c", + "simd/powerpc/jdcolor-altivec.c", + "simd/powerpc/jdmerge-altivec.c", + "simd/powerpc/jdsample-altivec.c", + "simd/powerpc/jfdctfst-altivec.c", + "simd/powerpc/jfdctint-altivec.c", + "simd/powerpc/jidctfst-altivec.c", + "simd/powerpc/jidctint-altivec.c", + "simd/powerpc/jquanti-altivec.c", + "simd/powerpc/jsimd.c", ], hdrs = [ - "simd/jccolext-altivec.c", # should have been named .inc - "simd/jcgryext-altivec.c", # should have been named .inc - "simd/jdcolext-altivec.c", # should have been named .inc - "simd/jdmrgext-altivec.c", # should have been named .inc + "simd/powerpc/jccolext-altivec.c", + "simd/powerpc/jcgryext-altivec.c", + "simd/powerpc/jdcolext-altivec.c", + "simd/powerpc/jdmrgext-altivec.c", + "simd/powerpc/jcsample.h", + "simd/powerpc/jsimd_altivec.h", ], copts = libjpegturbo_copts, nocopts = libjpegturbo_nocopts, @@ -175,6 +175,7 @@ cc_library( srcs = [ "jchuff.h", "jconfig.h", + "jconfigint.h", "jdct.h", "jerror.h", "jinclude.h", @@ -183,24 +184,35 @@ cc_library( "jpeglib.h", "jsimd.h", "jsimddct.h", - "simd/jccolor-sse2-64.o", - "simd/jcgray-sse2-64.o", - "simd/jchuff-sse2-64.o", - "simd/jcsample-sse2-64.o", - "simd/jdcolor-sse2-64.o", - "simd/jdmerge-sse2-64.o", - "simd/jdsample-sse2-64.o", - "simd/jfdctflt-sse-64.o", - "simd/jfdctfst-sse2-64.o", - "simd/jfdctint-sse2-64.o", - "simd/jidctflt-sse2-64.o", - "simd/jidctfst-sse2-64.o", - "simd/jidctint-sse2-64.o", - "simd/jidctred-sse2-64.o", - "simd/jquantf-sse2-64.o", - "simd/jquanti-sse2-64.o", "simd/jsimd.h", - "simd/jsimd_x86_64.c", + "simd/x86_64/jsimd.c", + "simd/x86_64/jccolor-avx2.o", + "simd/x86_64/jccolor-sse2.o", + "simd/x86_64/jcgray-avx2.o", + "simd/x86_64/jcgray-sse2.o", + "simd/x86_64/jchuff-sse2.o", + "simd/x86_64/jcphuff-sse2.o", + "simd/x86_64/jcsample-avx2.o", + "simd/x86_64/jcsample-sse2.o", + "simd/x86_64/jdcolor-avx2.o", + "simd/x86_64/jdcolor-sse2.o", + "simd/x86_64/jdmerge-avx2.o", + "simd/x86_64/jdmerge-sse2.o", + "simd/x86_64/jdsample-avx2.o", + "simd/x86_64/jdsample-sse2.o", + "simd/x86_64/jfdctflt-sse.o", + "simd/x86_64/jfdctfst-sse2.o", + "simd/x86_64/jfdctint-avx2.o", + "simd/x86_64/jfdctint-sse2.o", + "simd/x86_64/jidctflt-sse2.o", + "simd/x86_64/jidctfst-sse2.o", + "simd/x86_64/jidctint-avx2.o", + "simd/x86_64/jidctint-sse2.o", + "simd/x86_64/jidctred-sse2.o", + "simd/x86_64/jquantf-sse2.o", + "simd/x86_64/jquanti-avx2.o", + "simd/x86_64/jquanti-sse2.o", + "simd/x86_64/jsimdcpu.o", ], copts = libjpegturbo_copts, linkstatic = 1, @@ -210,57 +222,88 @@ cc_library( genrule( name = "simd_x86_64_assemblage23", srcs = [ - "simd/jccolext-sse2-64.asm", - "simd/jccolor-sse2-64.asm", - "simd/jcgray-sse2-64.asm", - "simd/jcgryext-sse2-64.asm", - "simd/jchuff-sse2-64.asm", - "simd/jcolsamp.inc", - "simd/jcsample-sse2-64.asm", - "simd/jdcolext-sse2-64.asm", - "simd/jdcolor-sse2-64.asm", - "simd/jdct.inc", - "simd/jdmerge-sse2-64.asm", - "simd/jdmrgext-sse2-64.asm", - "simd/jdsample-sse2-64.asm", - "simd/jfdctflt-sse-64.asm", - "simd/jfdctfst-sse2-64.asm", - "simd/jfdctint-sse2-64.asm", - "simd/jidctflt-sse2-64.asm", - "simd/jidctfst-sse2-64.asm", - "simd/jidctint-sse2-64.asm", - "simd/jidctred-sse2-64.asm", - "simd/jpeg_nbits_table.inc", - "simd/jquantf-sse2-64.asm", - "simd/jquanti-sse2-64.asm", - "simd/jsimdcfg.inc", - "simd/jsimdext.inc", + "jconfig.h", + "jconfigint.h", + "simd/x86_64/jccolext-avx2.asm", + "simd/x86_64/jccolext-sse2.asm", + "simd/x86_64/jccolor-avx2.asm", + "simd/x86_64/jccolor-sse2.asm", + "simd/x86_64/jcgray-avx2.asm", + "simd/x86_64/jcgray-sse2.asm", + "simd/x86_64/jcgryext-avx2.asm", + "simd/x86_64/jcgryext-sse2.asm", + "simd/x86_64/jchuff-sse2.asm", + "simd/x86_64/jcphuff-sse2.asm", + "simd/x86_64/jcsample-avx2.asm", + "simd/x86_64/jcsample-sse2.asm", + "simd/x86_64/jdcolext-avx2.asm", + "simd/x86_64/jdcolext-sse2.asm", + "simd/x86_64/jdcolor-avx2.asm", + "simd/x86_64/jdcolor-sse2.asm", + "simd/x86_64/jdmerge-avx2.asm", + "simd/x86_64/jdmerge-sse2.asm", + "simd/x86_64/jdmrgext-avx2.asm", + "simd/x86_64/jdmrgext-sse2.asm", + "simd/x86_64/jdsample-avx2.asm", + "simd/x86_64/jdsample-sse2.asm", + "simd/x86_64/jfdctflt-sse.asm", + "simd/x86_64/jfdctfst-sse2.asm", + "simd/x86_64/jfdctint-avx2.asm", + "simd/x86_64/jfdctint-sse2.asm", + "simd/x86_64/jidctflt-sse2.asm", + "simd/x86_64/jidctfst-sse2.asm", + "simd/x86_64/jidctint-avx2.asm", + "simd/x86_64/jidctint-sse2.asm", + "simd/x86_64/jidctred-sse2.asm", + "simd/x86_64/jquantf-sse2.asm", + "simd/x86_64/jquanti-avx2.asm", + "simd/x86_64/jquanti-sse2.asm", + "simd/x86_64/jsimdcpu.asm", + "simd/nasm/jcolsamp.inc", + "simd/nasm/jdct.inc", + "simd/nasm/jpeg_nbits_table.inc", + "simd/nasm/jsimdcfg.inc", + "simd/nasm/jsimdcfg.inc.h", + "simd/nasm/jsimdext.inc", ], outs = [ - "simd/jccolor-sse2-64.o", - "simd/jcgray-sse2-64.o", - "simd/jchuff-sse2-64.o", - "simd/jcsample-sse2-64.o", - "simd/jdcolor-sse2-64.o", - "simd/jdmerge-sse2-64.o", - "simd/jdsample-sse2-64.o", - "simd/jfdctflt-sse-64.o", - "simd/jfdctfst-sse2-64.o", - "simd/jfdctint-sse2-64.o", - "simd/jidctflt-sse2-64.o", - "simd/jidctfst-sse2-64.o", - "simd/jidctint-sse2-64.o", - "simd/jidctred-sse2-64.o", - "simd/jquantf-sse2-64.o", - "simd/jquanti-sse2-64.o", + "simd/x86_64/jccolor-avx2.o", + "simd/x86_64/jccolor-sse2.o", + "simd/x86_64/jcgray-avx2.o", + "simd/x86_64/jcgray-sse2.o", + "simd/x86_64/jchuff-sse2.o", + "simd/x86_64/jcphuff-sse2.o", + "simd/x86_64/jcsample-avx2.o", + "simd/x86_64/jcsample-sse2.o", + "simd/x86_64/jdcolor-avx2.o", + "simd/x86_64/jdcolor-sse2.o", + "simd/x86_64/jdmerge-avx2.o", + "simd/x86_64/jdmerge-sse2.o", + "simd/x86_64/jdsample-avx2.o", + "simd/x86_64/jdsample-sse2.o", + "simd/x86_64/jfdctflt-sse.o", + "simd/x86_64/jfdctfst-sse2.o", + "simd/x86_64/jfdctint-avx2.o", + "simd/x86_64/jfdctint-sse2.o", + "simd/x86_64/jidctflt-sse2.o", + "simd/x86_64/jidctfst-sse2.o", + "simd/x86_64/jidctint-avx2.o", + "simd/x86_64/jidctint-sse2.o", + "simd/x86_64/jidctred-sse2.o", + "simd/x86_64/jquantf-sse2.o", + "simd/x86_64/jquanti-avx2.o", + "simd/x86_64/jquanti-sse2.o", + "simd/x86_64/jsimdcpu.o", ], cmd = "for out in $(OUTS); do\n" + " $(location @nasm//:nasm) -f elf64" + - " -DELF -DPIC -DRGBX_FILLER_0XFF -D__x86_64__ -DARCH_X86_64" + - " -I $$(dirname $(location simd/jdct.inc))/" + - " -I $$(dirname $(location simd/jsimdcfg.inc))/" + + " -DELF -DPIC -D__x86_64__" + + " -I $$(dirname $(location jconfig.h))/" + + " -I $$(dirname $(location jconfigint.h))/" + + " -I $$(dirname $(location simd/nasm/jsimdcfg.inc.h))/" + + " -I $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/" + " -o $$out" + - " $$(dirname $(location simd/jdct.inc))/$$(basename $${out%.o}.asm)\n" + + " $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/$$(basename $${out%.o}.asm)\n" + "done", tools = ["@nasm"], ) @@ -279,8 +322,8 @@ cc_library( "jsimd.h", "jsimddct.h", "simd/jsimd.h", - "simd/jsimd_arm.c", - "simd/jsimd_arm_neon.S", + "simd/arm/jsimd.c", + "simd/arm/jsimd_neon.S", ], copts = libjpegturbo_copts, nocopts = libjpegturbo_nocopts, @@ -300,8 +343,8 @@ cc_library( "jsimd.h", "jsimddct.h", "simd/jsimd.h", - "simd/jsimd_arm64.c", - "simd/jsimd_arm64_neon.S", + "simd/arm64/jsimd.c", + "simd/arm64/jsimd_neon.S", ], copts = libjpegturbo_copts, nocopts = libjpegturbo_nocopts, @@ -332,50 +375,44 @@ template_rule( out = "jconfig_win.h", substitutions = { "@JPEG_LIB_VERSION@": "62", - "@VERSION@": "1.5.1", - "@LIBJPEG_TURBO_VERSION_NUMBER@": "1005001", - "cmakedefine": "define", + "@VERSION@": "2.0.0", + "@LIBJPEG_TURBO_VERSION_NUMBER@": "2000000", "@BITS_IN_JSAMPLE@": "8", - }, -) - -template_rule( - name = "jconfigint_win", - src = "win/jconfigint.h.in", - out = "jconfigint_win.h", - substitutions = { - "@VERSION@": "1.5.1", - "@BUILD@": "20161115", - "@CMAKE_PROJECT_NAME@": "libjpeg-turbo", + "#cmakedefine C_ARITH_CODING_SUPPORTED": "#define C_ARITH_CODING_SUPPORTED", + "#cmakedefine D_ARITH_CODING_SUPPORTED": "#define D_ARITH_CODING_SUPPORTED", + "#cmakedefine MEM_SRCDST_SUPPORTED": "#define MEM_SRCDST_SUPPORTED", + "#cmakedefine WITH_SIMD": "", }, ) JCONFIG_NOWIN_COMMON_SUBSTITUTIONS = { - "LIBJPEG_TURBO_VERSION 0": "LIBJPEG_TURBO_VERSION 1.5.1", - "LIBJPEG_TURBO_VERSION_NUMBER 0": "LIBJPEG_TURBO_VERSION_NUMBER 1005001", - "#undef C_ARITH_CODING_SUPPORTED": "#define C_ARITH_CODING_SUPPORTED 1", - "#undef D_ARITH_CODING_SUPPORTED": "#define D_ARITH_CODING_SUPPORTED 1", - "#undef HAVE_LOCALE_H": "#define HAVE_LOCALE_H 1", - "#undef HAVE_STDDEF_H": "#define HAVE_STDDEF_H 1", - "#undef HAVE_STDLIB_H": "#define HAVE_STDLIB_H 1", - "#undef HAVE_UNSIGNED_CHAR": "#define HAVE_UNSIGNED_CHAR 1", - "#undef HAVE_UNSIGNED_SHORT": "#define HAVE_UNSIGNED_SHORT 1", - "#undef INCOMPLETE_TYPES_BROKEN": "", - "#undef MEM_SRCDST_SUPPORTED": "#define MEM_SRCDST_SUPPORTED 1", - "#undef NEED_BSD_STRINGS": "", - "#undef NEED_SYS_TYPES_H": "#define NEED_SYS_TYPES_H 1", - "#undef __CHAR_UNSIGNED__": "", + "@JPEG_LIB_VERSION@": "62", + "@VERSION@": "2.0.0", + "@LIBJPEG_TURBO_VERSION_NUMBER@": "2000000", + "#cmakedefine C_ARITH_CODING_SUPPORTED": "#define C_ARITH_CODING_SUPPORTED", + "#cmakedefine D_ARITH_CODING_SUPPORTED": "#define D_ARITH_CODING_SUPPORTED", + "#cmakedefine MEM_SRCDST_SUPPORTED": "#define MEM_SRCDST_SUPPORTED", + "@BITS_IN_JSAMPLE@": "8", + "#cmakedefine HAVE_LOCALE_H": "#define HAVE_LOCALE_H 1", + "#cmakedefine HAVE_STDDEF_H": "#define HAVE_STDDEF_H 1", + "#cmakedefine HAVE_STDLIB_H": "#define HAVE_STDLIB_H 1", + "#cmakedefine NEED_SYS_TYPES_H": "#define NEED_SYS_TYPES_H", + "#cmakedefine NEED_BSD_STRINGS": "", + "#cmakedefine HAVE_UNSIGNED_CHAR": "#define HAVE_UNSIGNED_CHAR 1", + "#cmakedefine HAVE_UNSIGNED_SHORT": "#define HAVE_UNSIGNED_SHORT 1", + "#cmakedefine INCOMPLETE_TYPES_BROKEN": "", + "#cmakedefine RIGHT_SHIFT_IS_UNSIGNED": "", + "#cmakedefine __CHAR_UNSIGNED__": "", "#undef const": "", "#undef size_t": "", - "#undef RIGHT_SHIFT_IS_UNSIGNED": "", } JCONFIG_NOWIN_SIMD_SUBSTITUTIONS = { - "#undef WITH_SIMD": "#define WITH_SIMD 1", + "#cmakedefine WITH_SIMD": "#define WITH_SIMD", } JCONFIG_NOWIN_NOSIMD_SUBSTITUTIONS = { - "#undef WITH_SIMD": "", + "#cmakedefine WITH_SIMD": "", } JCONFIG_NOWIN_SIMD_SUBSTITUTIONS.update(JCONFIG_NOWIN_COMMON_SUBSTITUTIONS) @@ -396,22 +433,55 @@ template_rule( substitutions = JCONFIG_NOWIN_SIMD_SUBSTITUTIONS, ) +JCONFIGINT_COMMON_SUBSTITUTIONS = { + "@BUILD@": "20180831", + "@VERSION@": "2.0.0", + "@CMAKE_PROJECT_NAME@": "libjpeg-turbo", + "#undef inline": "", + "#cmakedefine HAVE_INTRIN_H": "", +} + +JCONFIGINT_NOWIN_SUBSTITUTIONS = { + "#cmakedefine HAVE_BUILTIN_CTZL": "#define HAVE_BUILTIN_CTZL", + "@INLINE@" : "inline __attribute__((always_inline))", + "#define SIZEOF_SIZE_T @SIZE_T@": "#if (__WORDSIZE==64 && !defined(__native_client__))\n" + + "#define SIZEOF_SIZE_T 8\n" + + "#else\n" + + "#define SIZEOF_SIZE_T 4\n" + + "#endif\n", +} + +JCONFIGINT_WIN_SUBSTITUTIONS = { + "#cmakedefine HAVE_BUILTIN_CTZL": "", + "#define INLINE @INLINE@" : "#if defined(__GNUC__)\n" + + "#define INLINE inline __attribute__((always_inline))\n" + + "#elif defined(_MSC_VER)\n" + + "#define INLINE __forceinline\n" + + "#else\n" + + "#define INLINE\n" + + "#endif\n", + "#define SIZEOF_SIZE_T @SIZE_T@": "#if (__WORDSIZE==64)\n" + + "#define SIZEOF_SIZE_T 8\n" + + "#else\n" + + "#define SIZEOF_SIZE_T 4\n" + + "#endif\n", +} + +JCONFIGINT_NOWIN_SUBSTITUTIONS.update(JCONFIGINT_COMMON_SUBSTITUTIONS) +JCONFIGINT_WIN_SUBSTITUTIONS.update(JCONFIGINT_COMMON_SUBSTITUTIONS) + template_rule( name = "jconfigint_nowin", src = "jconfigint.h.in", out = "jconfigint_nowin.h", - substitutions = { - "#undef BUILD": "#define BUILD \"20161115\"", - "#undef inline": "", - "#undef INLINE": "#define INLINE inline __attribute__((always_inline))", - "#undef PACKAGE_NAME": "#define PACKAGE_NAME \"libjpeg-turbo\"", - "#undef VERSION": "#define VERSION \"1.5.1\"", - "#undef SIZEOF_SIZE_T": "#if (__WORDSIZE==64 && !defined(__native_client__))\n" + - "#define SIZEOF_SIZE_T 8\n" + - "#else\n" + - "#define SIZEOF_SIZE_T 4\n" + - "#endif\n", - }, + substitutions = JCONFIGINT_NOWIN_SUBSTITUTIONS, +) + +template_rule( + name = "jconfigint_win", + src = "jconfigint.h.in", + out = "jconfigint_win.h", + substitutions = JCONFIGINT_WIN_SUBSTITUTIONS, ) genrule( -- GitLab From cf7373be08a5d745b52d95f2d62e2ccc919ad748 Mon Sep 17 00:00:00 2001 From: coder3101 Date: Sat, 1 Sep 2018 00:36:45 +0530 Subject: [PATCH 036/540] Fixes the formatting issue pointed out at #21762 --- tensorflow/python/ops/rnn_cell_impl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index c128a1039a..8a2da5f193 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -421,7 +421,7 @@ class BasicRNNCell(LayerRNNCell): def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % inputs_shape) + % (inputs_shape,)) input_depth = inputs_shape[-1] self._kernel = self.add_variable( @@ -510,7 +510,7 @@ class GRUCell(LayerRNNCell): def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % inputs_shape) + % (inputs_shape,)) input_depth = inputs_shape[-1] self._gate_kernel = self.add_variable( @@ -681,7 +681,7 @@ class BasicLSTMCell(LayerRNNCell): def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % inputs_shape) + % (inputs_shape,)) input_depth = inputs_shape[-1] h_depth = self._num_units @@ -875,7 +875,7 @@ class LSTMCell(LayerRNNCell): def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % inputs_shape) + % (inputs_shape,)) input_depth = inputs_shape[-1] h_depth = self._num_units if self._num_proj is None else self._num_proj -- GitLab From c67ded664a20f27b4e90020bf76a097b462182b1 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 1 Sep 2018 06:26:02 +0000 Subject: [PATCH 037/540] Fix tensorflow master build failure with verbs This fix tries to address the issue in 21999 where tensorflow master build failed with verbs. The issue was caused by StringPiece replaced with `absl::string_view` This fix fixes 21999 Signed-off-by: Yong Tang --- tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc index ad3dce1784..d4951b156c 100644 --- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc +++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc @@ -63,7 +63,7 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync( } CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0); RdmaChannel* rc = rdma_mgr_->FindChannel(src_name); - string key(std::move(parsed.FullKey().ToString())); + string key(parsed.FullKey()); string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_); Device* dst_dev; -- GitLab From 88646f6350ed8d84462730ac9c6521a97293c7ee Mon Sep 17 00:00:00 2001 From: coder3101 Date: Sun, 2 Sep 2018 11:16:40 +0530 Subject: [PATCH 038/540] updated changes requested. Converted %(input_shape,) to % str(input_shape) --- tensorflow/python/ops/rnn_cell_impl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 8a2da5f193..973ce6306d 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -421,7 +421,7 @@ class BasicRNNCell(LayerRNNCell): def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % (inputs_shape,)) + % str(input_shape)) input_depth = inputs_shape[-1] self._kernel = self.add_variable( @@ -510,7 +510,7 @@ class GRUCell(LayerRNNCell): def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % (inputs_shape,)) + % str(input_shape)) input_depth = inputs_shape[-1] self._gate_kernel = self.add_variable( @@ -681,7 +681,7 @@ class BasicLSTMCell(LayerRNNCell): def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % (inputs_shape,)) + % str(input_shape)) input_depth = inputs_shape[-1] h_depth = self._num_units @@ -875,7 +875,7 @@ class LSTMCell(LayerRNNCell): def build(self, inputs_shape): if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" - % (inputs_shape,)) + % str(input_shape)) input_depth = inputs_shape[-1] h_depth = self._num_units if self._num_proj is None else self._num_proj -- GitLab From f7d27bc67e5d89e5f4bb6d6a0a198c28fa8af46f Mon Sep 17 00:00:00 2001 From: Sangjung Woo Date: Thu, 30 Aug 2018 17:17:23 +0900 Subject: [PATCH 039/540] fix the comparison error when building a CPP API application When building a CPP API application with "-Wall -Werror" option , `error: comparison between signed and unsigned integer expressions' occurs since return type of num_elements() is 'int64' instead of 'size_t' in ops.h to express -1. This patch fixes this bug by explicit type casting. * related issue: https://github.com/tensorflow/tensorflow/issues/20428 Signed-off-by: Sangjung Woo --- tensorflow/cc/framework/ops.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index a085e1d6e2..0717e7dd4b 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -150,7 +150,7 @@ class Input { Initializer(const std::initializer_list& v, const TensorShape& shape) { typedef typename RealType::type RealT; Tensor t(DataTypeToEnum::v(), shape); - if (t.NumElements() != v.size()) { + if (t.NumElements() != static_cast(v.size())) { status = errors::InvalidArgument( "Cannot construct a tensor with ", t.NumElements(), " from an initializer list with ", v.size(), " elements"); -- GitLab From 74af314e4573e168d38072f646495034412ff061 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=9C=A8=E5=8E=9F=E4=BD=90=E4=B8=BA?= Date: Mon, 3 Sep 2018 10:09:05 +0800 Subject: [PATCH 040/540] use single quotation marks for single-line strings --- tensorflow/contrib/autograph/operators/slices_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py index 5300428462..329d9f1f43 100644 --- a/tensorflow/contrib/autograph/operators/slices_test.py +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -47,13 +47,13 @@ class SlicesTest(test.TestCase): self.assertAllEqual(sess.run(t), [3, 4]) def test_get_item_tensor_string(self): - initial_str = constant_op.constant("abcd") + initial_str = constant_op.constant('abcd') t = slices.get_item(initial_str, 1, slices.GetItemOpts(element_dtype=initial_str.dtype)) with self.test_session() as sess: - self.assertEqual(sess.run(t), b"b") + self.assertEqual(sess.run(t), b'b') - initial_list_str = constant_op.constant(["abcd", "bcde"]) + initial_list_str = constant_op.constant(['abcd', 'bcde']) t = slices.get_item(initial_list_str, 1, slices.GetItemOpts(element_dtype=initial_str.dtype)) with self.test_session() as sess: -- GitLab From 752e94a7d73a5c11a1b51b08bc170b0d91724a1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=9C=A8=E5=8E=9F=E4=BD=90=E4=B8=BA?= Date: Mon, 3 Sep 2018 10:09:44 +0800 Subject: [PATCH 041/540] use single quotation marks for single-line strings --- tensorflow/contrib/autograph/operators/slices_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py index 329d9f1f43..2c5ffed4f2 100644 --- a/tensorflow/contrib/autograph/operators/slices_test.py +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -57,7 +57,7 @@ class SlicesTest(test.TestCase): t = slices.get_item(initial_list_str, 1, slices.GetItemOpts(element_dtype=initial_str.dtype)) with self.test_session() as sess: - self.assertEqual(sess.run(t), b"bcde") + self.assertEqual(sess.run(t), b'bcde') if __name__ == '__main__': -- GitLab From f8a3472f711729beadd671884e206452c09f0784 Mon Sep 17 00:00:00 2001 From: pengwa Date: Mon, 3 Sep 2018 19:02:50 +0800 Subject: [PATCH 042/540] fix a minor issue for tf.split document --- tensorflow/python/ops/array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 21ccbc6c33..48f7d3be40 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1275,7 +1275,7 @@ unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__ def split(value, num_or_size_splits, axis=0, num=None, name="split"): """Splits a tensor into sub tensors. - If `num_or_size_splits` is an integer type, `num_split`, then splits `value` + If `num_or_size_splits` is an integer type, then splits `value` along dimension `axis` into `num_split` smaller tensors. Requires that `num_split` evenly divides `value.shape[axis]`. -- GitLab From d118516dd6c5b9fd2f0bfa2b870e7cfb5063e7dc Mon Sep 17 00:00:00 2001 From: Roger Xin Date: Mon, 3 Sep 2018 11:52:42 -0400 Subject: [PATCH 043/540] Fix issues in maxout layer --- tensorflow/contrib/layers/python/layers/layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 04668f112d..a82d4c1951 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -3109,7 +3109,7 @@ def maxout(inputs, num_units, axis=-1, scope=None): inputs: Tensor input num_units: Specifies how many features will remain after maxout in the `axis` dimension (usually channel). - This must be multiple of number of `axis`. + This must be a factor of number of features. axis: The dimension where max pooling will be performed. Default is the last dimension. scope: Optional scope for variable_scope. @@ -3128,7 +3128,7 @@ def maxout(inputs, num_units, axis=-1, scope=None): raise ValueError('number of features({}) is not ' 'a multiple of num_units({})'.format( num_channels, num_units)) - shape[axis] = -1 + shape[axis] = num_units shape += [num_channels // num_units] # Dealing with batches with arbitrary sizes -- GitLab From ce9e5b035b32ef02cd7d10f6ffdd27cc2a75664d Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 1 Sep 2018 01:40:41 +0000 Subject: [PATCH 044/540] Fix syntax error in single_image_random_dot_stereograms caused by locale This fix tries to address the issue raised in 21164 where the single_image_random_dot_stereograms in different locale (like de_DE) caused syntax error in python like: ``` File "", line 28 def single_image_random_dot_stereograms(depth_values, hidden_surface_removal=True, convergence_dots_size=8, dots_per_inch=72, eye_separation=2,5, mu=0,333299994, normalize=True, normalize_max=-100, normalize_min=100, border_level=0, number_colors=256, output_image_shape=[1024, 768, 1], output_data_window=[1022, 757], name=None): ^ SyntaxError: invalid syntax ``` The issue was that the float to string conversion in python_op_gen_internal.cc triggered snprintf (in `FloatToBuffer`) which is local dependent and generates something like `eye_separatiion=2,5` in DE locale. This fix replaced the float to string conversion with locale-independent ``` std::ostringstream s; s.imbue(std::locale::classic()); ``` This fix fixes 21164. Signed-off-by: Yong Tang --- tensorflow/python/framework/python_op_gen_internal.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc index f2270342b0..8ddd1e6432 100644 --- a/tensorflow/python/framework/python_op_gen_internal.cc +++ b/tensorflow/python/framework/python_op_gen_internal.cc @@ -435,7 +435,10 @@ string AttrValueToPython(const string& type, const AttrValue& value, if (std::isnan(value.f()) || std::isinf(value.f())) { return strings::StrCat("float('", value.f(), "')"); } else { - return strings::StrCat(value.f()); + std::ostringstream s; + s.imbue(std::locale::classic()); + s << value.f(); + return s.str(); } } else if (type == "bool") { return value.b() ? "True" : "False"; -- GitLab From a8a0ec4a2eaf37c853afe410964978715c3d02bb Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 1 Sep 2018 01:55:43 +0000 Subject: [PATCH 045/540] Add precision to match the existing behavior. Signed-off-by: Yong Tang --- tensorflow/python/framework/python_op_gen_internal.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc index 8ddd1e6432..dafaf2fd3a 100644 --- a/tensorflow/python/framework/python_op_gen_internal.cc +++ b/tensorflow/python/framework/python_op_gen_internal.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/python/framework/python_op_gen_internal.h" #include +#include +#include #include #include #include "tensorflow/core/framework/api_def.pb.h" @@ -435,9 +437,11 @@ string AttrValueToPython(const string& type, const AttrValue& value, if (std::isnan(value.f()) || std::isinf(value.f())) { return strings::StrCat("float('", value.f(), "')"); } else { + // Use locale-independent conversion. + static_assert(FLT_DIG < 10, "FLT_DIG is too big"); std::ostringstream s; s.imbue(std::locale::classic()); - s << value.f(); + s << std::setprecision(FLT_DIG) << value.f(); return s.str(); } } else if (type == "bool") { -- GitLab From 569426a13fbae66c0acd7ed728a62f413407b898 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 1 Sep 2018 01:58:35 +0000 Subject: [PATCH 046/540] Sanitize with clang-foramt Signed-off-by: Yong Tang --- tensorflow/python/framework/python_op_gen_internal.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc index dafaf2fd3a..7c4941a586 100644 --- a/tensorflow/python/framework/python_op_gen_internal.cc +++ b/tensorflow/python/framework/python_op_gen_internal.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/python/framework/python_op_gen_internal.h" -#include #include +#include #include #include #include -- GitLab From bf64fc285e88d36bb82f80757c4a1afd722347e0 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 14 Aug 2018 15:35:12 +0000 Subject: [PATCH 047/540] Add float16 support for NonMaxSuppressionV{2,3,4} This fix tries to address the issue raised in 20199 where there was no float16 support for NonMaxSuppressionV2. As NonMaxSuppressionV2 is the earlier versions of API and there are newer versions of NonMaxSuppression: NonMaxSuppressionV2, NonMaxSuppressionV3, NonMaxSuppressionV4, This fix exposes the float16 support to all of the above. (Note in the master the default version used is NonMaxSuppressionV3) This fix fixes 20199. Signed-off-by: Yong Tang --- .../core/kernels/non_max_suppression_op.cc | 107 ++++++++++-------- tensorflow/core/ops/image_ops.cc | 15 ++- 2 files changed, 67 insertions(+), 55 deletions(-) diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index 5d9257e20b..c0ea277ed5 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -75,28 +75,29 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context, } // Return intersection-over-union overlap between boxes i and j -static inline float IOUGreaterThanThreshold( - typename TTypes::ConstTensor boxes, int i, int j, - float iou_threshold) { - const float ymin_i = std::min(boxes(i, 0), boxes(i, 2)); - const float xmin_i = std::min(boxes(i, 1), boxes(i, 3)); - const float ymax_i = std::max(boxes(i, 0), boxes(i, 2)); - const float xmax_i = std::max(boxes(i, 1), boxes(i, 3)); - const float ymin_j = std::min(boxes(j, 0), boxes(j, 2)); - const float xmin_j = std::min(boxes(j, 1), boxes(j, 3)); - const float ymax_j = std::max(boxes(j, 0), boxes(j, 2)); - const float xmax_j = std::max(boxes(j, 1), boxes(j, 3)); - const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i); - const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j); - if (area_i <= 0 || area_j <= 0) return 0.0; - const float intersection_ymin = std::max(ymin_i, ymin_j); - const float intersection_xmin = std::max(xmin_i, xmin_j); - const float intersection_ymax = std::min(ymax_i, ymax_j); - const float intersection_xmax = std::min(xmax_i, xmax_j); - const float intersection_area = - std::max(intersection_ymax - intersection_ymin, 0.0) * - std::max(intersection_xmax - intersection_xmin, 0.0); - const float iou = intersection_area / (area_i + area_j - intersection_area); +template +static inline bool IOUGreaterThanThreshold( + typename TTypes::ConstTensor boxes, int i, int j, + T iou_threshold) { + const T ymin_i = std::min(boxes(i, 0), boxes(i, 2)); + const T xmin_i = std::min(boxes(i, 1), boxes(i, 3)); + const T ymax_i = std::max(boxes(i, 0), boxes(i, 2)); + const T xmax_i = std::max(boxes(i, 1), boxes(i, 3)); + const T ymin_j = std::min(boxes(j, 0), boxes(j, 2)); + const T xmin_j = std::min(boxes(j, 1), boxes(j, 3)); + const T ymax_j = std::max(boxes(j, 0), boxes(j, 2)); + const T xmax_j = std::max(boxes(j, 1), boxes(j, 3)); + const T area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i); + const T area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j); + if (area_i <= static_cast(0) || area_j <= static_cast(0)) return 0; + const T intersection_ymin = std::max(ymin_i, ymin_j); + const T intersection_xmin = std::max(xmin_i, xmin_j); + const T intersection_ymax = std::min(ymax_i, ymax_j); + const T intersection_xmax = std::min(xmax_i, xmax_j); + const T intersection_area = + std::max(intersection_ymax - intersection_ymin, static_cast(0.0)) * + std::max(intersection_xmax - intersection_xmin, static_cast(0.0)); + const T iou = intersection_area / (area_i + area_j - intersection_area); return iou > iou_threshold; } @@ -106,11 +107,12 @@ static inline bool OverlapsGreaterThanThreshold( return overlaps(i, j) > overlap_threshold; } +template static inline std::function CreateIOUSuppressCheckFn( const Tensor& boxes, float threshold) { - typename TTypes::ConstTensor boxes_data = boxes.tensor(); - return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1, - std::placeholders::_2, threshold); + typename TTypes::ConstTensor boxes_data = boxes.tensor(); + return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1, + std::placeholders::_2, static_cast(threshold)); } static inline std::function CreateOverlapsSuppressCheckFn( @@ -121,6 +123,7 @@ static inline std::function CreateOverlapsSuppressCheckFn( std::placeholders::_1, std::placeholders::_2, threshold); } +template void DoNonMaxSuppressionOp( OpKernelContext* context, const Tensor& scores, int num_boxes, const Tensor& max_output_size, const float score_threshold, @@ -128,13 +131,13 @@ void DoNonMaxSuppressionOp( bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) { const int output_size = max_output_size.scalar()(); - std::vector scores_data(num_boxes); - std::copy_n(scores.flat().data(), num_boxes, scores_data.begin()); + std::vector scores_data(num_boxes); + std::copy_n(scores.flat().data(), num_boxes, scores_data.begin()); // Data structure for selection candidate in NMS. struct Candidate { int box_index; - float score; + T score; }; auto cmp = [](const Candidate bs_i, const Candidate bs_j) { @@ -143,13 +146,13 @@ void DoNonMaxSuppressionOp( std::priority_queue, decltype(cmp)> candidate_priority_queue(cmp); for (int i = 0; i < scores_data.size(); ++i) { - if (scores_data[i] > score_threshold) { + if (scores_data[i] > static_cast(score_threshold)) { candidate_priority_queue.emplace(Candidate({i, scores_data[i]})); } } std::vector selected; - std::vector selected_scores; + std::vector selected_scores; Candidate next_candidate; while (selected.size() < output_size && !candidate_priority_queue.empty()) { @@ -176,7 +179,7 @@ void DoNonMaxSuppressionOp( int num_valid_outputs = selected.size(); if (pad_to_max_output_size) { selected.resize(output_size, 0); - selected_scores.resize(output_size, 0); + selected_scores.resize(output_size, static_cast(0)); } if (ptr_num_valid_outputs) { *ptr_num_valid_outputs = num_valid_outputs; @@ -221,10 +224,10 @@ class NonMaxSuppressionOp : public OpKernel { if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_); + auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_); const float score_threshold_val = std::numeric_limits::lowest(); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, + DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, score_threshold_val, suppress_check_fn); } @@ -232,7 +235,7 @@ class NonMaxSuppressionOp : public OpKernel { float iou_threshold_; }; -template +template class NonMaxSuppressionV2Op : public OpKernel { public: explicit NonMaxSuppressionV2Op(OpKernelConstruction* context) @@ -264,10 +267,10 @@ class NonMaxSuppressionV2Op : public OpKernel { if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val); + auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val); const float score_threshold_val = std::numeric_limits::lowest(); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, + DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, score_threshold_val, suppress_check_fn); } }; @@ -325,7 +328,7 @@ class NonMaxSuppressionV3V4Base : public OpKernel { float score_threshold_val_; }; -template +template class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base { public: explicit NonMaxSuppressionV3Op(OpKernelConstruction* context) @@ -334,14 +337,14 @@ class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base { protected: void DoComputeAndPostProcess(OpKernelContext* context) override { auto suppress_check_fn = - CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); - DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, + DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, score_threshold_val_, suppress_check_fn); } }; -template +template class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base { public: explicit NonMaxSuppressionV4Op(OpKernelConstruction* context) @@ -353,10 +356,10 @@ class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base { protected: void DoComputeAndPostProcess(OpKernelContext* context) override { auto suppress_check_fn = - CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); int num_valid_outputs; - DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, + DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, score_threshold_val_, suppress_check_fn, pad_to_max_output_size_, &num_valid_outputs); @@ -413,7 +416,7 @@ class NonMaxSuppressionWithOverlapsOp : public OpKernel { auto suppress_check_fn = CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, + DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, score_threshold_val, suppress_check_fn); } }; @@ -421,14 +424,20 @@ class NonMaxSuppressionWithOverlapsOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU), NonMaxSuppressionOp); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU), - NonMaxSuppressionV2Op); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").TypeConstraint("T").Device(DEVICE_CPU), + NonMaxSuppressionV2Op); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").TypeConstraint("T").Device(DEVICE_CPU), + NonMaxSuppressionV2Op); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU), - NonMaxSuppressionV3Op); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").TypeConstraint("T").Device(DEVICE_CPU), + NonMaxSuppressionV3Op); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").TypeConstraint("T").Device(DEVICE_CPU), + NonMaxSuppressionV3Op); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU), - NonMaxSuppressionV4Op); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").TypeConstraint("T").Device(DEVICE_CPU), + NonMaxSuppressionV4Op); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").TypeConstraint("T").Device(DEVICE_CPU), + NonMaxSuppressionV4Op); REGISTER_KERNEL_BUILDER( Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU), diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 11ca0bd259..abb4e6fcf6 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -683,11 +683,12 @@ REGISTER_OP("NonMaxSuppression") }); REGISTER_OP("NonMaxSuppressionV2") - .Input("boxes: float") - .Input("scores: float") + .Input("boxes: T") + .Input("scores: T") .Input("max_output_size: int32") .Input("iou_threshold: float") .Output("selected_indices: int32") + .Attr("T: {half, float}") .SetShapeFn([](InferenceContext* c) { // Get inputs and validate ranks. ShapeHandle boxes; @@ -711,22 +712,24 @@ REGISTER_OP("NonMaxSuppressionV2") }); REGISTER_OP("NonMaxSuppressionV3") - .Input("boxes: float") - .Input("scores: float") + .Input("boxes: T") + .Input("scores: T") .Input("max_output_size: int32") .Input("iou_threshold: float") .Input("score_threshold: float") .Output("selected_indices: int32") + .Attr("T: {half, float}") .SetShapeFn(NMSShapeFn); REGISTER_OP("NonMaxSuppressionV4") - .Input("boxes: float") - .Input("scores: float") + .Input("boxes: T") + .Input("scores: T") .Input("max_output_size: int32") .Input("iou_threshold: float") .Input("score_threshold: float") .Output("selected_indices: int32") .Output("valid_outputs: int32") + .Attr("T: {half, float}") .Attr("pad_to_max_output_size: bool = false") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(NMSShapeFn(c)); -- GitLab From 141d5d666694a37cda65c440315c135d9a6a48a7 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 14 Aug 2018 16:45:21 +0000 Subject: [PATCH 048/540] Add test cases for float16 support for non_max_suppression Signed-off-by: Yong Tang --- tensorflow/python/ops/image_ops_test.py | 41 +++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index f7502c4018..ee76d3d1dc 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -3657,6 +3657,47 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase): scores = constant_op.constant([0.9]) image_ops.non_max_suppression(boxes, scores, 3, [[0.5]]) + def testDataTypes(self): + # Test case for GitHub issue 20199. + boxes_np = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], + [0, 10, 1, 11], [0, 10.1, 1, 11.1], [0, 100, 1, 101]] + scores_np = [0.9, 0.75, 0.6, 0.95, 0.5, 0.3] + max_output_size_np = 3 + iou_threshold_np = 0.5 + # Note: There are multiple versions of non_max_suppression v2, v3, v4. + # gen_image_ops.non_max_suppression_v2: + for dtype in [np.float16, np.float32]: + with self.test_session(): + boxes = constant_op.constant(boxes_np, dtype=dtype) + scores = constant_op.constant(scores_np, dtype=dtype) + max_output_size = constant_op.constant(max_output_size_np) + iou_threshold = constant_op.constant(iou_threshold_np) + selected_indices = gen_image_ops.non_max_suppression_v2( + boxes, scores, max_output_size, iou_threshold).eval() + self.assertAllClose(selected_indices, [3, 0, 5]) + # image_ops.non_max_suppression = gen_image_ops.non_max_suppression_v3. + for dtype in [np.float16, np.float32]: + with self.test_session(): + boxes = constant_op.constant(boxes_np, dtype=dtype) + scores = constant_op.constant(scores_np, dtype=dtype) + max_output_size = constant_op.constant(max_output_size_np) + iou_threshold = constant_op.constant(iou_threshold_np) + selected_indices = image_ops.non_max_suppression( + boxes, scores, max_output_size, iou_threshold).eval() + self.assertAllClose(selected_indices, [3, 0, 5]) + # gen_image_ops.non_max_suppression_v4. + score_threshold=float('-inf') + for dtype in [np.float16, np.float32]: + with self.test_session(): + boxes = constant_op.constant(boxes_np, dtype=dtype) + scores = constant_op.constant(scores_np, dtype=dtype) + max_output_size = constant_op.constant(max_output_size_np) + iou_threshold = constant_op.constant(iou_threshold_np) + selected_indices, _ = gen_image_ops.non_max_suppression_v4( + boxes, scores, max_output_size, iou_threshold, score_threshold) + selected_indices = selected_indices.eval() + self.assertAllClose(selected_indices, [3, 0, 5]) + class NonMaxSuppressionPaddedTest(test_util.TensorFlowTestCase): -- GitLab From ad143cec3f8d9cac0953b9f4bce9a56f659d73d8 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 14 Aug 2018 17:03:02 +0000 Subject: [PATCH 049/540] Pylint fix Signed-off-by: Yong Tang --- tensorflow/python/ops/image_ops_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index ee76d3d1dc..795e6bbc3e 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -3686,7 +3686,7 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase): boxes, scores, max_output_size, iou_threshold).eval() self.assertAllClose(selected_indices, [3, 0, 5]) # gen_image_ops.non_max_suppression_v4. - score_threshold=float('-inf') + score_threshold = float('-inf') for dtype in [np.float16, np.float32]: with self.test_session(): boxes = constant_op.constant(boxes_np, dtype=dtype) -- GitLab From ad997f1c24829dbe3c687d449a757202c401bb6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=9C=A8=E5=8E=9F=E4=BD=90=E4=B8=BA?= Date: Tue, 4 Sep 2018 23:25:30 +0800 Subject: [PATCH 050/540] only apply _string_get_item for string with rank 0 --- tensorflow/contrib/autograph/operators/slices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py index d878bddf3c..a885bdab5b 100644 --- a/tensorflow/contrib/autograph/operators/slices.py +++ b/tensorflow/contrib/autograph/operators/slices.py @@ -58,7 +58,7 @@ def get_item(target, i, opts): elif tensor_util.is_tensor(target): if target.dtype == dtypes.variant: return _tf_tensor_list_get_item(target, i, opts) - if target.dtype == dtypes.string: + elif target.dtype == dtypes.string and target.get_shape() == (): # target is string with rank 0 return _tf_tensor_string_get_item(target, i) else: return _tf_tensor_get_item(target, i) -- GitLab From 0e9af928f7a6711971ade159a511da093f307a81 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 11:07:59 -0700 Subject: [PATCH 051/540] Removed redundant std::string -> string conversions. PiperOrigin-RevId: 211487989 --- .../core/common_runtime/bfc_allocator.cc | 2 +- .../core/common_runtime/graph_runner.cc | 4 +-- .../core/common_runtime/session_state.cc | 2 +- .../common_runtime/step_stats_collector.cc | 6 ++--- tensorflow/core/kernels/gpu_utils.h | 3 +-- .../kernels/merge_v2_checkpoints_op_test.cc | 4 +-- .../remote_fused_graph_execute_utils.cc | 26 +++++++++---------- .../core/kernels/save_restore_v2_ops.cc | 4 +-- tensorflow/core/kernels/string_strip_op.cc | 2 +- tensorflow/core/kernels/tensor_array_ops.cc | 2 +- .../core/kernels/whole_file_read_ops.cc | 2 +- .../core/platform/cloud/curl_http_request.cc | 4 +-- .../core/platform/cloud/gcs_file_system.cc | 14 +++++----- .../core/platform/cloud/oauth_client.cc | 4 +-- .../core/platform/cloud/oauth_client_test.cc | 6 ++--- .../freeze_requantization_ranges.cc | 2 +- .../graph_transforms/sparsify_gather_test.cc | 4 +-- .../tools/graph_transforms/transform_graph.cc | 15 +++++------ .../tools/graph_transforms/transform_utils.cc | 2 +- 19 files changed, 51 insertions(+), 57 deletions(-) diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index 3bf0532491..84c6285bbe 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -596,7 +596,7 @@ string BFCAllocator::RenderOccupancy() { region_offset += region.memory_size(); } - return std::string(rendered, resolution); + return string(rendered, resolution); } void BFCAllocator::DumpMemoryLog(size_t num_bytes) { diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 0a1797fa19..f9aef3af70 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -56,7 +56,7 @@ class SimpleRendezvous : public Rendezvous { } mutex_lock l(mu_); - string edge_name = std::string(parsed.edge_name); + string edge_name(parsed.edge_name); if (table_.count(edge_name) > 0) { return errors::Internal("Send of an already sent tensor"); } @@ -69,7 +69,7 @@ class SimpleRendezvous : public Rendezvous { Tensor tensor; Status status = Status::OK(); { - string key = std::string(parsed.edge_name); + string key(parsed.edge_name); mutex_lock l(mu_); if (table_.count(key) <= 0) { status = errors::Internal("Did not find key ", key); diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc index 65ff356e73..5b1915755d 100644 --- a/tensorflow/core/common_runtime/session_state.cc +++ b/tensorflow/core/common_runtime/session_state.cc @@ -70,7 +70,7 @@ Status TensorStore::SaveTensors(const std::vector& output_names, // Save only the tensors in output_names in the session. for (const string& name : output_names) { TensorId id(ParseTensorName(name)); - const string& op_name = std::string(id.first); + const string op_name(id.first); auto it = tensors_.find(op_name); if (it != tensors_.end()) { // Save the tensor to the session state. diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 9c2510e6a9..836cb8ed14 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -176,7 +176,7 @@ static int ExtractGpuWithStreamAll(string device_name) { } else { // Convert the captured string into an integer. But first we need to put // the digits back in order - string ordered_capture = std::string(capture); + string ordered_capture(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); @@ -205,7 +205,7 @@ static int ExtractGpuWithoutStream(string device_name) { } else { // Convert the captured string into an integer. But first we need to put // the digits back in order - string ordered_capture = std::string(capture); + string ordered_capture(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); @@ -252,7 +252,7 @@ void StepStatsCollector::BuildCostModel( for (auto& itr : per_device_stats) { const StringPiece device_name = itr.first; - const int gpu_id = ExtractGpuWithoutStream(std::string(device_name)); + const int gpu_id = ExtractGpuWithoutStream(string(device_name)); if (gpu_id >= 0) { // Reference the gpu hardware stats in addition to the regular stats // for this gpu device if they're available. diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h index c7dbefa0b4..86146f75f4 100644 --- a/tensorflow/core/kernels/gpu_utils.h +++ b/tensorflow/core/kernels/gpu_utils.h @@ -123,8 +123,7 @@ class AutoTuneMap { string GetActionSummary(StringPiece action, const Parameters& params, const Config& config) { return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(), - std::string(action).c_str(), - params.ToString().c_str(), + string(action).c_str(), params.ToString().c_str(), config.ToString().c_str()); } diff --git a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc index 10e468ce46..693ed8a8f0 100644 --- a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc +++ b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc @@ -114,9 +114,7 @@ class MergeV2CheckpointsOpTest : public OpsTestBase { // Exercises "delete_old_dirs". for (int i = 0; i < 2; ++i) { int directory_found = - Env::Default() - ->IsDirectory(std::string(io::Dirname(prefixes[i]))) - .code(); + Env::Default()->IsDirectory(string(io::Dirname(prefixes[i]))).code(); if (delete_old_dirs) { EXPECT_EQ(error::NOT_FOUND, directory_found); } else { diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index 194a711d98..26f107f940 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -47,7 +47,7 @@ std::unordered_set BuildNodeSetFromNodeNamesAndPorts( std::unordered_set retval; for (const string& node_name_and_port : node_names_and_ports) { const TensorId tid = ParseTensorName(node_name_and_port); - retval.emplace(std::string(tid.first)); + retval.emplace(tid.first); } return retval; } @@ -64,7 +64,7 @@ Node* FindMutableNodeByName(const string& name, Graph* graph) { const NodeDef* FindNodeDefByName(const string& input, const GraphDef& graph_def) { const TensorId tid = ParseTensorName(input); - const string name = std::string(tid.first); + const string name = string(tid.first); for (const NodeDef& node_def : graph_def.node()) { if (node_def.name() == name) { return &node_def; @@ -423,7 +423,7 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( std::vector data_types; std::vector shapes; const TensorId tid = ParseTensorName(name_and_port); - const string node_name = std::string(tid.first); + const string node_name(tid.first); const int port = tid.second; const NodeDef* node_def = FindNodeDefByName(node_name, graph_def); CHECK_NOTNULL(node_def); @@ -522,8 +522,7 @@ RemoteFusedGraphExecuteUtils::GetTensorShapeType( const TensorShapeMap& tensor_shape_map, const string& node_name) { if (node_name.find(':') != string::npos) { const TensorId tid = ParseTensorName(node_name); - return GetTensorShapeType(tensor_shape_map, std::string(tid.first), - tid.second); + return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second); } else { return GetTensorShapeType(tensor_shape_map, node_name, 0); } @@ -570,7 +569,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto( const TensorId tid = ParseTensorName(name); CHECK_EQ(tensor_shape_map->count(name), 0); tensor_shape_map->emplace( - std::string(tid.first), + string(tid.first), std::make_pair(tid.second, std::make_pair(tensor.dtype(), tensor.shape()))); } @@ -692,7 +691,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( std::vector node_out_list; for (const string& input : inputs) { const TensorId tid = ParseTensorName(input); - Node* node = FindMutableNodeByName(std::string(tid.first), graph); + Node* node = FindMutableNodeByName(string(tid.first), graph); CHECK_NOTNULL(node); node_out_list.emplace_back(node, tid.second); } @@ -848,7 +847,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( for (const string& subgraph_input : std::get<1>(cluster)) { const TensorId tid = ParseTensorName(subgraph_input); - const string subgraph_input_name = std::string(tid.first); + const string subgraph_input_name(tid.first); const int subgraph_input_port = tid.second; const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def); CHECK_NOTNULL(node_def); @@ -895,7 +894,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( std::deque queue; for (const string& output : border_outputs) { const TensorId tid = ParseTensorName(output); - const string& output_node_name = std::string(tid.first); + const string output_node_name(tid.first); for (const Node* node : graph.nodes()) { if (output_node_name == node->name()) { queue.push_back(node); @@ -975,7 +974,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( for (int j = 0; j < border_outputs.size(); ++j) { const string& output = border_outputs.at(j); const TensorId tid = ParseTensorName(output); - const string output_name = std::string(tid.first); + const string output_name(tid.first); Node* src_node = edge->src(); if (src_node != nullptr && src_node->name() == output_name && edge->src_output() == tid.second) { @@ -995,12 +994,11 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( // RemoteFusedGraphExecuteOpNode for (const string& output : outputs) { const TensorId output_tid = ParseTensorName(output); - const string output_name = std::string(output_tid.first); + const string output_name(output_tid.first); for (size_t i = 0; i < border_outputs.size(); ++i) { const TensorId subgraph_output_tid = ParseTensorName(border_outputs.at(i)); - const string& subgraph_output_name = - std::string(subgraph_output_tid.first); + const string subgraph_output_name(subgraph_output_tid.first); if (output_name == subgraph_output_name) { LOG(INFO) << "As graph output and subgraph output are same, " << "the graph output node is replaced by identity node"; @@ -1435,7 +1433,7 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions( GraphDef* graph_def) { const TensorId tid = ParseTensorName(input); CHECK_EQ(0, tid.second); - const string node_name = std::string(tid.first); + const string node_name(tid.first); for (NodeDef& node : *graph_def->mutable_node()) { if (node.name() != node_name) { continue; diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index ab4de6c815..180eb3ca34 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -220,9 +220,9 @@ class MergeV2Checkpoints : public OpKernel { context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix)); if (delete_old_dirs_) { - const string& merged_dir = std::string(io::Dirname(merged_prefix)); + const string merged_dir(io::Dirname(merged_prefix)); for (const string& input_prefix : input_prefixes) { - const string& dirname = std::string(io::Dirname(input_prefix)); + const string dirname(io::Dirname(input_prefix)); if (dirname == merged_dir) continue; Status status = env->DeleteDir(dirname); // For sharded save, only the first delete will go through and all diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc index 2aeafa28c4..544dca96ba 100644 --- a/tensorflow/core/kernels/string_strip_op.cc +++ b/tensorflow/core/kernels/string_strip_op.cc @@ -43,7 +43,7 @@ class StringStripOp : public OpKernel { for (int64 i = 0; i < input.size(); ++i) { StringPiece entry(input(i)); str_util::RemoveWhitespaceContext(&entry); - output(i) = std::string(entry); + output(i) = string(entry); } } }; diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 632b65e9b6..2ec2651c04 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -297,7 +297,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp { resource.name()); } tensor_array_name = - std::string(StringPiece(resource.name()).substr(container.size())); + string(StringPiece(resource.name()).substr(container.size())); } auto output_handle = tensor_array_output_handle->flat(); diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc index ed2bf3e8e2..1bf46b5e46 100644 --- a/tensorflow/core/kernels/whole_file_read_ops.cc +++ b/tensorflow/core/kernels/whole_file_read_ops.cc @@ -134,7 +134,7 @@ class WriteFileOp : public OpKernel { "Contents tensor must be scalar, but had shape: ", contents_input->shape().DebugString())); const string& filename = filename_input->scalar()(); - const string dir = std::string(io::Dirname(filename)); + const string dir(io::Dirname(filename)); if (!context->env()->FileExists(dir).ok()) { OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir)); } diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc index a1be4aacce..5e1eabee5b 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.cc +++ b/tensorflow/core/platform/cloud/curl_http_request.cc @@ -394,9 +394,9 @@ size_t CurlHttpRequest::HeaderCallback(const void* ptr, size_t size, .StopCapture() .OneLiteral(": ") .GetResult(&value, &name)) { - string str_value = std::string(value); + string str_value(value); str_util::StripTrailingWhitespace(&str_value); - that->response_headers_[std::string(name)] = str_value; + that->response_headers_[string(name)] = str_value; } return size * nmemb; } diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 9d33787bd5..8f959c018e 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -179,13 +179,13 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket, return errors::InvalidArgument("GCS path doesn't start with 'gs://': ", fname); } - *bucket = std::string(bucketp); + *bucket = string(bucketp); if (bucket->empty() || *bucket == ".") { return errors::InvalidArgument("GCS path doesn't contain a bucket name: ", fname); } str_util::ConsumePrefix(&objectp, "/"); - *object = std::string(objectp); + *object = string(objectp); if (!empty_object_ok && object->empty()) { return errors::InvalidArgument("GCS path doesn't contain an object name: ", fname); @@ -224,7 +224,7 @@ std::set AddAllSubpaths(const std::vector& paths) { for (const string& path : paths) { StringPiece subpath = io::Dirname(path); while (!subpath.empty()) { - result.emplace(std::string(subpath)); + result.emplace(string(subpath)); subpath = io::Dirname(subpath); } } @@ -723,7 +723,7 @@ GcsFileSystem::GcsFileSystem() { if (!header_name.empty() && !header_value.empty()) { additional_header_.reset(new std::pair( - std::string(header_name), std::string(header_value))); + string(header_name), string(header_value))); VLOG(1) << "GCS additional header ENABLED. " << "Name: " << additional_header_->first << ", " @@ -1229,7 +1229,7 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern, // Find the fixed prefix by looking for the first wildcard. const string& fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\")); - const string& dir = std::string(io::Dirname(fixed_prefix)); + const string dir(io::Dirname(fixed_prefix)); if (dir.empty()) { return errors::InvalidArgument( "A GCS pattern doesn't have a bucket name: ", pattern); @@ -1326,7 +1326,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, " doesn't match the prefix ", object_prefix)); } if (!relative_path.empty() || include_self_directory_marker) { - result->emplace_back(std::string(relative_path)); + result->emplace_back(relative_path); } if (++retrieved_results >= max_results) { return Status::OK(); @@ -1354,7 +1354,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, "Unexpected response: the returned folder name ", prefix_str, " doesn't match the prefix ", object_prefix); } - result->emplace_back(std::string(relative_path)); + result->emplace_back(relative_path); if (++retrieved_results >= max_results) { return Status::OK(); } diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc index ee6ba7b041..9b85cae9b9 100644 --- a/tensorflow/core/platform/cloud/oauth_client.cc +++ b/tensorflow/core/platform/cloud/oauth_client.cc @@ -216,7 +216,7 @@ Status OAuthClient::GetTokenFromServiceAccountJson( // Send the request to the Google OAuth 2.0 server to get the token. std::unique_ptr request(http_request_factory_->Create()); std::vector response_buffer; - request->SetUri(std::string(oauth_server_uri)); + request->SetUri(string(oauth_server_uri)); request->SetPostFromBuffer(request_body.c_str(), request_body.size()); request->SetResultBuffer(&response_buffer); TF_RETURN_IF_ERROR(request->Send()); @@ -248,7 +248,7 @@ Status OAuthClient::GetTokenFromRefreshTokenJson( std::unique_ptr request(http_request_factory_->Create()); std::vector response_buffer; - request->SetUri(std::string(oauth_server_uri)); + request->SetUri(string(oauth_server_uri)); request->SetPostFromBuffer(request_body.c_str(), request_body.size()); request->SetResultBuffer(&response_buffer); TF_RETURN_IF_ERROR(request->Send()); diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc index 4ffa72288b..1cd0641cd3 100644 --- a/tensorflow/core/platform/cloud/oauth_client_test.cc +++ b/tensorflow/core/platform/cloud/oauth_client_test.cc @@ -126,9 +126,9 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) { EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer", grant_type); - int last_dot = std::string(assertion).find_last_of("."); - string header_dot_claim = std::string(assertion.substr(0, last_dot)); - string signature_encoded = std::string(assertion.substr(last_dot + 1)); + int last_dot = assertion.rfind('.'); + string header_dot_claim(assertion.substr(0, last_dot)); + string signature_encoded(assertion.substr(last_dot + 1)); // Check that 'signature' signs 'header_dot_claim'. diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc index c8dc2a7c4d..d97496cbeb 100644 --- a/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc +++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges.cc @@ -92,7 +92,7 @@ Status ExtractMinMaxRecords(const string& log_file_name, if (!str_util::EndsWith(name_string, print_suffix)) { continue; } - string name = std::string( + string name( name_string.substr(0, name_string.size() - print_suffix.size())); records->push_back({name, min, max}); } diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc index dd95779a1f..b8d6ba00de 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc @@ -42,8 +42,8 @@ class SparsifyGatherTest : public ::testing::Test { const std::vector& inputs, GraphDef* graph_def, bool control_dep = false) { NodeDef* node_def = graph_def->add_node(); - node_def->set_name(std::string(name)); - node_def->set_op(std::string(op)); + node_def->set_name(string(name)); + node_def->set_op(string(op)); if (!control_dep) { std::for_each(inputs.begin(), inputs.end(), [&node_def](NodeDef* input) { node_def->add_input(input->name()); diff --git a/tensorflow/tools/graph_transforms/transform_graph.cc b/tensorflow/tools/graph_transforms/transform_graph.cc index 5cae8f8d8f..7efe450710 100644 --- a/tensorflow/tools/graph_transforms/transform_graph.cc +++ b/tensorflow/tools/graph_transforms/transform_graph.cc @@ -65,19 +65,19 @@ Status ParseTransformParameters(const string& transforms_string, .GetResult(&remaining, &transform_name); if (!found_transform_name) { return errors::InvalidArgument("Looking for transform name, but found ", - std::string(remaining).c_str()); + string(remaining).c_str()); } if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) { state = TRANSFORM_PARAM_NAME; } else { // Add a transform with no parameters. - params_list->push_back({std::string(transform_name), func_parameters}); + params_list->push_back({string(transform_name), func_parameters}); transform_name = ""; state = TRANSFORM_NAME; } } else if (state == TRANSFORM_PARAM_NAME) { if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) { - params_list->push_back({std::string(transform_name), func_parameters}); + params_list->push_back({string(transform_name), func_parameters}); transform_name = ""; state = TRANSFORM_NAME; } else { @@ -92,13 +92,13 @@ Status ParseTransformParameters(const string& transforms_string, if (!found_parameter_name) { return errors::InvalidArgument( "Looking for parameter name, but found ", - std::string(remaining).c_str()); + string(remaining).c_str()); } if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) { state = TRANSFORM_PARAM_VALUE; } else { return errors::InvalidArgument("Looking for =, but found ", - std::string(remaining).c_str()); + string(remaining).c_str()); } } } else if (state == TRANSFORM_PARAM_VALUE) { @@ -120,10 +120,9 @@ Status ParseTransformParameters(const string& transforms_string, } if (!found_parameter_value) { return errors::InvalidArgument("Looking for parameter name, but found ", - std::string(remaining).c_str()); + string(remaining).c_str()); } - func_parameters[std::string(parameter_name)].push_back( - std::string(parameter_value)); + func_parameters[string(parameter_name)].emplace_back(parameter_value); // Eat up any trailing quotes. Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match); Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match); diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index cb084e49b7..c715380aae 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -93,7 +93,7 @@ void NodeNamePartsFromInput(const string& input_name, string* prefix, } else { *prefix = ""; } - *node_name = std::string(node_name_piece); + *node_name = string(node_name_piece); } string NodeNameFromInput(const string& input_name) { -- GitLab From 4cd79b3f6361b6518463349a51fe33f7520f3b49 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 11:10:27 -0700 Subject: [PATCH 052/540] Fix LazyAdamOptimizer for sparse updates on resource variables. PiperOrigin-RevId: 211488610 --- .../python/training/lazy_adam_optimizer.py | 63 ++++++++++++++----- .../training/lazy_adam_optimizer_test.py | 17 ++++- 2 files changed, 63 insertions(+), 17 deletions(-) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py index 72117c1e81..f026f437dc 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py @@ -25,9 +25,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import adam @@ -46,7 +48,12 @@ class LazyAdamOptimizer(adam.AdamOptimizer): may lead to different empirical results. """ - def _apply_sparse(self, grad, var): + def _apply_sparse_shared(self, + grad, + var, + indices, + scatter_update, + scatter_sub): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) @@ -58,23 +65,51 @@ class LazyAdamOptimizer(adam.AdamOptimizer): # \\(m := beta1 * m + (1 - beta1) * g_t\\) m = self.get_slot(var, "m") - m_t = state_ops.scatter_update(m, grad.indices, - beta1_t * array_ops.gather(m, grad.indices) + - (1 - beta1_t) * grad.values, - use_locking=self._use_locking) + m_t = scatter_update(m, indices, + beta1_t * array_ops.gather(m, indices) + + (1 - beta1_t) * grad) # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) v = self.get_slot(var, "v") - v_t = state_ops.scatter_update(v, grad.indices, - beta2_t * array_ops.gather(v, grad.indices) + - (1 - beta2_t) * math_ops.square(grad.values), - use_locking=self._use_locking) + v_t = scatter_update(v, indices, + beta2_t * array_ops.gather(v, indices) + + (1 - beta2_t) * math_ops.square(grad)) # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) - m_t_slice = array_ops.gather(m_t, grad.indices) - v_t_slice = array_ops.gather(v_t, grad.indices) + m_t_slice = array_ops.gather(m_t, indices) + v_t_slice = array_ops.gather(v_t, indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t - var_update = state_ops.scatter_sub(var, grad.indices, - lr * m_t_slice / denominator_slice, - use_locking=self._use_locking) + var_update = scatter_sub(var, indices, + lr * m_t_slice / denominator_slice) return control_flow_ops.group(var_update, m_t, v_t) + + def _apply_sparse(self, grad, var): + return self._apply_sparse_shared( + grad.values, var, grad.indices, + self._scatter_update, + self._scatter_sub) + + def _resource_apply_sparse(self, grad, var, indices): + return self._apply_sparse_shared( + grad, var, indices, + self._resource_scatter_update, + self._resource_scatter_sub) + + # Utility functions for updating resource or non-resource variables. + def _scatter_update(self, x, i, v): + return state_ops.scatter_update( + x, i, v, use_locking=self._use_locking) + + def _scatter_sub(self, x, i, v): + return state_ops.scatter_sub( + x, i, v, use_locking=self._use_locking) + + def _resource_scatter_update(self, x, i, v): + update_op = resource_variable_ops.resource_scatter_update(x.handle, i, v) + with ops.control_dependencies([update_op]): + return x.value() + + def _resource_scatter_sub(self, x, i, v): + sub_op = resource_variable_ops.resource_scatter_sub(x.handle, i, v) + with ops.control_dependencies([sub_op]): + return x.value() diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py index dc4c462ce4..d3e9e89502 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -51,7 +52,7 @@ def adam_update_numpy(param, class AdamOptimizerTest(test.TestCase): - def testSparse(self): + def doTestSparse(self, use_resource=False): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): # Initialize variables for numpy implementation. @@ -61,8 +62,12 @@ class AdamOptimizerTest(test.TestCase): var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) + if use_resource: + var0 = resource_variable_ops.ResourceVariable(var0_np) + var1 = resource_variable_ops.ResourceVariable(var1_np) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) grads0_np_indices = np.array([0, 1], dtype=np.int32) grads0 = ops.IndexedSlices( constant_op.constant(grads0_np), @@ -94,6 +99,12 @@ class AdamOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(var0_np, var0.eval()) self.assertAllCloseAccordingToType(var1_np, var1.eval()) + def testSparse(self): + self.doTestSparse(use_resource=False) + + def testResourceSparse(self): + self.doTestSparse(use_resource=True) + def testSparseDevicePlacement(self): for index_dtype in [dtypes.int32, dtypes.int64]: with self.test_session(force_gpu=test.is_gpu_available()): -- GitLab From 9ae8214229960c634c9f82c00f2c0df287c27a9d Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Tue, 4 Sep 2018 11:15:12 -0700 Subject: [PATCH 053/540] Support zeros_like for nested TensorLists. PiperOrigin-RevId: 211489741 --- tensorflow/core/kernels/list_kernels.h | 21 +++++++++- .../python/kernel_tests/list_ops_test.py | 41 +++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index 066a1d603b..72581c9293 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -374,7 +374,12 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, y->tensors.reserve(x.tensors.size()); for (const Tensor& t : x.tensors) { Tensor out_tensor; - TF_RETURN_IF_ERROR(c->allocate_temp(t.dtype(), t.shape(), &out_tensor)); + AllocatorAttributes attr; + if (t.dtype() == DT_VARIANT) { + attr.set_on_host(true); + } + TF_RETURN_IF_ERROR( + c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr)); switch (out_tensor.dtype()) { #define DTYPE_CASE(dtype) \ case DataTypeToEnum::value: \ @@ -385,6 +390,20 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, TF_CALL_POD_TYPES(DTYPE_CASE) #undef DTYPE_CASE + + case DataTypeToEnum::value: { + const TensorList* inner_x = t.scalar()().get(); + if (inner_x == nullptr) { + return errors::InvalidArgument("Input handle is not a list. Saw: '", + t.scalar()().DebugString(), + "'"); + } + TensorList inner_y; + TF_RETURN_IF_ERROR(TensorListZerosLike(c, *inner_x, &inner_y)); + out_tensor.scalar()() = std::move(inner_y); + break; + } + default: return errors::InvalidArgument( "Trying to compute zeros_like for unsupported dtype ", diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index 9b6aee64aa..ff941b64fa 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -476,6 +476,47 @@ class ListOpsTest(test_util.TensorFlowTestCase): self.evaluate(t_full_zeros), np.zeros( (2,), dtype=dtype.as_numpy_dtype)) + @test_util.run_in_graph_and_eager_modes + def testZerosLikeVariant(self): + for dtype in (dtypes.uint8, dtypes.uint16, dtypes.int8, dtypes.int16, + dtypes.int32, dtypes.int64, dtypes.float16, dtypes.float32, + dtypes.float64, dtypes.complex64, dtypes.complex128, + dtypes.bool): + l = list_ops.empty_tensor_list( + element_dtype=dtypes.variant, element_shape=scalar_shape()) + + sub_l = list_ops.empty_tensor_list( + element_dtype=dtype, element_shape=scalar_shape()) + l = list_ops.tensor_list_push_back(l, sub_l) + sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast( + 1, dtype=dtype)) + l = list_ops.tensor_list_push_back(l, sub_l) + sub_l = list_ops.tensor_list_push_back(sub_l, math_ops.cast( + 2, dtype=dtype)) + l = list_ops.tensor_list_push_back(l, sub_l) + + # l : [[], + # [1], + # [1, 2]] + # + # l_zeros : [[], + # [0], + # [0, 0]] + l_zeros = array_ops.zeros_like(l) + + outputs = [] + for _ in range(3): + l_zeros, out = list_ops.tensor_list_pop_back( + l_zeros, element_dtype=dtypes.variant) + outputs.append(list_ops.tensor_list_stack(out, element_dtype=dtype)) + + # Note: `outputs` contains popped values so the order is reversed. + self.assertAllEqual(self.evaluate(outputs[2]), []) + self.assertAllEqual( + self.evaluate(outputs[1]), np.zeros((1,), dtype=dtype.as_numpy_dtype)) + self.assertAllEqual( + self.evaluate(outputs[0]), np.zeros((2,), dtype=dtype.as_numpy_dtype)) + if __name__ == "__main__": test.main() -- GitLab From 5d183ab7fc7b82f1dea0b9fa9c6412c39ade15a1 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Tue, 4 Sep 2018 11:17:30 -0700 Subject: [PATCH 054/540] [XLA] Make kConvolution, kDot HLO attributes mandatory HLO transformations would forget to propagate the feature depth attribute. Making these attributes mandatory, while slightly less convenient for tests, makes HLO transformations more robust. PiperOrigin-RevId: 211490160 --- tensorflow/compiler/xla/client/xla_builder.cc | 4 +- tensorflow/compiler/xla/reference_util.cc | 14 ++- .../xla/service/algebraic_simplifier.cc | 21 ++-- .../xla/service/algebraic_simplifier_test.cc | 50 +++++--- .../xla/service/batch_dot_simplification.cc | 4 +- .../service/bfloat16_normalization_test.cc | 5 +- .../xla/service/buffer_assignment_test.cc | 11 +- .../convolution_feature_group_converter.cc | 4 +- .../xla/service/cpu/conv_canonicalization.cc | 6 +- .../service/cpu/conv_canonicalization_test.cc | 13 +- .../cpu/cpu_instruction_fusion_test.cc | 6 +- .../compiler/xla/service/dot_decomposer.cc | 6 +- .../gpu/cudnn_convolution_rewriter_test.cc | 112 +++++++++++------- .../compiler/xla/service/graphviz_example.cc | 7 +- .../xla/service/heap_simulator_test.cc | 31 +++-- .../xla/service/hlo_computation_test.cc | 15 ++- .../xla/service/hlo_creation_utils.cc | 25 ++-- .../compiler/xla/service/hlo_creation_utils.h | 11 +- .../xla/service/hlo_dataflow_analysis_test.cc | 5 +- .../compiler/xla/service/hlo_evaluator.cc | 6 +- .../compiler/xla/service/hlo_evaluator.h | 3 +- .../xla/service/hlo_evaluator_test.cc | 37 ++++-- .../xla/service/hlo_evaluator_typed_visitor.h | 7 +- .../compiler/xla/service/hlo_instruction.cc | 57 ++++++--- .../compiler/xla/service/hlo_instruction.h | 7 +- .../xla/service/hlo_instruction_test.cc | 35 +++--- .../compiler/xla/service/hlo_instructions.cc | 14 ++- .../compiler/xla/service/hlo_instructions.h | 11 +- tensorflow/compiler/xla/service/hlo_parser.cc | 41 ++++--- .../compiler/xla/service/hlo_verifier.cc | 4 +- .../xla/service/indexed_array_analysis.cc | 27 +++-- .../xla/service/indexed_array_analysis.h | 10 +- .../compiler/xla/service/shape_inference.cc | 4 +- .../compiler/xla/service/shape_inference.h | 6 +- .../xla/service/shape_inference_test.cc | 16 +-- .../compiler/xla/service/transpose_folding.cc | 7 +- .../xla/service/transpose_folding_test.cc | 31 +++-- .../service/tuple_points_to_analysis_test.cc | 5 +- .../xla/tests/multioutput_fusion_test.cc | 12 +- 39 files changed, 436 insertions(+), 254 deletions(-) diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index e639028ccd..7f2125f74c 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -990,8 +990,8 @@ XlaOp XlaBuilder::ConvGeneralDilated( TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), ShapeInference::InferConvolveShape( - lhs_shape, rhs_shape, instr.window(), - dimension_numbers, feature_group_count)); + lhs_shape, rhs_shape, feature_group_count, + instr.window(), dimension_numbers)); *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index a4854f593f..8a05d1b0d7 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -564,18 +564,22 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( dim2.set_base_dilation(lhs_dilation.second); *window.add_dimensions() = dim2; - const Shape& shape = - ShapeInference::InferConvolveShape(lhs_literal->shape(), - rhs_literal->shape(), window, dnums) - .ConsumeValueOrDie(); + const Shape& shape = ShapeInference::InferConvolveShape( + lhs_literal->shape(), rhs_literal->shape(), + /*feature_group_count=*/1, window, dnums) + .ConsumeValueOrDie(); HloInstruction* lhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfigProto::DEFAULT); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, precision_config)); HloModuleConfig config; HloModule module("ReferenceUtil", config); auto computation = module.AddEntryComputation(b.Build()); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 7c078f07d7..3d18fe3be2 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -950,9 +950,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( new_dot_rhs = rhs_slice; } - auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot.shape(), new_dot_lhs, new_dot_rhs, new_dot_dnums)); - new_dot->set_precision_config(dot.precision_config()); + auto* new_dot = computation_->AddInstruction( + HloInstruction::CreateDot(dot.shape(), new_dot_lhs, new_dot_rhs, + new_dot_dnums, dot.precision_config())); if (add_result) { add_result = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -1053,9 +1053,9 @@ StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfGather( const int n = right_operand->shape().dimensions(1 - rhs_contracting_dimension); auto memoized_shape = ShapeUtil::MakeShape(F32, {m, n}); - auto* memoized_inst = computation_->AddInstruction(HloInstruction::CreateDot( - memoized_shape, left_operand, right_operand, dnums)); - memoized_inst->set_precision_config(dot->precision_config()); + auto* memoized_inst = computation_->AddInstruction( + HloInstruction::CreateDot(memoized_shape, left_operand, right_operand, + dnums, dot->precision_config())); // Get pair {start, 0} or {0, start}. HloInstruction* original_start_indices = lhs_is_dynamic_slice ? lhs->mutable_operand(1) : rhs->mutable_operand(1); @@ -1151,9 +1151,8 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), - rhs->mutable_operand(0), lhs->mutable_operand(0), - dot_dimension_numbers)); - new_dot->set_precision_config(dot->precision_config()); + rhs->mutable_operand(0), lhs->mutable_operand(0), dot_dimension_numbers, + dot->precision_config())); return ReplaceWithNewInstruction( dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } @@ -2477,8 +2476,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( dot_dimension_numbers.add_lhs_contracting_dimensions(1); dot_dimension_numbers.add_rhs_contracting_dimensions(0); auto dot = computation_->AddInstruction(HloInstruction::CreateDot( - dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); - dot->set_precision_config(convolution->precision_config()); + dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers, + convolution->precision_config())); return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 43a891e4fa..019840b476 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1013,6 +1013,13 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { 1); } +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { auto builder = HloComputation::Builder(TestName()); HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1044,7 +1051,8 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { dim->set_window_reversal(false); // Create add computation. builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums)); + ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(builder.Build()); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -2260,9 +2268,11 @@ TEST_P(ConvInputPaddingTest, DoTest) { .ValueOrDie(); builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(lhs_pad->shape(), filter->shape(), - window, dnums) + /*feature_group_count=*/1, window, + dnums) .ValueOrDie(), - lhs_pad, filter, window, dnums)); + lhs_pad, filter, /*feature_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -2368,9 +2378,11 @@ TEST_P(ConvFilterPaddingTest, DoIt) { .ValueOrDie(); auto* orig_conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(input->shape(), rhs_pad->shape(), - window, dnums) + /*feature_group_count=*/1, window, + dnums) .ValueOrDie(), - input, rhs_pad, window, dnums)); + input, rhs_pad, /*feature_group_count=*/1, window, dnums, + DefaultPrecisionConfig(2))); // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place // after the transformation. @@ -2522,8 +2534,9 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { HloInstruction* filter = b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter")); - b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, - window, dnums)); + b.AddInstruction(HloInstruction::CreateConvolve( + out_shape, input, filter, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); // TODO(b/80488902): verify this module. auto module = HloTestBase::CreateNewModule(); @@ -2901,7 +2914,8 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums, + DefaultPrecisionConfig(2))); std::unique_ptr dot_computation(builder.Build()); HloComputation::Builder call_builder(TestName() + ".Call"); @@ -3253,8 +3267,8 @@ TEST_P(DotStrengthReductionTest, DotStrengthReduction) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -3329,8 +3343,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { dot_dnums.add_rhs_contracting_dimensions(0); Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3393,8 +3407,8 @@ TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { dot_dnums.add_rhs_contracting_dimensions(0); Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3511,8 +3525,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantRHS) { int64 dot_row_size = 1; int64 dot_col_size = spec.n; Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, ds, rhs, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, ds, rhs, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, @@ -3581,8 +3595,8 @@ TEST_P(DotOfGatherSimplificationTest, ConstantLHS) { int64 dot_row_size = spec.m; int64 dot_col_size = 1; Shape dot_shape = ShapeUtil::MakeShape(F32, {dot_row_size, dot_col_size}); - builder.AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, ds, dot_dnums)); + builder.AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, ds, dot_dnums, DefaultPrecisionConfig(2))); auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, diff --git a/tensorflow/compiler/xla/service/batch_dot_simplification.cc b/tensorflow/compiler/xla/service/batch_dot_simplification.cc index a16b85a0a5..eda026ac56 100644 --- a/tensorflow/compiler/xla/service/batch_dot_simplification.cc +++ b/tensorflow/compiler/xla/service/batch_dot_simplification.cc @@ -63,8 +63,8 @@ BatchDotSimplification::ElideDegenerateBatchDimensionFromBatchDot( new_dim_numbers.rhs_contracting_dimensions(0) - degenerate_dims.size()); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot, - MakeDotHlo(new_lhs, new_rhs, new_dim_numbers)); - new_dot->set_precision_config(batch_dot->precision_config()); + MakeDotHlo(new_lhs, new_rhs, new_dim_numbers, + batch_dot->precision_config())); TF_ASSIGN_OR_RETURN(HloInstruction * new_dot_reshaped, MakeReshapeHlo(batch_dot->shape(), new_dot)); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index b08705d4c2..d480d72297 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -308,8 +308,11 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8bd1533972..7398f105a0 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1490,10 +1490,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot_ab = builder.AddInstruction( - HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums)); - auto dot_bc = builder.AddInstruction( - HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); + auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot( + shape_2x4, param_a, param_b, dot_dnums, precision_config)); + auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot( + shape_3x4, param_b, param_c, dot_dnums, precision_config)); builder.AddInstruction( HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 0)); diff --git a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc index 9c81a86bbb..0826380f65 100644 --- a/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc +++ b/tensorflow/compiler/xla/service/convolution_feature_group_converter.cc @@ -223,8 +223,8 @@ Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) { filter_mask, expanded_filter, zero_filter)); auto new_convolution = HloInstruction::CreateConvolve( convolution->shape(), convolution->mutable_operand(0), new_filter, - convolution->window(), dim_numbers, /*feature_group_count=*/1); - new_convolution->set_precision_config(convolution->precision_config()); + /*feature_group_count=*/1, convolution->window(), dim_numbers, + convolution->precision_config()); TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( convolution, std::move(new_convolution))); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 098ce17a56..2d9978404c 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -130,9 +130,9 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { // change the dimension mapping but not the dimension sizes. For // example, input height and width are the same as before the reshapes. HloInstruction* new_conv = module->entry_computation()->AddInstruction( - HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, - hlo->window(), new_dnums)); - new_conv->set_precision_config(hlo->precision_config()); + HloInstruction::CreateConvolve( + new_conv_shape, new_input, new_kernel, hlo->feature_group_count(), + hlo->window(), new_dnums, hlo->precision_config())); // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 547d4c696d..616c453750 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -56,6 +56,13 @@ class ConvCanonicalizationTest : public HloTestBase { static constexpr int kOutputFeatureCount = 64; }; +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in CNHW order. @@ -84,7 +91,8 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}), - input, kernel, conv_window_, dnums)); + input, kernel, /*feature_group_count=*/1, conv_window_, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -146,7 +154,8 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape( F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}), - input, kernel, conv_window_, dnums)); + input, kernel, /*feature_group_count=*/1, conv_window_, dnums, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 284929ca07..6bd0a2dd90 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -38,7 +38,11 @@ std::unique_ptr MakeDot(const Shape& shape, HloInstruction* lhs, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, + precision_config); } TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc index 09cb10d6ee..b2ba261790 100644 --- a/tensorflow/compiler/xla/service/dot_decomposer.cc +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -134,9 +134,9 @@ Status DecomposeBatchDot(HloInstruction* dot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( - dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); - dot_r2->set_precision_config(dot->precision_config()); + auto dot_r2 = computation->AddInstruction( + HloInstruction::CreateDot(dot_shape_r2, lhs_slice_r2, rhs_slice_r2, + dot_dnums, dot->precision_config())); // Reshape Dot to R3 so we can concat along batch dimension. auto dot_r3 = computation->AddInstruction( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index 46c23db465..9b46bfc098 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -95,6 +95,13 @@ class CudnnConvolutionRewriterTest : public HloVerifiedTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = @@ -107,12 +114,12 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { conv_window.mutable_dimensions(1)->set_size(2); conv_window.mutable_dimensions(1)->set_window_dilation(2); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(activations->shape(), - gradients->shape(), conv_window, - tf_default_dnums_for_backward_filter_) + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, conv_window, - tf_default_dnums_for_backward_filter_)); + activations, gradients, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -135,12 +142,12 @@ TEST_F(CudnnConvolutionRewriterTest, Window conv_window = default_conv_window_; conv_window.mutable_dimensions(1)->set_size(3); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(activations->shape(), - gradients->shape(), conv_window, - tf_default_dnums_for_backward_filter_) + ShapeInference::InferConvolveShape( + activations->shape(), gradients->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_filter_) .ConsumeValueOrDie(), - activations, gradients, conv_window, - tf_default_dnums_for_backward_filter_)); + activations, gradients, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -170,7 +177,8 @@ TEST_F(CudnnConvolutionRewriterTest, } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -200,7 +208,8 @@ TEST_F(CudnnConvolutionRewriterTest, } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -228,7 +237,8 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { } builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_filter_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -272,13 +282,14 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output, - /*rhs=*/reverse_kernel, conv_window, conv_dnums)); + /*rhs=*/reverse_kernel, /*feature_group_count=*/1, conv_window, + conv_dnums, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), - ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, conv_dnums) - .ValueOrDie())); + conv->shape(), ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), + /*feature_group_count=*/1, conv_window, conv_dnums) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -319,11 +330,11 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { builder.AddInstruction(HloInstruction::CreateConvolve( ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - conv_window, + /*feature_group_count=*/1, conv_window, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, conv_window, - tf_default_dnums_for_backward_input_)); + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -350,12 +361,13 @@ TEST_F(CudnnConvolutionRewriterTest, 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel")); builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape(output->shape(), kernel->shape(), - default_conv_window_, - tf_default_dnums_for_backward_input_) + ShapeInference::InferConvolveShape( + output->shape(), kernel->shape(), /*feature_group_count=*/1, + default_conv_window_, tf_default_dnums_for_backward_input_) .ConsumeValueOrDie(), - /*lhs=*/output, /*rhs=*/kernel, default_conv_window_, - tf_default_dnums_for_backward_input_)); + /*lhs=*/output, /*rhs=*/kernel, /*feature_group_count=*/1, + default_conv_window_, tf_default_dnums_for_backward_input_, + DefaultPrecisionConfig(2))); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -402,13 +414,15 @@ TEST_F(CudnnConvolutionRewriterTest, } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -449,13 +463,15 @@ TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { } HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -502,13 +518,15 @@ TEST_F(CudnnConvolutionRewriterTest, forward_conv_col_dim->set_base_dilation(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); const HloComputation* entry_computation = @@ -554,13 +572,15 @@ TEST_F(CudnnConvolutionRewriterTest, forward_conv_col_dim->set_padding_high(2); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel, - conv_window, tf_default_dnums_for_backward_input_)); + /*feature_group_count=*/1, conv_window, + tf_default_dnums_for_backward_input_, DefaultPrecisionConfig(2))); // Verify the convolution's shape is consistent with ShapeInference. CHECK(ShapeUtil::Compatible( - conv->shape(), ShapeInference::InferConvolveShape( - output->shape(), reverse_kernel->shape(), conv_window, - tf_default_dnums_for_backward_input_) - .ValueOrDie())); + conv->shape(), + ShapeInference::InferConvolveShape( + output->shape(), reverse_kernel->shape(), /*feature_group_count=*/1, + conv_window, tf_default_dnums_for_backward_input_) + .ValueOrDie())); auto module = CreateNewModule(); HloComputation* entry_computation = diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index a2be89511b..0a49d85c6d 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -112,8 +112,11 @@ std::unique_ptr MakeBigGraph() { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums)); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfigProto::DEFAULT); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + vshape, clamp, param_v0, dot_dnums, precision_config)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({dot, param_s, clamp})); auto scalar = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 5f85f14565..576c5ff7a4 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -353,6 +353,13 @@ TEST_F(HeapSimulatorTest, BufferReusedOnce) { (neg_buffer == output_buffer_1)); } +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_F(HeapSimulatorTest, MultiplyDot) { auto builder = HloComputation::Builder(TestName()); auto paramA = builder.AddInstruction( @@ -366,8 +373,8 @@ TEST_F(HeapSimulatorTest, MultiplyDot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); // The buffer for dot is the output, and it cannot be shared with the buffer // for mul, since dot isn't elementwise. @@ -402,8 +409,8 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); @@ -440,10 +447,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot0 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); - auto dot1 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); + auto dot0 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); + auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2))); // The buffer for dot1 is the output. No buffers can be shared. The buffer // for mul is freed before the end, since it's no longer used after dot0 @@ -481,10 +488,10 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot0 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); - auto dot1 = builder.AddInstruction( - HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); + auto dot0 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, mul, paramY, dot_dnums, DefaultPrecisionConfig(2))); + auto dot1 = builder.AddInstruction(HloInstruction::CreateDot( + f32vec4_, dot0, paramY, dot_dnums, DefaultPrecisionConfig(2))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index f7ed1b0316..a2c1ce34c6 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -601,8 +601,11 @@ TEST_F(HloComputationTest, Stringification) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -633,8 +636,11 @@ TEST_F(HloComputationTest, StringificationIndent) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -666,8 +672,11 @@ TEST_F(HloComputationTest, StringificationCanonical) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index 19ffb465c0..a6ae0337a5 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -61,15 +61,18 @@ StatusOr MakeSliceHlo(HloInstruction* operand, } StatusOr MakeConvolveHlo( - HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers) { + HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfigProto& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); - TF_ASSIGN_OR_RETURN(Shape convolve_shape, ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), - window, dimension_numbers)); + TF_ASSIGN_OR_RETURN(Shape convolve_shape, + ShapeInference::InferConvolveShape( + lhs->shape(), rhs->shape(), feature_group_count, + window, dimension_numbers)); return computation->AddInstruction(HloInstruction::CreateConvolve( - convolve_shape, lhs, rhs, window, dimension_numbers)); + convolve_shape, lhs, rhs, feature_group_count, window, dimension_numbers, + precision_config)); } StatusOr MakeTransposeHlo(HloInstruction* operand, @@ -164,15 +167,17 @@ StatusOr MakeConcatHlo( HloInstruction::CreateConcatenate(concat_shape, operands, dimension)); } -StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers) { +StatusOr MakeDotHlo( + HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN( Shape dot_shape, ShapeInference::InferDotOpShape(lhs->shape(), rhs->shape(), dim_numbers)); - return computation->AddInstruction( - HloInstruction::CreateDot(dot_shape, lhs, rhs, dim_numbers)); + return computation->AddInstruction(HloInstruction::CreateDot( + dot_shape, lhs, rhs, dim_numbers, precision_config)); } StatusOr MakeMapHlo(absl::Span operands, diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index a1c4b374d1..1c82956907 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -48,8 +48,9 @@ StatusOr MakeSliceHlo(HloInstruction* operand, // Creates a convolution HLO instruction and adds it to the computation // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation). StatusOr MakeConvolveHlo( - HloInstruction* lhs, HloInstruction* rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers); + HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfigProto& precision_config); // Creates a transpose HLO instruction and adds it to the computation containing // `operand`. @@ -97,8 +98,10 @@ StatusOr MakeConcatHlo( // Creates a Dot HLO instruction and adds it to the computation containing `lhs` // and `rhs` (both must be in the same computation). -StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers); +StatusOr MakeDotHlo( + HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config); // Creates a Map HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index d1a96c10f8..62eea2b06c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2334,8 +2334,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + 2, PrecisionConfigProto::DEFAULT); auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 441dcad000..ffb3451164 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -53,7 +53,6 @@ namespace xla { namespace { - template StatusOr> Compare(const Shape& shape, HloOpcode opcode, LiteralSlice lhs_literal, @@ -345,7 +344,8 @@ StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( } StatusOr> HloEvaluator::EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, const Literal& lhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = HloInstruction::CreateConstant(lhs.CloneToUnique()); @@ -358,7 +358,7 @@ StatusOr> HloEvaluator::EvaluateDotOp( std::unique_ptr cloned_instruction = HloInstruction::CreateDot(dot_shape, lhs_instr.get(), rhs_instr.get(), - dim_numbers); + dim_numbers, precision_config); return Evaluate(cloned_instruction.get()); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index c2d49e56ac..e13af8e999 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -115,7 +115,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { HloOpcode opcode, const Literal& operand); StatusOr> EvaluateDotOp( - const DotDimensionNumbers& dim_numbers, const Literal& lhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config, const Literal& lhs, const Literal& rhs); protected: diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 7e490d7f32..3ab8ef18dd 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -622,6 +622,13 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_P(HloEvaluatorTest, DotRank2AndRank1) { HloComputation::Builder b(TestName()); @@ -649,7 +656,8 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr result = Evaluate(); @@ -694,7 +702,8 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) { dot_dnums.add_lhs_contracting_dimensions(0); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr result = Evaluate(); @@ -737,7 +746,8 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) { dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, - rhs_instruction, dot_dnums)); + rhs_instruction, dot_dnums, + DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr result = Evaluate(); @@ -790,7 +800,8 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr result = Evaluate(); @@ -844,7 +855,8 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr result = Evaluate(); @@ -927,7 +939,8 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr result = Evaluate(); @@ -1004,7 +1017,8 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr result = Evaluate(); @@ -1063,7 +1077,8 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr result = Evaluate(); @@ -1126,7 +1141,8 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr result = Evaluate(); @@ -1197,7 +1213,8 @@ TEST_P(HloEvaluatorTest, const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( - shape, lhs_instruction, rhs_instruction, window, dnums)); + shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, + window, dnums, DefaultPrecisionConfig(2))); module().AddEntryComputation(b.Build()); std::unique_ptr result = Evaluate(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index cb27e13e99..dc16a84246 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1021,9 +1021,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { CHECK_EQ(num_spatial_dims + 2, lhs_rank); CHECK_EQ(num_spatial_dims + 2, rhs_rank); - TF_ASSIGN_OR_RETURN(auto inferred_return_shape, - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, - window, dnums)); + TF_ASSIGN_OR_RETURN( + auto inferred_return_shape, + ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, conv->feature_group_count(), window, dnums)); CHECK(ShapeUtil::Compatible(result_shape, inferred_return_shape)) << "return shape set to: " << ShapeUtil::HumanString(result_shape) << " but is inferred to be: " diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 6d13f85cbb..f25761ac70 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -341,17 +341,21 @@ StatusOr> HloInstruction::CreateFromProto( source_target_pairs); break; } - case HloOpcode::kConvolution: + case HloOpcode::kConvolution: { TF_RET_CHECK(proto.operand_ids_size() == 2) << "Convolution instruction should have 2 operands but sees " << proto.operand_ids_size(); TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); + PrecisionConfigProto precision_config = proto.precision_config(); + precision_config.mutable_operand_precision()->Resize( + proto.operand_ids_size(), PrecisionConfigProto::DEFAULT); instruction = CreateConvolve( - proto.shape(), operands(0), operands(1), proto.window(), - proto.convolution_dimension_numbers(), - std::max(static_cast(proto.feature_group_count()), 1LL)); + proto.shape(), operands(0), operands(1), + std::max(proto.feature_group_count(), 1), proto.window(), + proto.convolution_dimension_numbers(), precision_config); break; + } case HloOpcode::kReduceWindow: TF_RET_CHECK(proto.operand_ids_size() == 2) << "ReduceWindow instruction should have 2 operands but sees " @@ -468,6 +472,20 @@ StatusOr> HloInstruction::CreateFromProto( computation_map.at(computation_id)); } } + if (instruction->opcode() == HloOpcode::kDot) { + instruction->precision_config_ = proto.precision_config(); + instruction->precision_config_.mutable_operand_precision()->Resize( + instruction->operand_count(), PrecisionConfigProto::DEFAULT); + TF_RET_CHECK(proto.has_dot_dimension_numbers()); + instruction->dot_dimension_numbers_ = + absl::make_unique( + proto.dot_dimension_numbers()); + } else { + TF_RET_CHECK(!proto.has_precision_config()) + << instruction->opcode() << proto.DebugString(); + TF_RET_CHECK(!proto.has_dot_dimension_numbers()) + << instruction->opcode(); + } break; } } @@ -476,12 +494,6 @@ StatusOr> HloInstruction::CreateFromProto( instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); instruction->backend_config_ = proto.backend_config(); - instruction->precision_config_ = proto.precision_config(); - - if (proto.has_dot_dimension_numbers()) { - instruction->dot_dimension_numbers_ = - absl::make_unique(proto.dot_dimension_numbers()); - } if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, @@ -643,10 +655,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) { + int64 feature_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfigProto& precision_config) { return absl::make_unique( - shape, lhs, rhs, window, dimension_numbers, feature_group_count); + shape, lhs, rhs, feature_group_count, window, dimension_numbers, + precision_config); } /* static */ std::unique_ptr HloInstruction::CreateFft( @@ -658,13 +672,15 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dimension_numbers) { + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfigProto& precision_config) { auto instruction = absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); instruction->AppendOperand(lhs); instruction->AppendOperand(rhs); instruction->dot_dimension_numbers_ = absl::make_unique(dimension_numbers); + instruction->set_precision_config(precision_config); return instruction; } @@ -1057,7 +1073,6 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->clear_sharding(); } derived_instruction->set_metadata(metadata_); - derived_instruction->set_precision_config(precision_config_); } bool HloInstruction::HasSideEffectNoRecurse() const { @@ -1278,7 +1293,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kDot: CHECK_EQ(new_operands.size(), 2); clone = CreateDot(shape, new_operands[0], new_operands[1], - *dot_dimension_numbers_); + *dot_dimension_numbers_, precision_config()); break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); @@ -2167,7 +2182,9 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_metadata() = metadata_; proto.set_backend_config(backend_config_); - *proto.mutable_precision_config() = precision_config_; + if (opcode() == HloOpcode::kConvolution || opcode() == HloOpcode::kDot) { + *proto.mutable_precision_config() = precision_config_; + } if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); @@ -2948,7 +2965,11 @@ StatusOr StringToRandomDistribution(const string& name) { } string HloInstruction::PrecisionConfigToString() const { - if (precision_config_.operand_precision().empty()) { + if (absl::c_all_of( + precision_config_.operand_precision(), [](int32 precision) { + return static_cast(precision) == + PrecisionConfigProto::DEFAULT; + })) { return ""; } return StrCat( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index cca134e8b4..55d592ff94 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -405,9 +405,9 @@ class HloInstruction { // and window describes how the filter is applied to lhs. static std::unique_ptr CreateConvolve( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, + int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + const PrecisionConfigProto& precision_config); // Creates an FFT op, of the type indicated by fft_type. static std::unique_ptr CreateFft( @@ -418,7 +418,8 @@ class HloInstruction { // dimensions specified in 'dimension_numbers'. static std::unique_ptr CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dimension_numbers); + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfigProto& precision_config); // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 76b0e940a6..b4e302e832 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1122,6 +1122,13 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { } } +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { // Fused expression: // @@ -1147,8 +1154,8 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1188,8 +1195,8 @@ TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(s, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + s, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1239,8 +1246,8 @@ TEST_F(HloInstructionTest, NestedFusionEquality) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); + auto dot = builder.AddInstruction(HloInstruction::CreateDot( + data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2))); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); auto add_operand = builder.AddInstruction( @@ -1320,8 +1327,8 @@ TEST_F(HloInstructionTest, Stringification) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto options = HloPrintOptions().set_print_metadata(false); @@ -1485,8 +1492,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto options = HloPrintOptions().Canonical(); @@ -1527,8 +1534,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationWhile) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); @@ -1583,8 +1590,8 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule(); auto* computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e46afa764f..bed273149b 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1628,12 +1628,13 @@ std::unique_ptr HloOutfeedInstruction::CloneWithNewOperandsImpl( HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count) + int64 feature_group_count, const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers, + const PrecisionConfigProto& precision_config) : HloInstruction(HloOpcode::kConvolution, shape), + feature_group_count_(feature_group_count), window_(window), - convolution_dimension_numbers_(dimension_numbers), - feature_group_count_(feature_group_count) { + convolution_dimension_numbers_(dimension_numbers) { if (window_util::HasBaseDilation(window)) { SetAndSanitizeName(StrCat(name(), "-base-dilated")); } @@ -1642,6 +1643,7 @@ HloConvolutionInstruction::HloConvolutionInstruction( } AppendOperand(lhs); AppendOperand(rhs); + set_precision_config(precision_config); } string HloConvolutionInstruction::ToCategory() const { @@ -1697,8 +1699,8 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl( HloCloneContext* context) const { CHECK_EQ(new_operands.size(), 2); return absl::make_unique( - shape, new_operands[0], new_operands[1], window(), - convolution_dimension_numbers_, feature_group_count_); + shape, new_operands[0], new_operands[1], feature_group_count_, window(), + convolution_dimension_numbers_, precision_config()); } HloReduceWindowInstruction::HloReduceWindowInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 3230383579..1c85aa4681 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -942,9 +942,9 @@ class HloConvolutionInstruction : public HloInstruction { public: explicit HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, - const Window& window, + int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count); + const PrecisionConfigProto& precision_config); const Window& window() const override { return window_; } void set_window(const Window& window) override { window_ = window; } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { @@ -972,12 +972,13 @@ class HloConvolutionInstruction : public HloInstruction { std::unique_ptr CloneWithNewOperandsImpl( const Shape& shape, absl::Span new_operands, HloCloneContext* context) const override; - Window window_; - // Describes the dimension numbers used for a convolution. - ConvolutionDimensionNumbers convolution_dimension_numbers_; // The number of feature groups. Must be a divisor of the input feature // dimension and output feature dimension. int64 feature_group_count_; + // Describes the window used for a convolution. + Window window_; + // Describes the dimension numbers used for a convolution. + ConvolutionDimensionNumbers convolution_dimension_numbers_; }; class HloReduceWindowInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index ea8e6a239a..62f01c4adb 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -530,10 +530,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, attrs["backend_config"] = {/*required=*/false, AttrTy::kString, &backend_config}; - optional> operand_precision; - attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, - &operand_precision}; - HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -913,6 +909,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + optional> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { return false; @@ -923,9 +922,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!feature_group_count) { feature_group_count = 1; } + PrecisionConfigProto precision_config; + if (operand_precision) { + *precision_config.mutable_operand_precision() = { + operand_precision->begin(), operand_precision->end()}; + } else { + precision_config.mutable_operand_precision()->Resize( + operands.size(), PrecisionConfigProto::DEFAULT); + } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( - shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums, - feature_group_count.value())); + shape, /*lhs=*/operands[0], /*rhs=*/operands[1], + feature_group_count.value(), *window, *dnums, precision_config)); break; } case HloOpcode::kFft: { @@ -1272,6 +1279,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; + optional> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; if (!ParseOperands(&operands, /*expected_size=*/2) || !ParseAttributes(attrs)) { @@ -1296,8 +1306,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, rhs_batch_dims->end()}; } - instruction = builder->AddInstruction( - HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); + PrecisionConfigProto precision_config; + if (operand_precision) { + *precision_config.mutable_operand_precision() = { + operand_precision->begin(), operand_precision->end()}; + } else { + precision_config.mutable_operand_precision()->Resize( + operands.size(), PrecisionConfigProto::DEFAULT); + } + + instruction = builder->AddInstruction(HloInstruction::CreateDot( + shape, operands[0], operands[1], dnum, precision_config)); break; } case HloOpcode::kGather: { @@ -1414,12 +1433,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (backend_config) { instruction->set_raw_backend_config_string(std::move(*backend_config)); } - if (operand_precision) { - PrecisionConfigProto precision_config; - *precision_config.mutable_operand_precision() = {operand_precision->begin(), - operand_precision->end()}; - instruction->set_precision_config(precision_config); - } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 95516dec74..069586a738 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -86,8 +86,8 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { const Shape expected, ShapeInference::InferConvolveShape( convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->window(), convolution->convolution_dimension_numbers(), - convolution->feature_group_count())); + convolution->feature_group_count(), convolution->window(), + convolution->convolution_dimension_numbers())); return CheckShape(convolution, expected); } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index a4de02a890..4a71ee909b 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -165,6 +165,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayFor( TF_ASSIGN_OR_RETURN( computed_array, ComputeArrayForDot(instr->shape(), instr->dot_dimension_numbers(), + instr->precision_config(), FindOrDie(cache_, instr->operand(0)), FindOrDie(cache_, instr->operand(1)))); } else { @@ -1030,6 +1031,7 @@ bool CanFoldDotIntoIndexedArray( StatusOr IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config, ScalarIndexedConstantArray* lhs, ConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " " << ToString(rhs); @@ -1045,9 +1047,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( new_dim_numbers.set_lhs_contracting_dimensions( 0, lhs->source_dim() == (lhs_rank - 1) ? (lhs_rank - 2) : (lhs_rank - 1)); - TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, - TakeOwnership(HloEvaluator{}.EvaluateDotOp( - new_dim_numbers, lhs->literal(), *rhs->literal()))); + TF_ASSIGN_OR_RETURN( + Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, precision_config, lhs->literal(), *rhs->literal()))); // The new source dimension is wherever the non-batch non-contracting LHS // dimension "went". @@ -1063,7 +1066,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( StatusOr IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ConstantArray* lhs, ScalarIndexedConstantArray* rhs) { + const PrecisionConfigProto& precision_config, ConstantArray* lhs, + ScalarIndexedConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " " << ToString(rhs); if (!CanFoldDotIntoIndexedArray( @@ -1079,9 +1083,10 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( new_dim_numbers.set_rhs_contracting_dimensions( 0, rhs->source_dim() == (rhs_rank - 1) ? (rhs_rank - 2) : (rhs_rank - 1)); - TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, - TakeOwnership(HloEvaluator{}.EvaluateDotOp( - new_dim_numbers, *lhs->literal(), rhs->literal()))); + TF_ASSIGN_OR_RETURN( + Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateDotOp( + new_dim_numbers, precision_config, *lhs->literal(), rhs->literal()))); // The new source dimension is wherever the non-batch non-contracting RHS // dimension "went". @@ -1095,8 +1100,8 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( } StatusOr IndexedArrayAnalysis::ComputeArrayForDot( - const Shape& shape, const DotDimensionNumbers& dim_numbers, Array* lhs, - Array* rhs) { + const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs) { // Intuitively, if // // - The LHS of a dot product is a gathered sequence of rows from a constant @@ -1119,6 +1124,7 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForDot( dynamic_cast(lhs)) { if (auto* rhs_constant = dynamic_cast(rhs)) { return ComputeArrayForDotWithIndexedLhs(shape, dim_numbers, + precision_config, lhs_indexed_array, rhs_constant); } } @@ -1126,7 +1132,8 @@ StatusOr IndexedArrayAnalysis::ComputeArrayForDot( if (auto* rhs_indexed_array = dynamic_cast(rhs)) { if (auto* lhs_constant = dynamic_cast(lhs)) { - return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, lhs_constant, + return ComputeArrayForDotWithIndexedRhs(shape, dim_numbers, + precision_config, lhs_constant, rhs_indexed_array); } } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index dcfb725535..f21e784a4d 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -267,15 +267,17 @@ class IndexedArrayAnalysis { StatusOr ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config, ScalarIndexedConstantArray* lhs, ConstantArray* rhs); StatusOr ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - ConstantArray* lhs, ScalarIndexedConstantArray* rhs); + const PrecisionConfigProto& precision_config, ConstantArray* lhs, + ScalarIndexedConstantArray* rhs); - StatusOr ComputeArrayForDot(const Shape& shape, - const DotDimensionNumbers& dim_numbers, - Array* lhs, Array* rhs); + StatusOr ComputeArrayForDot( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs); // This tries to fold a ScalarIndexedArray which has another // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 2611749862..7758a5dd4d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1552,8 +1552,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferConvolveShape( - const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dnums, int64 feature_group_count) { + const Shape& lhs, const Shape& rhs, int64 feature_group_count, + const Window& window, const ConvolutionDimensionNumbers& dnums) { TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution")); diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index a28345acef..96a0ee165d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -108,9 +108,9 @@ class ShapeInference { // Infers the shape produced by applying the given convolutional // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr InferConvolveShape( - const Shape& lhs, const Shape& rhs, const Window& window, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1); + const Shape& lhs, const Shape& rhs, int64 feature_group_count, + const Window& window, + const ConvolutionDimensionNumbers& dimension_numbers); // Infers the shape produced by the given FFT type on the given operand. static StatusOr InferFftShape(const Shape& in, FftType fft_type, diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index cc92e58ef8..864ed43118 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -419,8 +419,8 @@ TEST_F(ShapeInferenceTest, Convolve) { dim1->set_padding_high(0); dim1->set_window_dilation(1); dim1->set_base_dilation(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}), @@ -464,8 +464,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dim1->set_padding_high(1); dim1->set_window_dilation(2); dim1->set_base_dilation(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}), @@ -509,8 +509,8 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dim1->set_padding_high(1); dim1->set_window_dilation(1); dim1->set_base_dilation(2); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_IS_OK(inferred_status.status()); Shape inferred_shape = inferred_status.ValueOrDie(); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}), @@ -547,8 +547,8 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dim1->set_stride(2); dim1->set_padding_low(1); dim1->set_padding_high(1); - auto inferred_status = - ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums); + auto inferred_status = ShapeInference::InferConvolveShape( + lhs_shape, rhs_shape, /*feature_group_count=*/1, window, dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("each dimension exactly once")); diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 530f40e4b2..7c1f4b5cc6 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -108,8 +108,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) { } std::unique_ptr new_dot = HloInstruction::CreateDot( - dot->shape(), new_lhs, new_rhs, new_dim_numbers); - new_dot->set_precision_config(dot->precision_config()); + dot->shape(), new_lhs, new_rhs, new_dim_numbers, dot->precision_config()); return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot)); } @@ -178,8 +177,8 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { } auto new_conv = HloInstruction::CreateConvolve( - convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); - new_conv->set_precision_config(convolution.precision_config()); + convolution.shape(), new_lhs, new_rhs, convolution.feature_group_count(), + convolution.window(), new_dnums, convolution.precision_config()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 58f767e913..e486a00e53 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -215,6 +215,13 @@ ENTRY entry_computation { /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); } +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + // Test that a two dimension swap of the kernel gets folded into convolution. TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { auto builder = HloComputation::Builder("entry_computation"); @@ -240,10 +247,12 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), window, dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + conv_shape.ValueOrDie(), x, transpose_y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -293,10 +302,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { transpose_y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - x->shape(), transpose_y->shape(), window, dnums); + x->shape(), transpose_y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), x, transpose_y, window, dnums)); + conv_shape.ValueOrDie(), x, transpose_y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -351,10 +362,12 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), window, dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + conv_shape.ValueOrDie(), transpose_x, y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = @@ -415,10 +428,12 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); } StatusOr conv_shape = ShapeInference::InferConvolveShape( - transpose_x->shape(), y->shape(), window, dnums); + transpose_x->shape(), y->shape(), /*feature_group_count=*/1, window, + dnums); EXPECT_IS_OK(conv_shape); HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( - conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + conv_shape.ValueOrDie(), transpose_x, y, + /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); auto module = CreateNewModule("test_module"); HloComputation* entry_computation = diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index a32d1f9026..e3328203a6 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1064,8 +1064,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + /*new_size=*/2, PrecisionConfigProto::DEFAULT); auto dot = builder.AddInstruction( - HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 05f90ba9fb..53b5e933b6 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -47,6 +47,12 @@ limitations under the License. namespace xla { namespace { +PrecisionConfigProto DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} class MultiOutputFusionTest : public HloTestBase { protected: @@ -90,8 +96,8 @@ class MultiOutputFusionTest : public HloTestBase { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + elem_shape2, sub, add2, dot_dnums, DefaultPrecisionConfig(2))); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -154,7 +160,7 @@ class MultiOutputFusionTest : public HloTestBase { dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( ShapeUtil::MakeShapeWithDescendingLayout(F32, {1}), sub, reshape, - dot_dnums)); + dot_dnums, DefaultPrecisionConfig(2))); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { -- GitLab From 965e3b0ca01ed7cc951131454b38ab638ff44fbf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 11:18:50 -0700 Subject: [PATCH 055/540] Extend hoisting monotonic functions out of min/max reductions to all monotonic unary functions. Add the ability to flip Max <-> Min if the function is non-increasing, e.g. Max(Neg(x)) => Neg(Min(x)). PiperOrigin-RevId: 211490436 --- tensorflow/core/grappler/op_types.cc | 37 ++++++++++++---- tensorflow/core/grappler/op_types.h | 2 +- .../optimizers/arithmetic_optimizer.cc | 10 ++++- .../optimizers/arithmetic_optimizer_test.cc | 42 +++++++++++++++++++ 4 files changed, 80 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 653b088b1d..e78239bd43 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -135,16 +135,37 @@ bool IsDequeueOp(const NodeDef& node) { bool IsDiv(const NodeDef& node) { return node.op() == "Div"; } -bool IsElementWiseMonotonic(const NodeDef& node) { - static const std::unordered_set* element_wise_monotonic_ops = +// Returns true if node represents a unary elementwise function that is +// monotonic. If *is_non_decreasing is true, the function is non-decreasing, +// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing, +// e.g. inv. +bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) { + static const std::unordered_set* monotonic_non_decreasing_ops = CHECK_NOTNULL((new std::unordered_set{ - "Relu", - "Relu6", - "Sigmoid", - "Sqrt", - "Tanh", + "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1", + "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint", + "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh", + })); + static const std::unordered_set* monotonic_non_increasing_ops = + CHECK_NOTNULL((new std::unordered_set{ + "Inv", + "Reciprocal", + "Erfc", + "Rsqrt", + "Neg", })); - return element_wise_monotonic_ops->count(node.op()) > 0; + if (monotonic_non_decreasing_ops->count(node.op()) > 0) { + if (is_non_decreasing) { + *is_non_decreasing = true; + } + return true; + } else if (monotonic_non_increasing_ops->count(node.op()) > 0) { + if (is_non_decreasing) { + *is_non_decreasing = false; + } + return true; + } + return false; } bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 94439265c9..25ab6b65ac 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -55,7 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node); bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsDiv(const NodeDef& node); -bool IsElementWiseMonotonic(const NodeDef& node); +bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing); bool IsEluGrad(const NodeDef& node); bool IsEnter(const NodeDef& node); bool IsEqual(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 4fed88d536..65947ddce5 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2706,8 +2706,9 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { // 0. inner_function is not in the preserve set, // 1. inner_function's Op is element-wise monotonic // 2. inner_function's output is not being consumed elsewhere. + bool is_non_decreasing = false; if (!IsInPreserveSet(*inner_function) && - IsElementWiseMonotonic(*inner_function) && + IsElementWiseMonotonic(*inner_function, &is_non_decreasing) && ctx().node_map->GetOutputs(inner_function->name()).size() == 1) { // Swap the first inputs of the inner function Op & the reduction Op. NodeDef* inner_input; @@ -2719,7 +2720,12 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { UpdateConsumers(reduction_node, inner_function->name()); ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(), reduction_node->name()); - + if (!is_non_decreasing) { + // Flip Min<->Max if the function is non-increasing, e.g. + // Max(Neg(x)) = Neg(Min(x)). + const string opposite = IsMax(*reduction_node) ? "Min" : "Max"; + reduction_node->set_op(opposite); + } AddToOptimizationQueue(reduction_node); AddToOptimizationQueue(inner_function); AddToOptimizationQueue(inner_input); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index bfccc0affd..39517edc06 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -3248,6 +3248,48 @@ TEST_F(ArithmeticOptimizerTest, VerifyGraphsMatch(item.graph, output, __LINE__); } +TEST_F(ArithmeticOptimizerTest, + OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output neg = ops::Neg(s.WithOpName("neg"), x); + Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0}); + Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max); + + GrapplerItem item; + item.fetch = {"final_out"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + + test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); + EXPECT_EQ(item.graph.node_size(), output.node_size()); + // Check if the inputs are switched + int required_node_count = 0; + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + if (node.name() == "neg") { + EXPECT_EQ("Neg", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("reduce_max", node.input(0)); + ++required_node_count; + } else if (node.name() == "reduce_max") { + EXPECT_EQ("Min", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + ++required_node_count; + } + } + EXPECT_EQ(2, required_node_count); +} + TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); -- GitLab From 7ac5c1ed94eae6e23dc9bc42dc99bf4c500b71a6 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Tue, 4 Sep 2018 11:48:52 -0700 Subject: [PATCH 056/540] [TF:XLA] Bump open source llvm revision to r341347 PiperOrigin-RevId: 211496283 --- tensorflow/workspace.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index fdbb1bf383..01f82cc68a 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -491,11 +491,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/67bd0d9a0f5597f57f272061fd70f24dffb3d223.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/67bd0d9a0f5597f57f272061fd70f24dffb3d223.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/dc6d9ec3646865125d057b6f515b4543df79920a.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/dc6d9ec3646865125d057b6f515b4543df79920a.tar.gz", ], - sha256 = "b8f4ffbcaeea345e2245fd7028c7e960d71c2a2007c20bbfc5d79ecc86992a5e", - strip_prefix = "llvm-67bd0d9a0f5597f57f272061fd70f24dffb3d223", + sha256 = "c7252290a113f694cccbb4b325c67b56f3aa6f5b3044524302c0e79db2da7e2a", + strip_prefix = "llvm-dc6d9ec3646865125d057b6f515b4543df79920a", build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"), ) -- GitLab From 22e855159462b502dc3af138d254214bd02cf68b Mon Sep 17 00:00:00 2001 From: Guangda Lai Date: Tue, 4 Sep 2018 11:49:16 -0700 Subject: [PATCH 057/540] Fix the tensorrt dependency order in tensorflow/contrib/BUILD. PiperOrigin-RevId: 211496364 --- tensorflow/contrib/BUILD | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 66983801bf..798f499870 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -20,13 +20,7 @@ py_library( ), srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = if_not_windows([ - # TODO(aaroey): tensorrt dependency has to appear before tflite so the - # build can resolve its flatbuffers symbols within the tensorrt library. - # This is an issue with the tensorrt static library and will be fixed by - # the next tensorrt release, so fix the order here after that. - "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows - ]) + [ + deps = [ "//tensorflow/contrib/all_reduce", "//tensorflow/contrib/batching:batch_py", "//tensorflow/contrib/bayesflow:bayesflow_py", @@ -135,6 +129,7 @@ py_library( ]) + if_not_windows([ "//tensorflow/contrib/bigtable", # depends on bigtable "//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows + "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", ]), ) -- GitLab From 2fcec016cec1ec70ba715c9b2f4c759c71eaafca Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 12:02:22 -0700 Subject: [PATCH 058/540] Add IsValidSignature method to signature_def_utils PiperOrigin-RevId: 211498364 --- .../contrib/saved_model/cc/saved_model/BUILD | 2 + .../cc/saved_model/signature_def_utils.cc | 81 ++++++++++++ .../cc/saved_model/signature_def_utils.h | 3 + .../saved_model/signature_def_utils_test.cc | 123 ++++++++++++++++-- 4 files changed, 201 insertions(+), 8 deletions(-) diff --git a/tensorflow/contrib/saved_model/cc/saved_model/BUILD b/tensorflow/contrib/saved_model/cc/saved_model/BUILD index 3c616c555b..ea4d41d43b 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/cc/saved_model/BUILD @@ -30,6 +30,7 @@ cc_library( hdrs = ["signature_def_utils.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", @@ -42,6 +43,7 @@ tf_cc_test( srcs = ["signature_def_utils_test.cc"], deps = [ ":signature_def_utils", + "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc index a45908d272..e87e497e5f 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h" +#include "tensorflow/cc/saved_model/signature_constants.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/protobuf.h" @@ -33,6 +35,79 @@ Status FindInProtobufMap(StringPiece description, *value = &it->second; return Status::OK(); } + +// Looks up the TensorInfo for the given key in the given map and verifies that +// its datatype matches the given correct datatype. +bool VerifyTensorInfoForKeyInMap(const protobuf::Map& map, + const string& key, DataType correct_dtype) { + const TensorInfo* tensor_info; + const Status& status = FindInProtobufMap("", map, key, &tensor_info); + if (!status.ok()) { + return false; + } + if (tensor_info->dtype() != correct_dtype) { + return false; + } + return true; +} + +bool IsValidPredictSignature(const SignatureDef& signature_def) { + if (signature_def.method_name() != kPredictMethodName) { + return false; + } + if (signature_def.inputs().empty()) { + return false; + } + if (signature_def.outputs().empty()) { + return false; + } + return true; +} + +bool IsValidRegressionSignature(const SignatureDef& signature_def) { + if (signature_def.method_name() != kRegressMethodName) { + return false; + } + if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kRegressInputs, + DT_STRING)) { + return false; + } + if (!VerifyTensorInfoForKeyInMap(signature_def.outputs(), kRegressOutputs, + DT_FLOAT)) { + return false; + } + return true; +} + +bool IsValidClassificationSignature(const SignatureDef& signature_def) { + if (signature_def.method_name() != kClassifyMethodName) { + return false; + } + if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kClassifyInputs, + DT_STRING)) { + return false; + } + if (signature_def.outputs().empty()) { + return false; + } + for (auto const& output : signature_def.outputs()) { + const string& key = output.first; + const TensorInfo& tensor_info = output.second; + if (key == kClassifyOutputClasses) { + if (tensor_info.dtype() != DT_STRING) { + return false; + } + } else if (key == kClassifyOutputScores) { + if (tensor_info.dtype() != DT_FLOAT) { + return false; + } + } else { + return false; + } + } + return true; +} + } // namespace Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def, @@ -74,4 +149,10 @@ Status FindOutputTensorNameByKey(const SignatureDef& signature_def, return Status::OK(); } +bool IsValidSignature(const SignatureDef& signature_def) { + return IsValidClassificationSignature(signature_def) || + IsValidRegressionSignature(signature_def) || + IsValidPredictSignature(signature_def); +} + } // namespace tensorflow diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h index b732cdd41e..bb24faa989 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h @@ -64,6 +64,9 @@ Status FindInputTensorNameByKey(const SignatureDef& signature_def, Status FindOutputTensorNameByKey(const SignatureDef& signature_def, const string& tensor_info_key, string* name); +// Determine whether a SignatureDef can be served by TensorFlow Serving. +bool IsValidSignature(const SignatureDef& signature_def); + } // namespace tensorflow #endif // TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_ diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc index a063e95696..c743112ce0 100644 --- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc +++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h" +#include "tensorflow/cc/saved_model/signature_constants.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -22,7 +23,7 @@ limitations under the License. namespace tensorflow { -class SignatureDefUtilsTest : public ::testing::Test { +class FindByKeyTest : public ::testing::Test { protected: MetaGraphDef MakeSampleMetaGraphDef() { MetaGraphDef result; @@ -32,13 +33,23 @@ class SignatureDefUtilsTest : public ::testing::Test { return result; } + void SetInputNameForKey(const string& key, const string& name, + SignatureDef* signature_def) { + (*signature_def->mutable_inputs())[key].set_name(name); + } + + void SetOutputNameForKey(const string& key, const string& name, + SignatureDef* signature_def) { + (*signature_def->mutable_outputs())[key].set_name(name); + } + SignatureDef MakeSampleSignatureDef() { SignatureDef result; result.set_method_name(kMethodName); - (*result.mutable_inputs())[kInput1Key].set_name(kInput1Name); - (*result.mutable_inputs())[kInput2Key].set_name(kInput2Name); - (*result.mutable_outputs())[kOutput1Key].set_name(kOutput1Name); - (*result.mutable_outputs())[kOutput2Key].set_name(kOutput2Name); + SetInputNameForKey(kInput1Key, kInput1Name, &result); + SetInputNameForKey(kInput2Key, kInput2Name, &result); + SetOutputNameForKey(kOutput1Key, kOutput1Name, &result); + SetOutputNameForKey(kOutput2Key, kOutput2Name, &result); return result; } @@ -54,7 +65,7 @@ class SignatureDefUtilsTest : public ::testing::Test { const string kOutput2Name = "output_two"; }; -TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) { +TEST_F(FindByKeyTest, FindSignatureDefByKey) { const MetaGraphDef meta_graph_def = MakeSampleMetaGraphDef(); const SignatureDef* signature_def; // Succeeds for an existing signature. @@ -67,7 +78,7 @@ TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) { .ok()); } -TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) { +TEST_F(FindByKeyTest, FindInputTensorNameByKey) { const SignatureDef signature_def = MakeSampleSignatureDef(); string name; // Succeeds for an existing input. @@ -78,7 +89,7 @@ TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) { FindInputTensorNameByKey(signature_def, "nonexistent", &name).ok()); } -TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) { +TEST_F(FindByKeyTest, FindOutputTensorNameByKey) { const SignatureDef signature_def = MakeSampleSignatureDef(); string name; // Succeeds for an existing output. @@ -89,4 +100,100 @@ TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) { FindOutputTensorNameByKey(signature_def, "nonexistent", &name).ok()); } +class IsValidSignatureTest : public ::testing::Test { + protected: + void SetInputDataTypeForKey(const string& key, DataType dtype) { + (*signature_def_.mutable_inputs())[key].set_dtype(dtype); + } + + void SetOutputDataTypeForKey(const string& key, DataType dtype) { + (*signature_def_.mutable_outputs())[key].set_dtype(dtype); + } + + void EraseOutputKey(const string& key) { + (*signature_def_.mutable_outputs()).erase(key); + } + + void ExpectInvalidSignature() { + EXPECT_FALSE(IsValidSignature(signature_def_)); + } + + void ExpectValidSignature() { EXPECT_TRUE(IsValidSignature(signature_def_)); } + + SignatureDef signature_def_; +}; + +TEST_F(IsValidSignatureTest, IsValidPredictSignature) { + signature_def_.set_method_name("not_kPredictMethodName"); + // Incorrect method name + ExpectInvalidSignature(); + + signature_def_.set_method_name(kPredictMethodName); + // No inputs + ExpectInvalidSignature(); + + SetInputDataTypeForKey(kPredictInputs, DT_STRING); + // No outputs + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kPredictOutputs, DT_STRING); + ExpectValidSignature(); +} + +TEST_F(IsValidSignatureTest, IsValidRegressionSignature) { + signature_def_.set_method_name("not_kRegressMethodName"); + // Incorrect method name + ExpectInvalidSignature(); + + signature_def_.set_method_name(kRegressMethodName); + // No inputs + ExpectInvalidSignature(); + + SetInputDataTypeForKey(kRegressInputs, DT_STRING); + // No outputs + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kRegressOutputs, DT_STRING); + // Incorrect data type + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kRegressOutputs, DT_FLOAT); + ExpectValidSignature(); +} + +TEST_F(IsValidSignatureTest, IsValidClassificationSignature) { + signature_def_.set_method_name("not_kClassifyMethodName"); + // Incorrect method name + ExpectInvalidSignature(); + + signature_def_.set_method_name(kClassifyMethodName); + // No inputs + ExpectInvalidSignature(); + + SetInputDataTypeForKey(kClassifyInputs, DT_STRING); + // No outputs + ExpectInvalidSignature(); + + SetOutputDataTypeForKey("invalidKey", DT_FLOAT); + // Invalid key + ExpectInvalidSignature(); + + EraseOutputKey("invalidKey"); + SetOutputDataTypeForKey(kClassifyOutputClasses, DT_FLOAT); + // Invalid dtype for classes + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kClassifyOutputClasses, DT_STRING); + // Valid without scores + ExpectValidSignature(); + + SetOutputDataTypeForKey(kClassifyOutputScores, DT_STRING); + // Invalid dtype for scores + ExpectInvalidSignature(); + + SetOutputDataTypeForKey(kClassifyOutputScores, DT_FLOAT); + // Valid with both classes and scores + ExpectValidSignature(); +} + } // namespace tensorflow -- GitLab From 5587f4eb011115b947daaa4b092ef70650705687 Mon Sep 17 00:00:00 2001 From: Guangda Lai Date: Tue, 4 Sep 2018 12:11:14 -0700 Subject: [PATCH 059/540] Enable TensorRT in ci docker build. PiperOrigin-RevId: 211500190 --- tensorflow/tools/ci_build/Dockerfile.gpu | 1 + tensorflow/tools/ci_build/install/install_deb_packages.sh | 6 ++++++ tensorflow/tools/ci_build/linux/libtensorflow_docker.sh | 1 + 3 files changed, 8 insertions(+) diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu b/tensorflow/tools/ci_build/Dockerfile.gpu index f05c7a4809..a4cad4b6c6 100644 --- a/tensorflow/tools/ci_build/Dockerfile.gpu +++ b/tensorflow/tools/ci_build/Dockerfile.gpu @@ -30,3 +30,4 @@ RUN mkdir /usr/local/cuda-9.0/lib && \ # Configure the build for our CUDA configuration. ENV TF_NEED_CUDA 1 +ENV TF_NEED_TENSORRT 1 diff --git a/tensorflow/tools/ci_build/install/install_deb_packages.sh b/tensorflow/tools/ci_build/install/install_deb_packages.sh index 9640810533..179fc42d60 100755 --- a/tensorflow/tools/ci_build/install/install_deb_packages.sh +++ b/tensorflow/tools/ci_build/install/install_deb_packages.sh @@ -67,6 +67,12 @@ apt-get install -y --no-install-recommends \ zip \ zlib1g-dev +apt-get update && \ + apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \ + apt-get update && \ + apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \ + apt-get install libnvinfer-dev=4.1.2-1+cuda9.0 + # populate the database updatedb diff --git a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh index f958b3c9b7..60c974c36b 100755 --- a/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh +++ b/tensorflow/tools/ci_build/linux/libtensorflow_docker.sh @@ -52,6 +52,7 @@ ${DOCKER_BINARY} run \ -e "PYTHON_BIN_PATH=/usr/bin/python" \ -e "TF_NEED_HDFS=0" \ -e "TF_NEED_CUDA=${TF_NEED_CUDA}" \ + -e "TF_NEED_TENSORRT=${TF_NEED_CUDA}" \ -e "TF_NEED_OPENCL_SYCL=0" \ "${DOCKER_IMAGE}" \ "/workspace/tensorflow/tools/ci_build/linux/libtensorflow.sh" -- GitLab From f9b58d46499c79e01f55d9e16867a8aace667db8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 12:22:27 -0700 Subject: [PATCH 060/540] Add more data fields to step proto. PiperOrigin-RevId: 211501909 --- tensorflow/contrib/tpu/profiler/tf_op_stats.proto | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto index 2b13343efa..f88dc51636 100644 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -79,12 +79,15 @@ message StepInfoResult { // The step duration in picoseconds. optional uint64 duration_ps = 2; // The infeed duration in picoseconds. - // Can turn into a map if we want a variable number of ops. optional uint64 infeed_duration_ps = 3; + // The outfeed duration in picoseconds. + optional uint64 host_outfeed_ps = 8; // The start time of this step in picoseconds. optional uint64 begin_ps = 4; // The waiting time within this step in picoseconds. optional uint64 wait_duration_ps = 5; + // The unit b outfeed duration in picoseconds. + optional uint64 unit_b_outfeed_ps = 9; // The time spent on cross-replica-sum in picoseconds. optional uint64 crs_duration_ps = 6; // Percentage of unit b time spent on infeed. -- GitLab From 28b09bfedf396553b9190db5c687e764ab9d0cec Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Tue, 4 Sep 2018 12:45:19 -0700 Subject: [PATCH 061/540] Minor tweaks to TFLite API docs PiperOrigin-RevId: 211505612 --- tensorflow/contrib/lite/g3doc/apis.md | 43 ++++++++++++++++----------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md index f255017ad9..69616c7b8a 100644 --- a/tensorflow/contrib/lite/g3doc/apis.md +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -37,7 +37,7 @@ float* output = interpreter->typed_output_tensor(0); ``` ### Data Alignment -TensorFlow Lite data is usually aligned to 32-bit boundaries. It is recommended +TensorFlow Lite data is usually aligned to 16-byte boundaries. It is recommended that all data provided to TensorFlow Lite be aligned that way. ### Error Reporting @@ -112,7 +112,7 @@ below. It should be noted that: * Tensors are represented by integers, in order to avoid string comparisons (and any fixed dependency on string libraries). - * An interpreter must not be accessed from concurrent threads + * An interpreter must not be accessed from concurrent threads. * Memory allocation for input and output tensors must be triggered by calling AllocateTensors() right after resizing tensors. @@ -169,7 +169,7 @@ former provides error reporting facilities and access to global objects, including all the tensors. The latter allows implementations to access their inputs and outputs. -When the interpreter loads a model, it calls init() once for each node in the +When the interpreter loads a model, it calls `init()` once for each node in the graph. A given `init()` will be called more than once if the op is used multiple times in the graph. For custom ops a configuration buffer will be provided, containing a flexbuffer that maps parameter names to their values. @@ -210,8 +210,9 @@ namespace custom { Note that registration is not automatic and an explicit call to `Register_MY_CUSTOM_OP` should be made somewhere. While the standard -`:builtin_ops` takes care of the registration of builtins, custom ops will have -to be collected in separated custom libraries. +`BuiltinOpResolver` (available from the `:builtin_ops` target) takes care of the +registration of builtins, custom ops will have to be collected in separate +custom libraries. ### Customizing the kernel library @@ -232,7 +233,7 @@ class OpResolver { }; ``` -The regular usage will require the developer to use the `BuiltinOpResolver` and +Regular usage will require the developer to use the `BuiltinOpResolver` and write: ```c++ @@ -308,18 +309,25 @@ an `IllegalArgumentException` will be thrown. #### Inputs -Each input should be an array, a multi-dimensional array, or a `ByteBuffer` of -the supported primitive types. +Each input should be an array or multi-dimensional array of the supported +primitive types, or a raw `ByteBuffer` of the appropriate size. If the input is +an array or multi-dimensional array, the associated input tensor will be +implicitly resized to the array's dimensions at inference time. If the input is +a ByteBuffer, the caller should first manually resize the associated input +tensor (via `Interpreter.resizeInput()`) before running inference. -The use of `ByteBuffer` is preferred since it allows the `Interpreter` to avoid -unnecessary copies. Each `ByteBuffer` needs to be a direct byte buffer, and its -order must be `ByteOrder.nativeOrder()`. After it is used for a model inference, -it must remain unchanged until the model inference is finished. +When using 'ByteBuffer', prefer using direct byte buffers, as this allows the +`Interpreter` to avoid unnecessary copies. If the `ByteBuffer` is a direct byte +buffer, its order must be `ByteOrder.nativeOrder()`. After it is used for a +model inference, it must remain unchanged until the model inference is finished. #### Outputs -Each output should be an array, or a multi-dimensional array of the supported -primitive types. +Each output should be an array or multi-dimensional array of the supported +primitive types, or a ByteBuffer of the appropriate size. Note that some models +have dynamic outputs, where the shape of output tensors can vary depending on +the input. There's no straightforward way of handling this with the existing +Java inference API, but planned extensions will make this possible. #### Running Model Inference @@ -339,9 +347,10 @@ interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); where each entry in `inputs` corresponds to an input tensor and `map_of_indices_to_outputs` maps indices of output tensors to the corresponding output data. In both cases the tensor indices should correspond to -the values given to the `TensorFlow Lite Optimized Converter` when the model was -created. Be aware that the order of tensors in `input` must match the order -given to the `TensorFlow Lite Optimized Converter`. +the values given to the [TensorFlow Lite Optimized Converter](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md) +when the model was created. Be aware that the order of tensors in `input` must +match the order given to the `TensorFlow Lite Optimized Converter`. + The Java API also provides convenient functions for app developers to get the index of any model input or output using a tensor name: -- GitLab From 1a25a8e610db416225e4e7373337a0f47dd6e87e Mon Sep 17 00:00:00 2001 From: Jian Li Date: Tue, 4 Sep 2018 12:46:13 -0700 Subject: [PATCH 062/540] Create layer norm LSTM custom Op. PiperOrigin-RevId: 211505721 --- tensorflow/contrib/lite/kernels/BUILD | 15 + .../internal/optimized/neon_tensor_utils.h | 12 + .../internal/optimized/tensor_utils_impl.h | 8 + .../reference/portable_tensor_utils.cc | 36 + .../reference/portable_tensor_utils.h | 22 + .../lite/kernels/internal/tensor_utils.h | 10 + .../kernels/internal/tensor_utils_test.cc | 90 ++ .../contrib/lite/kernels/layer_norm_lstm.cc | 1316 +++++++++++++++++ .../lite/kernels/layer_norm_lstm_test.cc | 664 +++++++++ tensorflow/contrib/lite/kernels/register.cc | 2 + 10 files changed, 2175 insertions(+) create mode 100644 tensorflow/contrib/lite/kernels/layer_norm_lstm.cc create mode 100644 tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 8287115f5c..ca66fa6aa0 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -177,6 +177,7 @@ cc_library( "gather.cc", "hashtable_lookup.cc", "l2norm.cc", + "layer_norm_lstm.cc", "local_response_norm.cc", "logical.cc", "lsh_projection.cc", @@ -903,6 +904,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "layer_norm_lstm_test", + size = "small", + srcs = ["layer_norm_lstm_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + tf_cc_test( name = "lstm_test", size = "small", diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index e671624fe7..5ca1b4b76f 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -79,6 +79,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1, n_batch, result, result_stride); } +void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch, + float* batch_vector) { + PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector); +} + void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, float* batch_vector) { PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector); @@ -138,6 +143,13 @@ void ReductionSumVector(const float* input_vector, float* output_vector, reduction_size); } +void MeanStddevNormalization(const float* input_vector, float* output_vector, + int v_size, int n_batch, + float normalization_epsilon) { + PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch, + normalization_epsilon); +} + } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index 8664ebc4f6..7e53dc2fa2 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -117,6 +117,10 @@ void PortableClipVector(const float* vector, int v_size, float abs_limit, void NeonClipVector(const float* vector, int v_size, float abs_limit, float* result); +// Add another vector for each batch in the batch vector. +void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch, + float* batch_vector); + // Batch vector initialization with another vector. void PortableVectorBatchVectorAssign(const float* vector, int v_size, int n_batch, float* batch_vector); @@ -172,6 +176,10 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector, void NeonReductionSumVector(const float* input_vector, float* output_vector, int output_size, int reduction_size); +void PortableMeanStddevNormalization(const float* input_vector, + float* output_vector, int v_size, + int n_batch, float normalization_epsilon); + } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc index e79e75a898..2a30910c3f 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -173,6 +173,16 @@ void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, } } +void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch, + float* batch_vector) { + for (int b = 0; b < n_batch; b++) { + for (int i = 0; i < v_size; ++i) { + batch_vector[i] += vector[i]; + } + batch_vector += v_size; + } +} + void PortableVectorBatchVectorAssign(const float* vector, int v_size, int n_batch, float* batch_vector) { for (int b = 0; b < n_batch; b++) { @@ -243,5 +253,31 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector, } } +void PortableMeanStddevNormalization(const float* input_vector, + float* output_vector, int v_size, + int n_batch, float normalization_epsilon) { + for (int batch = 0; batch < n_batch; ++batch) { + float sum = 0.0f; + float sum_sq = 0.0f; + for (int i = 0; i < v_size; ++i) { + sum += input_vector[i]; + sum_sq += input_vector[i] * input_vector[i]; + } + const float mean = sum / v_size; + float stddev_inv = 0.0f; + const float variance = sum_sq / v_size - mean * mean; + if (variance == 0) { + stddev_inv = 1.0f / sqrt(normalization_epsilon); + } else { + stddev_inv = 1.0f / sqrt(variance); + } + for (int i = 0; i < v_size; ++i) { + output_vector[i] = (input_vector[i] - mean) * stddev_inv; + } + input_vector += v_size; + output_vector += v_size; + } +} + } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index 3829be0c5e..f5b3a84f07 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -87,6 +87,10 @@ void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, void PortableVectorBatchVectorAssign(const float* vector, int v_size, int n_batch, float* batch_vector); +// Add another vector for each batch in the batch vector. +void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch, + float* batch_vector); + // Apply sigmoid to elements of a vector. void PortableApplySigmoidToVector(const float* vector, int v_size, float* result); @@ -125,6 +129,12 @@ void PortableVectorShiftLeft(float* vector, int v_size, float shift_value); void PortableReductionSumVector(const float* input_vector, float* output_vector, int output_size, int reduction_size); +// Layer norm for each batch. +// normalization_epsilon is added to avoid divergence. +void PortableMeanStddevNormalization(const float* input_vector, + float* output_vector, int v_size, + int n_batch, float normalization_epsilon); + float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } bool IsZeroVector(const float* vector, int v_size) { @@ -193,6 +203,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1, result, result_stride); } +void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch, + float* batch_vector) { + PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector); +} + void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, float* batch_vector) { PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector); @@ -240,6 +255,13 @@ void ReductionSumVector(const float* input_vector, float* output_vector, reduction_size); } +void MeanStddevNormalization(const float* input_vector, float* output_vector, + int v_size, int n_batch, + float normalization_epsilon) { + PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch, + normalization_epsilon); +} + } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index 748356d1bd..1439bf8c37 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -113,6 +113,10 @@ void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, const float* batch_vector, int n_batch, float* result); +// Add another vector for each batch in the batch vector. +void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch, + float* batch_vector); + // Batch vector initialization with another vector. void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, float* batch_vector); @@ -152,6 +156,12 @@ void VectorShiftLeft(float* vector, int v_size, float shift_value); // added to get one element of output. void ReductionSumVector(const float* input_vector, float* output_vector, int output_size, int reduction_size); + +// Layer norm for each batch. +// normalization_epsilon is added to avoid divergence. +void MeanStddevNormalization(const float* input_vector, float* output_vector, + int v_size, int n_batch, + float normalization_epsilon); } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc index 240fb64ca3..dad924fc28 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -496,6 +496,16 @@ TEST(uKernels, VectorVectorCwiseProductAccumulateTest) { {1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45}))); } +TEST(uKernels, VectorBatchVectorAddTest) { + constexpr int kVectorSize = 3; + constexpr int kBatchSize = 2; + static float input[kVectorSize] = {0.0, -0.5, 1.0}; + std::vector output = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + VectorBatchVectorAdd(input, kVectorSize, kBatchSize, output.data()); + EXPECT_THAT(output, + testing::ElementsAreArray({1.0, 1.5, 4.0, 4.0, 4.5, 7.0})); +} + TEST(uKernels, VectorBatchVectorAssignTest) { constexpr int kVectorSize = 5; constexpr int kBatchSize = 3; @@ -712,5 +722,85 @@ TEST(uKernels, ReductionSumVectorTest) { EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5}))); } +TEST(uKernels, MeanStddevNormalizationNoneZeroInput) { + constexpr int kVectorSize = 4; + constexpr int kBatchSize = 2; + constexpr float kNormalizationEpsilon = 1e-8; + + // None-zero input. + static float input[kVectorSize * kBatchSize] = { + 0.1, 0.2, 0.3, 0.4, // batch 0 + 0.9, 1.0, 1.1, 1.2, // batch 1 + }; + std::vector output(kVectorSize * kBatchSize); + MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize, + kNormalizationEpsilon); + const std::vector expected_output = { + -1.34164071, -0.447213531, 0.44721365, 1.34164071, // batch 0 + -1.34163153, -0.447210163, 0.447211236, 1.3416326, // batch 1 + }; + EXPECT_THAT(output, testing::ElementsAreArray(expected_output)); +} + +TEST(uKernels, MeanStddevNormalizationAllZeroInput) { + constexpr int kVectorSize = 4; + constexpr int kBatchSize = 2; + constexpr float kNormalizationEpsilon = 1e-8; + + // Zero input. + static float input[kVectorSize * kBatchSize] = { + 0.0, 0.0, 0.0, 0.0, // batch 0 + 0.0, 0.0, 0.0, 0.0, // batch 1 + }; + std::vector output(kVectorSize * kBatchSize); + MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize, + kNormalizationEpsilon); + const std::vector expected_output = { + 0.0, 0.0, 0.0, 0.0, // batch 0 + 0.0, 0.0, 0.0, 0.0, // batch 1 + }; + EXPECT_THAT(output, testing::ElementsAreArray(expected_output)); +} + +TEST(uKernels, MeanStddevNormalizationMixed) { + constexpr int kVectorSize = 4; + constexpr int kBatchSize = 2; + constexpr float kNormalizationEpsilon = 1e-8; + + // Mix of zero and non-zero input. + static float input[kVectorSize * kBatchSize] = { + 0.0, 0.0, 0.0, 0.0, // batch 0 + 0.1, 0.2, 0.3, 0.4, // batch 1 + }; + std::vector output(kVectorSize * kBatchSize); + MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize, + kNormalizationEpsilon); + const std::vector expected_output = { + 0.0, 0.0, 0.0, 0.0, // batch 0 + -1.34164071, -0.447213531, 0.44721365, 1.34164071, // batch 1 + }; + EXPECT_THAT(output, testing::ElementsAreArray(expected_output)); +} + +TEST(uKernels, MeanStddevNormalizationSmallValue) { + constexpr int kVectorSize = 4; + constexpr int kBatchSize = 2; + constexpr float kNormalizationEpsilon = 1e-8; + + // Mix of zero and non-zero input. + static float input[kVectorSize * kBatchSize] = { + 3e-5, -7e-6, -9e-5, 1e-6, // batch 0 + 4e-5, 9e-6, 2e-4, 0.0, // batch 1 + }; + std::vector output(kVectorSize * kBatchSize); + MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize, + kNormalizationEpsilon); + const std::vector expected_output = { + 1.04231524, 0.212946132, -1.64753067, 0.392269224, // batch 0 + -0.275023013, -0.658201098, 1.70267045, -0.769446373, // batch 1 + }; + EXPECT_THAT(output, testing::ElementsAreArray(expected_output)); +} + } // namespace tensor_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc new file mode 100644 index 0000000000..1bbea67b93 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc @@ -0,0 +1,1316 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Layer Normalization LSTM op that applies normalization by mean and standard +// deviation to the activation of the LSTM layers. Please see +// https://arxiv.org/abs/1607.06450 for details. +#include "flatbuffers/flexbuffers.h" // flatbuffers +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace layer_norm_lstm { + +// Struct to hold Layer Norm LSTM option data. +struct OpData { + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + int scratch_tensor_index; +}; + +// Input Tensors of size {n_batch, n_input} +constexpr int kInputTensor = 0; + +// Input weight tensors of size: {n_cell, n_input} +constexpr int kInputToInputWeightsTensor = 1; // Optional +constexpr int kInputToForgetWeightsTensor = 2; +constexpr int kInputToCellWeightsTensor = 3; +constexpr int kInputToOutputWeightsTensor = 4; + +// Recurrent weight tensors of size {n_cell, n_output} +constexpr int kRecurrentToInputWeightsTensor = 5; // Optional +constexpr int kRecurrentToForgetWeightsTensor = 6; +constexpr int kRecurrentToCellWeightsTensor = 7; +constexpr int kRecurrentToOutputWeightsTensor = 8; + +// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kCellToInputWeightsTensor = 9; // Optional +constexpr int kCellToForgetWeightsTensor = 10; // Optional +constexpr int kCellToOutputWeightsTensor = 11; // Optional + +// Layer norm weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kInputLayerNormWeightsTensor = 12; +constexpr int kForgetLayerNormWeightsTensor = 13; +constexpr int kCellLayerNormWeightsTensor = 14; +constexpr int kOutputLayerNormWeightsTensor = 15; + +// Gates bias tensors of size {n_cell} +constexpr int kInputGateBiasTensor = 16; // Optional +constexpr int kForgetGateBiasTensor = 17; +constexpr int kCellGateBiasTensor = 18; +constexpr int kOutputGateBiasTensor = 19; + +// Projection weight tensor of size {n_output, n_cell} +constexpr int kProjectionWeightsTensor = 20; // Optional +// Projection bias tensor of size {n_output} +constexpr int kProjectionBiasTensor = 21; // Optional + +// State tensors. +constexpr int kInputActivationStateTensor = 22; +constexpr int kInputCellStateTensor = 23; + +// Output tensor. +constexpr int kOutputTensor = 0; + +// Total number of scratch tensors for hybrid Op. +constexpr int kTensorsToAdd = 7; + +// Small float to avoid divergence during calculation of deviation. +const float kLayerNormEpsilon = 1e-8; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + + // Turn custom option data into flexbuffer map format. + const uint8_t* buffer_t = reinterpret_cast(buffer); + const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); + + // Get activation function, cell_clip and proj_clip from the flexbuffer. + // TODO(b/113824099): make activation more generic. + assert(m["fused_activation_function"].ToString() == "TANH"); + data->activation = kTfLiteActTanh; + data->cell_clip = m["cell_clip"].AsFloat(); + data->proj_clip = m["proj_clip"].AsFloat(); + + // Populate scratch_tensor_index. + context->AddTensors(context, /*tensors_to_add=*/kTensorsToAdd, + &data->scratch_tensor_index); + return data; +} + +// Check that input tensor dimensions matches with each other. +TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, + TfLiteNode* node, int n_input, + int n_output, int n_cell) { + const OpData* op_data = reinterpret_cast(node->user_data); + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + TF_LITE_ENSURE(context, op_data->cell_clip >= 0); + TF_LITE_ENSURE(context, op_data->proj_clip >= 0); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + if (input_to_input_weights != nullptr) { + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); + } + + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); + + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + if (recurrent_to_input_weights != nullptr) { + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], + n_output); + } + + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], + n_output); + + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], + n_output); + + // We make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). + const bool cifg_weights_all_or_none = + ((input_to_input_weights != nullptr) && + (recurrent_to_input_weights != nullptr)) || + ((input_to_input_weights == nullptr) && + (recurrent_to_input_weights == nullptr)); + TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + if (cell_to_input_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); + } + + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + if (cell_to_forget_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); + } + + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + if (cell_to_output_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); + } + + // Making sure the peephole weights are there all or none. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool peephole_weights_all_or_none = + ((cell_to_input_weights != nullptr || use_cifg) && + (cell_to_forget_weights != nullptr) && + (cell_to_output_weights != nullptr)) || + ((cell_to_input_weights == nullptr) && + (cell_to_forget_weights == nullptr) && + (cell_to_output_weights == nullptr)); + TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); + + // Making sure layer norm weights are not null and have the right dimension. + const TfLiteTensor* input_layer_norm_weights = + GetInput(context, node, kInputLayerNormWeightsTensor); + TF_LITE_ENSURE(context, input_layer_norm_weights != nullptr); + TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->data[0], n_cell); + + const TfLiteTensor* forget_layer_norm_weights = + GetInput(context, node, kForgetLayerNormWeightsTensor); + TF_LITE_ENSURE(context, forget_layer_norm_weights != nullptr); + TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->data[0], n_cell); + + const TfLiteTensor* cell_layer_norm_weights = + GetInput(context, node, kCellLayerNormWeightsTensor); + TF_LITE_ENSURE(context, cell_layer_norm_weights != nullptr); + TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->data[0], n_cell); + + const TfLiteTensor* output_layer_norm_weights = + GetInput(context, node, kOutputLayerNormWeightsTensor); + TF_LITE_ENSURE(context, output_layer_norm_weights != nullptr); + TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->data[0], n_cell); + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + } else { + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); + } + + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); + + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); + + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + if (projection_weights != nullptr) { + TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); + } + + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + if (projection_bias != nullptr) { + TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + const bool projection_tensors_consistent = + ((projection_weights != nullptr) || (projection_bias == nullptr)); + TF_LITE_ENSURE(context, projection_tensors_consistent == true); + + return kTfLiteOk; +} + +// Resize the output, state tensors based on the sizes of the input tensors. +// Allocate a temporary scratch tensor. Also check that the sizes of the input +// tensors match each other. +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* op_data = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 24); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + // Inferring batch size, number of outputs and number of cells from the + // input tensors. + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE(context, input->dims->size > 1); + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + const int n_cell = input_to_output_weights->dims->data[0]; + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); + + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], + n_cell); + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Check that input tensor dimensions matches with each other. + TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input, + n_output, n_cell)); + + // Get the pointer to output, activation_state and cell_state tensors. + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + const TfLiteTensor* activation_state = + GetInput(context, node, kInputActivationStateTensor); + const TfLiteTensor* cell_state = + GetInput(context, node, kInputCellStateTensor); + + // Check the shape of input state tensors. + // These tensor may be 1D or 2D. It's fine as long as the total size is + // correct. + TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output); + TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); + // Resize the output tensors. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); + output_size->data[0] = n_batch; + output_size->data[1] = n_output; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + + // The weights are of consistent type, so it suffices to check one. + const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && + input->type == kTfLiteFloat32); + + TfLiteIntArrayFree(node->temporaries); + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(7); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } + node->temporaries->data[0] = op_data->scratch_tensor_index; + + // Create a scratch buffer tensor. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + scratch_buffer->type = input->type; + scratch_buffer->allocation_type = kTfLiteArenaRw; + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const bool use_cifg = (input_to_input_weights == nullptr); + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + if (use_cifg) { + // Reserving space for Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 3; + } else { + // Reserving space for Input, Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 4; + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // activation_state and cell_state tensors. + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; + TfLiteTensor* activation_state_quantized = + GetTemporary(context, node, /*index=*/2); + activation_state_quantized->type = kTfLiteUInt8; + activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(activation_state_quantized->dims, + activation_state->dims)) { + TfLiteIntArray* activation_state_quantized_size = + TfLiteIntArrayCopy(activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, activation_state_quantized, + activation_state_quantized_size)); + } + node->temporaries->data[3] = op_data->scratch_tensor_index + 3; + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + cell_state_quantized->type = kTfLiteUInt8; + cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { + TfLiteIntArray* cell_state_quantized_size = + TfLiteIntArrayCopy(cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state_quantized, + cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[4] = op_data->scratch_tensor_index + 4; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[5] = op_data->scratch_tensor_index + 5; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[6] = op_data->scratch_tensor_index + 6; + TfLiteTensor* recovered_weights = GetTemporary(context, node, /*index=*/6); + recovered_weights->type = kTfLiteFloat32; + recovered_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_weights_size = TfLiteIntArrayCreate(1); + recovered_weights_size->data[0] = n_cell; + if (!TfLiteIntArrayEqual(recovered_weights->dims, recovered_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_weights, + recovered_weights_size)); + } + } + return kTfLiteOk; +} + +void LayerNormLstmStep( + const float* input_ptr_batch, const float* input_to_input_weights_ptr, + const float* input_to_forget_weights_ptr, + const float* input_to_cell_weights_ptr, + const float* input_to_output_weights_ptr, + const float* recurrent_to_input_weights_ptr, + const float* recurrent_to_forget_weights_ptr, + const float* recurrent_to_cell_weights_ptr, + const float* recurrent_to_output_weights_ptr, + const float* cell_to_input_weights_ptr, + const float* cell_to_forget_weights_ptr, + const float* cell_to_output_weights_ptr, + const float* input_layer_norm_weight_ptr, + const float* forget_layer_norm_weight_ptr, + const float* cell_layer_norm_weight_ptr, + const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const float* projection_weights_ptr, + const float* projection_bias_ptr, float cell_clip, float proj_clip, + const TfLiteFusedActivation& activation, int n_batch, int n_cell, + int n_input, int n_output, float* output_state_ptr, float* cell_state_ptr, + float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, + float* output_gate_scratch, float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + + // Initialize scratch buffers with 0. + if (!use_cifg) { + tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch); + } + tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch); + tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch); + tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch); + + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, forget_gate_scratch, + /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, output_gate_scratch, + /*result_stride=*/1); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::MeanStddevNormalization(input_gate_scratch, + input_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr, + n_cell, input_gate_scratch, + n_batch, input_gate_scratch); + tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::MeanStddevNormalization(forget_gate_scratch, + forget_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr, + n_cell, forget_gate_scratch, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, + n_batch, kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct( + cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch); + tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip, + cell_state_ptr); + } + + // For each batch and cell: update the output gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::MeanStddevNormalization(output_gate_scratch, + output_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr, + n_cell, output_gate_scratch, + n_batch, output_gate_scratch); + tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch, + output_ptr_batch, /*result_stride=*/1); + if (proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip, + output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); + } + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); +} + +void LayerNormLstmStep( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, + const float* input_layer_norm_weight_ptr, + const float* forget_layer_norm_weight_ptr, + const float* cell_layer_norm_weight_ptr, + const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + float cell_clip, float proj_clip, const TfLiteFusedActivation& activation, + int n_batch, int n_cell, int n_input, int n_output, + float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, + float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + + // Initialize scratch buffers with 0. + if (!use_cifg) { + tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch); + } + tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch); + tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch); + tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch); + + if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_input; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, + &unused_min, &unused_max, &scaling_factors[b]); + } + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, + quantized_input_ptr_batch, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, forget_gate_scratch, + /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, output_gate_scratch, + /*result_stride=*/1); + } + + if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_output; + tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, + &unused_min, &unused_max, + &scaling_factors[b]); + } + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + forget_gate_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + output_gate_scratch, /*result_stride=*/1); + } + + // Save quantization and matmul computation for all zero input. + bool is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole && !is_cell_state_all_zeros) { + tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell, + cell_to_input_weights_scale, + recovered_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_weights, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::MeanStddevNormalization(input_gate_scratch, + input_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr, + n_cell, input_gate_scratch, + n_batch, input_gate_scratch); + tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole && !is_cell_state_all_zeros) { + tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell, + cell_to_forget_weights_scale, + recovered_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_weights, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::MeanStddevNormalization(forget_gate_scratch, + forget_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr, + n_cell, forget_gate_scratch, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, + n_batch, kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct( + cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch); + tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip, + cell_state_ptr); + } + + is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + // For each batch and cell: update the output gate. + if (use_peephole && !is_cell_state_all_zeros) { + tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell, + cell_to_output_weights_scale, + recovered_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_weights, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::MeanStddevNormalization(output_gate_scratch, + output_gate_scratch, n_cell, n_batch, + kLayerNormEpsilon); + tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr, + n_cell, output_gate_scratch, + n_batch, output_gate_scratch); + tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_cell; + tensor_utils::SymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * projection_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr, + product_scaling_factors, n_batch, output_ptr_batch, + /*result_stride=*/1); + } + if (proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip, + output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); + } + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); +} + +// The LayerNormLSTM Op engine. +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_layer_norm_weights, + const TfLiteTensor* forget_layer_norm_weights, + const TfLiteTensor* cell_layer_norm_weights, + const TfLiteTensor* output_layer_norm_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + float cell_clip, float proj_clip, const TfLiteFusedActivation& activation, + TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + const float* input_to_input_weights_ptr = + (use_cifg) ? nullptr : input_to_input_weights->data.f; + const float* recurrent_to_input_weights_ptr = + (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; + const float* input_gate_bias_ptr = + (use_cifg) ? nullptr : input_gate_bias->data.f; + const float* cell_to_input_weights_ptr = + (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; + const float* cell_to_forget_weights_ptr = + (use_peephole) ? cell_to_forget_weights->data.f : nullptr; + const float* cell_to_output_weights_ptr = + (use_peephole) ? cell_to_output_weights->data.f : nullptr; + const float* projection_weights_ptr = + (projection_weights == nullptr) ? nullptr : projection_weights->data.f; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const float* input_ptr_batch = input->data.f; + const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f; + const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f; + const float* input_to_output_weights_ptr = input_to_output_weights->data.f; + const float* recurrent_to_forget_weights_ptr = + recurrent_to_forget_weights->data.f; + const float* recurrent_to_cell_weights_ptr = + recurrent_to_cell_weights->data.f; + const float* recurrent_to_output_weights_ptr = + recurrent_to_output_weights->data.f; + const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f; + const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f; + const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f; + const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* activation_state_ptr = activation_state->data.f; + float* cell_state_ptr = cell_state->data.f; + float* output_ptr_batch = output->data.f; + + LayerNormLstmStep( + input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr, + input_to_cell_weights_ptr, input_to_output_weights_ptr, + recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr, + recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr, + cell_to_input_weights_ptr, cell_to_forget_weights_ptr, + cell_to_output_weights_ptr, input_layer_norm_weight_ptr, + forget_layer_norm_weight_ptr, cell_layer_norm_weight_ptr, + output_layer_norm_weight_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, + cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, + projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell, + n_input, n_output, activation_state_ptr, cell_state_ptr, + input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, output_ptr_batch); + + return kTfLiteOk; +} + +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_layer_norm_weights, + const TfLiteTensor* forget_layer_norm_weights, + const TfLiteTensor* cell_layer_norm_weights, + const TfLiteTensor* output_layer_norm_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + float cell_clip, float proj_clip, const TfLiteFusedActivation& activation, + TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, + TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_weights, + TfLiteTensor* input_quantized, TfLiteTensor* activation_state_quantized, + TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast(projection_weights->data.uint8); + const float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const float* input_ptr_batch = input->data.f; + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f; + const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f; + const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f; + const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* activation_state_ptr = activation_state->data.f; + float* cell_state_ptr = cell_state->data.f; + float* output_ptr_batch = output->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast(input_quantized->data.uint8); + int8_t* quantized_activation_state_ptr = + reinterpret_cast(activation_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_weights_ptr = recovered_weights->data.f; + + LayerNormLstmStep( + input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_layer_norm_weight_ptr, forget_layer_norm_weight_ptr, + cell_layer_norm_weight_ptr, output_layer_norm_weight_ptr, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, + projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell, + n_input, n_output, input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, + recovered_weights_ptr, quantized_input_ptr, + quantized_activation_state_ptr, quantized_cell_state_ptr, + activation_state_ptr, cell_state_ptr, output_ptr_batch); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const OpData* op_data = reinterpret_cast(node->user_data); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + const TfLiteTensor* input_layer_norm_weights = + GetInput(context, node, kInputLayerNormWeightsTensor); + const TfLiteTensor* forget_layer_norm_weights = + GetInput(context, node, kForgetLayerNormWeightsTensor); + const TfLiteTensor* cell_layer_norm_weights = + GetInput(context, node, kCellLayerNormWeightsTensor); + const TfLiteTensor* output_layer_norm_weights = + GetInput(context, node, kOutputLayerNormWeightsTensor); + + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* activation_state = + &context->tensors[node->inputs->data[kInputActivationStateTensor]]; + TfLiteTensor* cell_state = + &context->tensors[node->inputs->data[kInputCellStateTensor]]; + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input_to_output_weights->type) { + case kTfLiteFloat32: { + return EvalFloat(input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_layer_norm_weights, + forget_layer_norm_weights, cell_layer_norm_weights, + output_layer_norm_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, op_data->cell_clip, + op_data->proj_clip, op_data->activation, scratch_buffer, + activation_state, cell_state, output); + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* activation_state_quantized = + GetTemporary(context, node, /*index=*/2); + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + TfLiteTensor* recovered_weights = + GetTemporary(context, node, /*index=*/6); + return EvalHybrid( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_layer_norm_weights, forget_layer_norm_weights, + cell_layer_norm_weights, output_layer_norm_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, projection_weights, + projection_bias, op_data->cell_clip, op_data->proj_clip, + op_data->activation, scratch_buffer, scaling_factors, + prod_scaling_factors, recovered_weights, input_quantized, + activation_state_quantized, cell_state_quantized, activation_state, + cell_state, output); + } + default: + context->ReportError(context, "Type %d is not currently supported.", + input_to_output_weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +} // namespace layer_norm_lstm + +TfLiteRegistration* Register_LAYER_NORM_LSTM() { + static TfLiteRegistration r = {layer_norm_lstm::Init, layer_norm_lstm::Free, + layer_norm_lstm::Prepare, + layer_norm_lstm::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc new file mode 100644 index 0000000000..abc229f85a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc @@ -0,0 +1,664 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Layer Norm LSTM op. + +#include +#include + +#include +#include +#include "flatbuffers/flexbuffers.h" // flatbuffers +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_LAYER_NORM_LSTM(); + +namespace { + +using ::testing::ElementsAreArray; + +class LayerNormLSTMOpModel : public SingleOpModel { + public: + LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + bool use_cifg, bool use_peephole, + bool use_projection_weights, bool use_projection_bias, + float cell_clip, float proj_clip, + const std::vector>& input_shapes, + const TensorType& weight_type = TensorType_FLOAT32) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(weight_type); + } + + input_to_forget_weights_ = AddInput(weight_type); + input_to_cell_weights_ = AddInput(weight_type); + input_to_output_weights_ = AddInput(weight_type); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(weight_type); + } + + recurrent_to_forget_weights_ = AddInput(weight_type); + recurrent_to_cell_weights_ = AddInput(weight_type); + recurrent_to_output_weights_ = AddInput(weight_type); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(weight_type); + } + cell_to_forget_weights_ = AddInput(weight_type); + cell_to_output_weights_ = AddInput(weight_type); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + input_layer_norm_weights_ = AddInput(TensorType_FLOAT32); + forget_layer_norm_weights_ = AddInput(TensorType_FLOAT32); + cell_layer_norm_weights_ = AddInput(TensorType_FLOAT32); + output_layer_norm_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(weight_type); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + // Adding the 2 state tensors. + output_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true); + cell_state_ = + AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true); + + output_ = AddOutput(TensorType_FLOAT32); + + // Set up and pass in custom options using flexbuffer. + flexbuffers::Builder fbb; + fbb.Map([&]() { + fbb.Int("cell_clip", cell_clip); + fbb.Int("proj_clip", proj_clip); + fbb.String("fused_activation_function", "TANH"); + }); + fbb.Finish(); + SetCustomOp("LAYER_NORM_LSTM", fbb.GetBuffer(), Register_LAYER_NORM_LSTM); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputLayerNormWeights(std::initializer_list f) { + PopulateTensor(input_layer_norm_weights_, f); + } + + void SetForgetLayerNormWeights(std::initializer_list f) { + PopulateTensor(forget_layer_norm_weights_, f); + } + + void SetCellLayerNormWeights(std::initializer_list f) { + PopulateTensor(cell_layer_norm_weights_, f); + } + + void SetOutputLayerNormWeights(std::initializer_list f) { + PopulateTensor(output_layer_norm_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + protected: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_layer_norm_weights_; + int forget_layer_norm_weights_; + int cell_layer_norm_weights_; + int output_layer_norm_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_state_; + int cell_state_; + + int output_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + +class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel { + public: + HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + bool use_cifg, bool use_peephole, + bool use_projection_weights, + bool use_projection_bias, float cell_clip, + float proj_clip, + const std::vector>& input_shapes) + : LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, + use_peephole, use_projection_weights, + use_projection_bias, cell_clip, proj_clip, + input_shapes, TensorType_UINT8) {} + + void SetInputToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetInputLayerNormWeights(std::initializer_list f) { + PopulateTensor(input_layer_norm_weights_, f); + } + + void SetForgetLayerNormWeights(std::initializer_list f) { + PopulateTensor(forget_layer_norm_weights_, f); + } + + void SetCellLayerNormWeights(std::initializer_list f) { + PopulateTensor(cell_layer_norm_weights_, f); + } + + void SetOutputLayerNormWeights(std::initializer_list f) { + PopulateTensor(output_layer_norm_weights_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(projection_weights_, f); + } +}; + +class BaseLayerNormLstmTest : public ::testing::Test { + protected: + // Weights of the Layer Norm LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list input_layer_norm_weights_; + std::initializer_list forget_layer_norm_weights_; + std::initializer_list cell_layer_norm_weights_; + std::initializer_list output_layer_norm_weights_; + std::initializer_list projection_weights_; + + // Layer Norm LSTM input is stored as num_batch x num_inputs vector. + std::vector> layer_norm_lstm_input_; + + // Compares output up to tolerance to the result of the layer_norm_lstm given + // the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + LayerNormLSTMOpModel* layer_norm_lstm, + float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = layer_norm_lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(), + batch_start, batch_end); + } + + layer_norm_lstm->Invoke(); + + const int num_outputs = layer_norm_lstm->num_outputs(); + std::vector expected; + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + EXPECT_THAT(layer_norm_lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + } + } +}; + +class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest + : public BaseLayerNormLstmTest { + void SetUp() override { + input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, + 0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5, + -0.4, -0.5, -0.4, -0.3, -0.2, -0.1}; + + input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, + -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4, + -0.6, 0.3, -0.4, -0.6, -0.5, -0.5}; + + input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, + -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3, + -0.7, 0.7, -0.9, -0.5, 0.8, 0.6}; + + input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, + -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7, + -0.3, -0.5, 0.1, 0.5, -0.6, -0.4}; + + input_gate_bias_ = {0.03, 0.15, 0.22, 0.38}; + + forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1}; + + cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08}; + + output_gate_bias_ = {0.05, -0.01, 0.2, 0.1}; + + recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, + -0.2, -0.3, -0.7, 0.05, -0.2, -0.6}; + + recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, + -0.2, 0.3, 0.8, -0.6, -0.1, 0.2}; + + recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, + 0.9, 0.3, -0.1, 0.2, 0.5, 0.2}; + + recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, + -0.2, -0.6, -0.1, -0.4, -0.7, -0.2}; + + cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15}; + + cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03}; + + cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05}; + + input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5}; + forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3}; + cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8}; + output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5}; + + projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, + 0.3, 0.08, 0.07, 0.2, -0.4, 0.2}; + + layer_norm_lstm_input_ = { + {// Batch0: 3 (input_sequence_size) * 5 (n_input) + 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0 + 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1 + 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2 + + {// Batch1: 3 (input_sequence_size) * 5 (n_input) + 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0 + 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1 + 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2 + }; + } +}; + +TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, + LayerNormLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const float ceil_clip = 0.0; + const float proj_clip = 0.0; + + LayerNormLSTMOpModel layer_norm_lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, ceil_clip, proj_clip, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_layer_norm_weight tensor + {n_cell}, // forget_layer_norm_weight tensor + {n_cell}, // cell_layer_norm_weight tensor + {n_cell}, // output_layer_norm_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_); + layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); + layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); + layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); + + layer_norm_lstm.SetInputGateBias(input_gate_bias_); + layer_norm_lstm.SetCellBias(cell_gate_bias_); + layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); + layer_norm_lstm.SetOutputGateBias(output_gate_bias_); + + layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_); + layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); + layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); + + layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_); + layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_); + layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_); + layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_); + + layer_norm_lstm.SetProjectionWeights(projection_weights_); + + // Verify the final output. + const std::vector> layer_norm_lstm_golden_output = { + { + // Batch0: 3 (input_sequence_size) * 3 (n_output) + 0.0244077, 0.128027, -0.00170918, // seq 0 + 0.0137642, 0.140751, 0.0395835, // seq 1 + -0.00459231, 0.155278, 0.0837377, // seq 2 + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output) + -0.00692428, 0.0848741, 0.063445, // seq 0 + -0.00403912, 0.139963, 0.072681, // seq 1 + 0.00752706, 0.161903, 0.0561371, // seq 2 + }}; + + VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output, + &layer_norm_lstm); +} + +TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest, + HybridLayerNormLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 4; + const int n_output = 3; + const float ceil_clip = 0.0; + const float proj_clip = 0.0; + + HybridLayerNormLSTMOpModel layer_norm_lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, ceil_clip, proj_clip, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_layer_norm_weight tensor + {n_cell}, // forget_layer_norm_weight tensor + {n_cell}, // cell_layer_norm_weight tensor + {n_cell}, // output_layer_norm_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_); + layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_); + layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_); + layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_); + + layer_norm_lstm.SetInputGateBias(input_gate_bias_); + layer_norm_lstm.SetCellBias(cell_gate_bias_); + layer_norm_lstm.SetForgetGateBias(forget_gate_bias_); + layer_norm_lstm.SetOutputGateBias(output_gate_bias_); + + layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_); + layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_); + layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_); + + layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_); + layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_); + layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_); + layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_); + + layer_norm_lstm.SetProjectionWeights(projection_weights_); + + const std::vector> layer_norm_lstm_golden_output = { + { + // Batch0: 3 (input_sequence_size) * 3 (n_output) + 0.0244576, 0.127847, -0.00181765, // seq 0 + 0.0137518, 0.140892, 0.0402234, // seq 1 + -0.0048839, 0.155096, 0.0840309, // seq 2 + }, + { + // Batch1: 3 (input_sequence_size) * 3 (n_output) + -0.00728636, 0.0843957, 0.0634786, // seq 0 + -0.00448382, 0.139278, 0.0737372, // seq 1 + 0.00734616, 0.161793, 0.0560238, // seq 2 + }}; + + VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output, + &layer_norm_lstm); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 7b859dc332..188015f43c 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -22,6 +22,7 @@ namespace ops { namespace custom { TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); +TfLiteRegistration* Register_LAYER_NORM_LSTM(); TfLiteRegistration* Register_MFCC(); TfLiteRegistration* Register_DETECTION_POSTPROCESS(); @@ -247,6 +248,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddCustom("Mfcc", tflite::ops::custom::Register_MFCC()); AddCustom("AudioSpectrogram", tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); + AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM()); AddCustom("TFLite_Detection_PostProcess", tflite::ops::custom::Register_DETECTION_POSTPROCESS()); } -- GitLab From c0a9c988f75082b0c8521c1343874c4ce9a10dd6 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 4 Sep 2018 13:09:17 -0700 Subject: [PATCH 063/540] Fix bugs in backward convolution benchmarks + add labels. PiperOrigin-RevId: 211510051 --- tensorflow/core/kernels/eigen_benchmark.h | 74 ++++++++++--------- .../core/kernels/eigen_benchmark_cpu_test.cc | 15 +++- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/tensorflow/core/kernels/eigen_benchmark.h b/tensorflow/core/kernels/eigen_benchmark.h index 46ad38fb77..87e41b89b3 100644 --- a/tensorflow/core/kernels/eigen_benchmark.h +++ b/tensorflow/core/kernels/eigen_benchmark.h @@ -76,6 +76,9 @@ class SpatialConvolutionBenchmarksSuite { void SpatialConvolutionBackwardInput(Dimensions input_dims, Dimensions filter_dims) { + using OutputBackward = TTypes::ConstTensor; + using InputBackward = TTypes::Tensor; + Dimensions output_dims(input_dims[0], // batch input_dims[1], // input_height input_dims[2], // input_width @@ -85,37 +88,37 @@ class SpatialConvolutionBenchmarksSuite { Eigen::Index input_rows = input_dims[1]; Eigen::Index input_cols = input_dims[2]; - Scalar* input_data = - static_cast(device_.allocate(BufferSize(input_dims))); Scalar* filter_data = static_cast(device_.allocate(BufferSize(filter_dims))); - Scalar* output_data = + Scalar* output_backward_data = static_cast(device_.allocate(BufferSize(output_dims))); + Scalar* input_backward_data = + static_cast(device_.allocate(BufferSize(input_dims))); - device_.memset(input_data, 123, BufferSize(input_dims)); device_.memset(filter_data, 123, BufferSize(filter_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); - Input input(input_data, input_dims); Filter filter(filter_data, filter_dims); - Output output(output_data, output_dims); + OutputBackward output_backward(output_backward_data, output_dims); + InputBackward input_backward(input_backward_data, input_dims); ::tensorflow::testing::StartTiming(); for (int i = 0; i < iters_; ++i) { - output.device(device_) = Eigen::SpatialConvolutionBackwardInput( - filter, input, input_rows, input_cols); - tensorflow::testing::DoNotOptimize(output); + input_backward.device(device_) = Eigen::SpatialConvolutionBackwardInput( + filter, output_backward, input_rows, input_cols); + tensorflow::testing::DoNotOptimize(input_backward); } ::tensorflow::testing::StopTiming(); - device_.deallocate(input_data); device_.deallocate(filter_data); - device_.deallocate(output_data); + device_.deallocate(output_backward_data); + device_.deallocate(input_backward_data); } void SpatialConvolutionBackwardKernel(Dimensions input_dims, Dimensions filter_dims) { using OutputBackward = TTypes::ConstTensor; - using FilterGrad = TTypes::Tensor; + using FilterBackward = TTypes::Tensor; Dimensions output_dims(input_dims[0], // batch input_dims[1], // input_height @@ -130,7 +133,7 @@ class SpatialConvolutionBenchmarksSuite { static_cast(device_.allocate(BufferSize(input_dims))); Scalar* output_backward_data = static_cast(device_.allocate(BufferSize(output_dims))); - Scalar* filter_data = + Scalar* filter_backward_data = static_cast(device_.allocate(BufferSize(filter_dims))); device_.memset(input_data, 123, BufferSize(input_dims)); @@ -138,19 +141,19 @@ class SpatialConvolutionBenchmarksSuite { Input input(input_data, input_dims); OutputBackward output_backward(output_backward_data, input_dims); - FilterGrad filter_grad(filter_data, filter_dims); + FilterBackward filter_backward(filter_backward_data, filter_dims); ::tensorflow::testing::StartTiming(); for (int i = 0; i < iters_; ++i) { - filter_grad.device(device_) = Eigen::SpatialConvolutionBackwardKernel( + filter_backward.device(device_) = Eigen::SpatialConvolutionBackwardKernel( input, output_backward, filter_rows, filter_cols); - tensorflow::testing::DoNotOptimize(filter_grad); + tensorflow::testing::DoNotOptimize(filter_backward); } ::tensorflow::testing::StopTiming(); device_.deallocate(input_data); device_.deallocate(output_backward_data); - device_.deallocate(filter_data); + device_.deallocate(filter_backward_data); } private: @@ -215,42 +218,45 @@ class CuboidConvolutionBenchmarksSuite { input_dims[3], // input_planes filter_dims[4]); // filter_count + using OutputBackward = TTypes::ConstTensor; + using InputBackward = TTypes::Tensor; + // Assuming that the convolution had SAME padding. Eigen::Index input_rows = input_dims[1]; Eigen::Index input_cols = input_dims[2]; Eigen::Index input_planes = input_dims[3]; - Scalar* input_data = - static_cast(device_.allocate(BufferSize(input_dims))); Scalar* filter_data = static_cast(device_.allocate(BufferSize(filter_dims))); - Scalar* output_data = + Scalar* output_backward_data = static_cast(device_.allocate(BufferSize(output_dims))); + Scalar* input_backward_data = + static_cast(device_.allocate(BufferSize(input_dims))); - device_.memset(input_data, 123, BufferSize(input_dims)); device_.memset(filter_data, 123, BufferSize(filter_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); - Input input(input_data, input_dims); Filter filter(filter_data, filter_dims); - Output output(output_data, output_dims); + OutputBackward output_backward(output_backward_data, output_dims); + InputBackward input_backward(input_backward_data, input_dims); ::tensorflow::testing::StartTiming(); for (int i = 0; i < iters_; ++i) { - output.device(device_) = Eigen::CuboidConvolutionBackwardInput( - filter, input, input_planes, input_rows, input_cols); - tensorflow::testing::DoNotOptimize(output); + input_backward.device(device_) = Eigen::CuboidConvolutionBackwardInput( + filter, output_backward, input_planes, input_rows, input_cols); + tensorflow::testing::DoNotOptimize(input_backward); } ::tensorflow::testing::StopTiming(); - device_.deallocate(input_data); device_.deallocate(filter_data); - device_.deallocate(output_data); + device_.deallocate(output_backward_data); + device_.deallocate(input_backward_data); } void CuboidConvolutionBackwardKernel(Dimensions input_dims, Dimensions filter_dims) { using OutputBackward = TTypes::ConstTensor; - using FilterGrad = TTypes::Tensor; + using FilterBackward = TTypes::Tensor; Dimensions output_dims(input_dims[0], // batch input_dims[1], // input_height @@ -267,7 +273,7 @@ class CuboidConvolutionBenchmarksSuite { static_cast(device_.allocate(BufferSize(input_dims))); Scalar* output_backward_data = static_cast(device_.allocate(BufferSize(output_dims))); - Scalar* filter_data = + Scalar* filter_backward_data = static_cast(device_.allocate(BufferSize(filter_dims))); device_.memset(input_data, 123, BufferSize(input_dims)); @@ -275,19 +281,19 @@ class CuboidConvolutionBenchmarksSuite { Input input(input_data, input_dims); OutputBackward output_backward(output_backward_data, output_dims); - FilterGrad filter_grad(filter_data, filter_dims); + FilterBackward filter_backward(filter_backward_data, filter_dims); ::tensorflow::testing::StartTiming(); for (int i = 0; i < iters_; ++i) { - filter_grad.device(device_) = Eigen::CuboidConvolutionBackwardKernel( + filter_backward.device(device_) = Eigen::CuboidConvolutionBackwardKernel( input, output_backward, filter_planes, filter_rows, filter_cols); - tensorflow::testing::DoNotOptimize(filter_grad); + tensorflow::testing::DoNotOptimize(filter_backward); } ::tensorflow::testing::StopTiming(); device_.deallocate(input_data); device_.deallocate(output_backward_data); - device_.deallocate(filter_data); + device_.deallocate(filter_backward_data); } private: diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc index 2a8308ef9a..7c2bbb8148 100644 --- a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc +++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc @@ -123,6 +123,7 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads, #define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL) \ static void BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \ FW)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW); \ } \ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW)) @@ -130,6 +131,7 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads, #define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL) \ static void BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, \ FH, FW)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW); \ } \ BENCHMARK( \ @@ -138,6 +140,7 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads, #define BM_SpatialConvolutionBwdKernel(NT, N, H, W, C, FC, FH, FW, LABEL) \ static void BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \ FH, FW)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ SpatialConvolutionBackwardKernel(iters, NT, N, H, W, C, FC, FH, FW); \ } \ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \ @@ -348,6 +351,7 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, #define BM_CuboidConvolution(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \ static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, \ FP)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ CuboidConvolution(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \ } \ BENCHMARK( \ @@ -356,6 +360,7 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, #define BM_CuboidConvolutionBwdInput(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \ static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \ FH, FW, FP)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ CuboidConvolutionBackwardInput(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \ } \ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \ @@ -365,6 +370,7 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, LABEL) \ static void BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, \ FC, FH, FW, FP)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ CuboidConvolutionBackwardKernel(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \ } \ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, FC, \ @@ -395,8 +401,11 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, BM_CuboidConvolutions(8, // batch size 25, 25, 25, 4, // input: height, width, panes, depth 16, 5, 5, 5, // filter: count, height, width, panes - "conv3d"); + "conv3d_depth4"); +BM_CuboidConvolutions(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); -BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d"); +BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4"); +BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); -BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d"); +BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4"); +BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); -- GitLab From bfdb7a408c1ea519df9f970220e36c89e8fe1cf3 Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Tue, 4 Sep 2018 13:29:25 -0700 Subject: [PATCH 064/540] Disable rtti for builtin TFLite kernels PiperOrigin-RevId: 211514002 --- tensorflow/contrib/lite/kernels/BUILD | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index ca66fa6aa0..ab989c5425 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -6,7 +6,7 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_opts_nortti_if_android") # Suppress warnings that are introduced by Eigen Tensor. EXTRA_EIGEN_COPTS = select({ @@ -147,7 +147,7 @@ tf_cc_test( ) cc_library( - name = "builtin_ops", + name = "builtin_op_kernels", srcs = [ "activations.cc", "add.cc", @@ -192,7 +192,6 @@ cc_library( "pooling.cc", "pow.cc", "reduce.cc", - "register.cc", "reshape.cc", "resize_bilinear.cc", "select.cc", @@ -217,9 +216,9 @@ cc_library( ], hdrs = [ "padding.h", - "register.h", ], - copts = tflite_copts() + EXTRA_EIGEN_COPTS, + copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS, + visibility = ["//visibility:private"], deps = [ ":activation_functor", ":eigen_support", @@ -243,6 +242,17 @@ cc_library( ], ) +cc_library( + name = "builtin_ops", + srcs = ["register.cc"], + hdrs = ["register.h"], + deps = [ + ":builtin_op_kernels", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:util", + ], +) + tf_cc_test( name = "audio_spectrogram_test", size = "small", -- GitLab From 3db96c74a414f1a5be2bc84b2c263ce84f1b998a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 13:30:49 -0700 Subject: [PATCH 065/540] Removed old dynamic learning rate support code. PiperOrigin-RevId: 211514287 --- tensorflow/contrib/tpu/proto/optimization_parameters.proto | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto index bf807af68b..cbf6809257 100644 --- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -18,8 +18,10 @@ message DynamicLearningRate { message LearningRate { oneof learning_rate { float constant = 1; - DynamicLearningRate dynamic = 2; + // DynamicLearningRate dynamic = 2; -- disabled while code is being + // rewritten. } + reserved 2; } message AdagradParameters { -- GitLab From 0cd9b3e41d993f505feb54ff0b086ffbb21b595d Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Tue, 4 Sep 2018 13:39:03 -0700 Subject: [PATCH 066/540] Support <4D tensor inputs for pad/padv2 Fixes #21266 PiperOrigin-RevId: 211515918 --- tensorflow/contrib/lite/kernels/pad.cc | 34 +++++++++---------- tensorflow/contrib/lite/kernels/pad_test.cc | 13 +++++-- .../contrib/lite/testing/generate_examples.py | 25 +++++++++++--- .../testing/generated_examples_zip_test.cc | 6 ---- 4 files changed, 49 insertions(+), 29 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc index 55bcf3b533..3bce05353d 100644 --- a/tensorflow/contrib/lite/kernels/pad.cc +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -92,8 +92,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { op_context.constant_values->type); } - // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. - TF_LITE_ENSURE_EQ(context, op_context.dims, 4); + // TODO(nupurgarg): Current implementations rely on the inputs being <= 4D. + TF_LITE_ENSURE(context, op_context.dims <= 4); // Exit early if paddings is a non-const tensor. Set output tensor to // dynamic so output size can be determined in Eval. @@ -134,21 +134,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { after_padding.push_back(paddings_data[idx * 2 + 1]); } -#define TF_LITE_PAD(type, scalar, pad_value) \ - TF_LITE_ENSURE_EQ(context, before_padding.size(), 4); \ - TF_LITE_ENSURE_EQ(context, after_padding.size(), 4); \ - tflite::PadParams op_params; \ - op_params.left_padding_count = 4; \ - op_params.right_padding_count = 4; \ - for (int i = 0; i < 4; ++i) { \ - op_params.left_padding[i] = before_padding[3 - i]; \ - op_params.right_padding[i] = after_padding[3 - i]; \ - } \ - const scalar pad_value_copy = pad_value; \ - \ - type::Pad(op_params, GetTensorShape(op_context.input), \ - GetTensorData(op_context.input), &pad_value_copy, \ - GetTensorShape(op_context.output), \ +#define TF_LITE_PAD(type, scalar, pad_value) \ + TF_LITE_ENSURE(context, before_padding.size() <= 4); \ + TF_LITE_ENSURE(context, after_padding.size() <= 4); \ + tflite::PadParams op_params; \ + op_params.left_padding_count = before_padding.size(); \ + op_params.right_padding_count = after_padding.size(); \ + for (int i = 0; i < op_context.dims; ++i) { \ + op_params.left_padding[i] = before_padding[op_context.dims - 1 - i]; \ + op_params.right_padding[i] = after_padding[op_context.dims - 1 - i]; \ + } \ + const scalar pad_value_copy = pad_value; \ + \ + type::Pad(op_params, GetTensorShape(op_context.input), \ + GetTensorData(op_context.input), &pad_value_copy, \ + GetTensorShape(op_context.output), \ GetTensorData(op_context.output)) switch (op_context.input->type) { case kTfLiteFloat32: { diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc index f8b9064fbb..f663899713 100644 --- a/tensorflow/contrib/lite/kernels/pad_test.cc +++ b/tensorflow/contrib/lite/kernels/pad_test.cc @@ -193,7 +193,7 @@ TEST(PadOpTest, TooManyDimensions) { PadOpConstModel({TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2}, {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, {TensorType_FLOAT32}), - "dims != 4"); + "dims <= 4"); } TEST(PadOpTest, UnequalDimensions) { @@ -221,6 +221,15 @@ TEST(PadOpTest, SimpleConstTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } +TEST(PadOpTest, SimpleConst1DTest) { + PadOpConstModel m({TensorType_FLOAT32, {2}}, {1, 2}, {1, 2}, + {TensorType_FLOAT32}); + m.SetInput({2, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 3, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5})); +} + TEST(PadOpTest, SimpleDynamicTest) { PadOpDynamicModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2}, {TensorType_FLOAT32}); @@ -334,7 +343,7 @@ TEST(PadV2OpTest, TooManyDimensions) { {TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2}, {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, 0.0, {TensorType_FLOAT32}), - "dims != 4"); + "dims <= 4"); } TEST(PadV2OpTest, UnequalDimensions) { diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 57134ccd15..32f02a4f6c 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -1679,6 +1679,7 @@ def make_pad_tests(zip_path): # TODO(nupurgarg): Add test for tf.uint8. test_parameters = [ + # 4D: { "dtype": [tf.int32, tf.int64, tf.float32], "input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]], @@ -1686,13 +1687,20 @@ def make_pad_tests(zip_path): [0, 0], [2, 3]]], "constant_paddings": [True, False], }, - # Non-4D use case. + # 2D: { "dtype": [tf.int32, tf.int64, tf.float32], - "input_shape": [[1, 2], [0, 1, 2]], + "input_shape": [[1, 2]], "paddings": [[[0, 1], [2, 3]]], "constant_paddings": [True, False], }, + # 1D: + { + "dtype": [tf.int32], + "input_shape": [[1]], + "paddings": [[[1, 2]]], + "constant_paddings": [False], + }, ] def build_graph(parameters): @@ -1730,6 +1738,7 @@ def make_padv2_tests(zip_path): # TODO(nupurgarg): Add test for tf.uint8. test_parameters = [ + # 4D: { "dtype": [tf.int32, tf.int64, tf.float32], "input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]], @@ -1738,14 +1747,22 @@ def make_padv2_tests(zip_path): "constant_paddings": [True, False], "constant_values": [0, 2], }, - # Non-4D use case. + # 2D: { "dtype": [tf.int32, tf.int64, tf.float32], - "input_shape": [[1, 2], [0, 1, 2]], + "input_shape": [[1, 2]], "paddings": [[[0, 1], [2, 3]]], "constant_paddings": [True, False], "constant_values": [0, 2], }, + # 1D: + { + "dtype": [tf.int32], + "input_shape": [[1]], + "paddings": [[[0, 1]]], + "constant_paddings": [False], + "constant_values": [0, 2], + }, ] def build_graph(parameters): diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 37c7ae0e1c..349aa5a3b4 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -58,12 +58,6 @@ tensorflow::Env* env = tensorflow::Env::Default(); // Key is a substring of the test name and value is a bug number. // TODO(ahentz): make sure we clean this list up frequently. std::map kBrokenTests = { - // Pad and PadV2 only supports 4D tensors. - {R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])", - "70527055"}, - {R"(^\/padv2.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])", - "70527055"}, - // L2Norm only supports tensors with 4D or fewer. {R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, -- GitLab From ffd9519c3fffe43473f06a1c8fdd12519490db3b Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 4 Sep 2018 13:52:01 -0700 Subject: [PATCH 067/540] Optimize CuboidConvolutionBackwardKernel (Conv3D kernel backprop). * simplify contraction by collapsing inner dims into single dimension * get rid of expensive reverse op ~5X improvement when compiled with AVX. PiperOrigin-RevId: 211518363 --- .../eigen_backward_cuboid_convolutions.h | 304 ++++++------------ 1 file changed, 96 insertions(+), 208 deletions(-) diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h index e13e548f86..3ebeb7be2b 100644 --- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h @@ -323,47 +323,34 @@ CuboidConvolutionBackwardInput( template EIGEN_ALWAYS_INLINE static const typename internal::conditional< internal::traits::Layout == ColMajor, - const TensorShufflingOp< - const array::Index, 5>, - const TensorReverseOp< - const array, + TensorReshapingOp< + const DSizes::Index, 5>, + const TensorContractionOp< + const array::Index>, 1>, const TensorReshapingOp< - const DSizes::Index, - 5>, - const TensorContractionOp< - const array< - IndexPair::Index>, 2>, - const TensorReshapingOp< - const DSizes::Index, - 3>, - const Input>, - const TensorReshapingOp< - const DSizes< - typename internal::traits::Index, - 4>, - const TensorVolumePatchOp< - Dynamic, Dynamic, Dynamic, - const OutputBackward> > > > > >, - const TensorShufflingOp< - const array::Index, 5>, - const TensorReverseOp< - const array, + const DSizes::Index, 2>, + const OutputBackward>, + const TensorShufflingOp< + const array::Index, + 2>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorVolumePatchOp > > > >, + TensorReshapingOp< + const DSizes::Index, 5>, + const TensorContractionOp< + const array::Index>, 1>, + const TensorShufflingOp< + const array::Index, + 2>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorVolumePatchOp > >, const TensorReshapingOp< - const DSizes::Index, - 5>, - const TensorContractionOp< - const array< - IndexPair::Index>, 2>, - const TensorReshapingOp< - const DSizes< - typename internal::traits::Index, - 4>, - const TensorVolumePatchOp >, - const TensorReshapingOp< - const DSizes::Index, - 3>, - const Input> > > > > >::type + const DSizes::Index, 2>, + const OutputBackward> > > >::type CuboidConvolutionBackwardKernel( const Input& input, const OutputBackward& output_backward, typename internal::traits::Index kernelPlanes, @@ -406,213 +393,114 @@ CuboidConvolutionBackwardKernel( const TensorIndex outputCols = isColMajor ? out.dimension(3) : out.dimension(NumDims - 4); + // Number of filters. This is the same as the output depth. const TensorIndex kernelFilters = isColMajor ? out.dimension(0) : out.dimension(NumDims - 1); + // Number of channels. This is the same as the input depth. const TensorIndex kernelChannels = isColMajor ? in.dimension(0) : in.dimension(NumDims - 1); - TensorIndex forward_pad_z, forward_pad_y, forward_pad_x; - const TensorIndex size_z = - Eigen::divup(inputPlanes, static_cast(stridePlanes)); - const TensorIndex size_y = - Eigen::divup(inputRows, static_cast(strideRows)); - const TensorIndex size_x = - Eigen::divup(inputCols, static_cast(strideCols)); - - // Infer padding type. - if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) { - // SAME padding. - const TensorIndex dz = numext::maxi( - 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes); - const TensorIndex dy = numext::maxi( - 0, (size_y - 1) * strideRows + kernelRows - inputRows); - const TensorIndex dx = numext::maxi( - 0, (size_x - 1) * strideCols + kernelCols - inputCols); - - forward_pad_z = dz / 2; - forward_pad_y = dy / 2; - forward_pad_x = dx / 2; - } else { - // VALID padding. - forward_pad_z = 0; - forward_pad_y = 0; - forward_pad_x = 0; - } - - const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z; - const TensorIndex padding_top = kernelRows - 1 - forward_pad_y; - const TensorIndex padding_left = kernelCols - 1 - forward_pad_x; - - const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 - - (outputPlanes - 1) * stridePlanes - 1 - - padding_ztop; - const TensorIndex padding_bottom = inputRows + kernelRows - 1 - - (outputRows - 1) * strideRows - 1 - - padding_top; - const TensorIndex padding_right = inputCols + kernelCols - 1 - - (outputCols - 1) * strideCols - 1 - - padding_left; - - eigen_assert(padding_ztop >= 0); - eigen_assert(padding_zbottom >= 0); - eigen_assert(padding_top >= 0); - eigen_assert(padding_left >= 0); - eigen_assert(padding_bottom >= 0); - eigen_assert(padding_right >= 0); - - // The output_backward has dimensions out_depth X out_plaens X out_rows X - // out_cols X OTHERS - // When we extract the image patches from output_backward (with input as the - // kernel), it will have dimensions - // (out_depth) X (input_planes * input_rows * input_cols) X (kernel_planes * - // kernel_rows * kernel_cols) X OTHERS - DSizes pre_contract_dims; + // TODO(ezhulenev): Add support for inflated strides. Without inflated strides + // effective kernel planes/rows/cols are always the same as the kernel itself + // (see eigen_spatial_convolutions for details). + const TensorIndex kernelPlanesEff = kernelPlanes; + const TensorIndex kernelRowsEff = kernelRows; + const TensorIndex kernelColsEff = kernelCols; + + const TensorIndex padPlanes = numext::maxi( + 0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes); + const TensorIndex padRows = numext::maxi( + 0, (outputRows - 1) * strideRows + kernelRowsEff - inputRows); + const TensorIndex padCols = numext::maxi( + 0, (outputCols - 1) * strideCols + kernelColsEff - inputCols); + + const TensorIndex padding_top_z = padPlanes / 2; + const TensorIndex padding_bottom_z = padPlanes - padding_top_z; + const TensorIndex padding_top = padRows / 2; + const TensorIndex padding_bottom = padRows - padding_top; + const TensorIndex padding_left = padCols / 2; + const TensorIndex padding_right = padCols - padding_left; + + // Reshaped output_backward before contraction. + DSizes output_dims; if (isColMajor) { - pre_contract_dims[0] = kernelFilters; - pre_contract_dims[1] = inputRows * inputCols * inputPlanes; - pre_contract_dims[2] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[3] = 1; + output_dims[0] = kernelFilters; + output_dims[1] = outputPlanes * outputRows * outputCols; for (int i = 4; i < NumDims; ++i) { - pre_contract_dims[3] *= out.dimension(i); + output_dims[1] *= out.dimension(i); } } else { - pre_contract_dims[3] = kernelFilters; - pre_contract_dims[2] = inputRows * inputCols * inputPlanes; - pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[0] = 1; + output_dims[1] = kernelFilters; + output_dims[0] = outputCols * outputRows * outputPlanes; for (int i = 0; i < NumDims - 4; ++i) { - pre_contract_dims[0] *= out.dimension(i); + output_dims[0] *= out.dimension(i); } } - // The input has dimensions in_depth X (input_planes * input_rows * - // input_cols) X OTHERS - DSizes input_dims; + // Reshaped extract_volume_patches(in) + DSizes pre_contract_dims; if (isColMajor) { - input_dims[0] = kernelChannels; - input_dims[1] = inputRows * inputCols * inputPlanes; - input_dims[2] = 1; + pre_contract_dims[0] = + kernelChannels * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[1] = outputPlanes * outputRows * outputCols; for (int i = 4; i < NumDims; ++i) { - input_dims[2] *= in.dimension(i); + pre_contract_dims[1] *= in.dimension(i); } - eigen_assert(input_dims[2] == pre_contract_dims[3]); + eigen_assert(output_dims[1] == pre_contract_dims[1]); } else { - input_dims[2] = kernelChannels; - input_dims[1] = inputRows * inputCols * inputPlanes; - input_dims[0] = 1; + pre_contract_dims[1] = + kernelCols * kernelRows * kernelPlanes * kernelChannels; + pre_contract_dims[0] = outputCols * outputRows * outputPlanes; for (int i = 0; i < NumDims - 4; ++i) { - input_dims[0] *= in.dimension(i); + pre_contract_dims[0] *= in.dimension(i); } - eigen_assert(input_dims[0] == pre_contract_dims[0]); + eigen_assert(output_dims[0] == pre_contract_dims[0]); } - // We will contract along dimensions (1, 2) in and (1, 3) in out, if - // this is col-major. - // For row-major, it's dimensions (0, 1) in and (0, 2) in out. - array, 2> contract_dims; - if (isColMajor) { - // col-major: in.contract(output.patches) - contract_dims[0] = IndexPair(1, 1); - contract_dims[1] = IndexPair(2, 3); - } else { - // row-major: output.patches.contract(in) - contract_dims[0] = IndexPair(0, 0); - contract_dims[1] = IndexPair(2, 1); - } + array shuffle_dims; + shuffle_dims[0] = 1; + shuffle_dims[1] = 0; - // After the contraction, the kernel will have dimension - // in_depth X out_depth X kernel_patches X kernel_rows X kernel_cols - // We will need to shuffle the first two dimensions and reverse the spatial - // dimensions. - // The end shape is: - // out_depth X in_shape X kernel_planes X kernel_rows X kernel_cols + array, 1> contract_dims; + contract_dims[0] = IndexPair(1, 0); - // This is the shape of the kernel *before* the shuffling. DSizes kernel_dims; if (isColMajor) { - kernel_dims[0] = kernelChannels; - kernel_dims[1] = kernelFilters; + kernel_dims[0] = kernelFilters; + kernel_dims[1] = kernelChannels; kernel_dims[2] = kernelPlanes; kernel_dims[3] = kernelRows; kernel_dims[4] = kernelCols; } else { - kernel_dims[0] = kernelCols; - kernel_dims[1] = kernelRows; + kernel_dims[4] = kernelFilters; + kernel_dims[3] = kernelChannels; kernel_dims[2] = kernelPlanes; - kernel_dims[3] = kernelFilters; - kernel_dims[4] = kernelChannels; - } - - // Flip filters and channels. - array kernel_shuffle; - if (isColMajor) { - kernel_shuffle[0] = 1; - kernel_shuffle[1] = 0; - kernel_shuffle[2] = 2; - kernel_shuffle[3] = 3; - kernel_shuffle[4] = 4; - } else { - kernel_shuffle[0] = 0; - kernel_shuffle[1] = 1; - kernel_shuffle[2] = 2; - kernel_shuffle[3] = 4; - kernel_shuffle[4] = 3; - } - - // Reverse the spatial dimensions. - array kernel_reverse; - if (isColMajor) { - kernel_reverse[0] = false; - kernel_reverse[1] = false; - kernel_reverse[2] = true; - kernel_reverse[3] = true; - kernel_reverse[4] = true; - } else { - kernel_reverse[0] = true; - kernel_reverse[1] = true; - kernel_reverse[2] = true; - kernel_reverse[3] = false; - kernel_reverse[4] = false; + kernel_dims[1] = kernelRows; + kernel_dims[0] = kernelCols; } - DSizes strides; - for (int i = 0; i < NumDims; i++) { - strides[i] = 1; - } - if (isColMajor) { - strides[1] = stridePlanes; - strides[2] = strideRows; - strides[3] = strideCols; - } else { - strides[NumDims - 2] = stridePlanes; - strides[NumDims - 3] = strideRows; - strides[NumDims - 4] = strideCols; - } return choose( Cond::Layout == ColMajor>(), - input.reshape(input_dims) - .contract(output_backward + output_backward.reshape(output_dims) + .contract(input .extract_volume_patches( - inputPlanes, inputRows, inputCols, 1, 1, 1, - stridePlanes, strideRows, strideCols, - - padding_ztop, padding_zbottom, padding_top, - padding_bottom, padding_left, padding_right) - .reshape(pre_contract_dims), + kernelPlanes, kernelRows, kernelCols, stridePlanes, + strideRows, strideCols, 1, 1, 1, padding_top_z, + padding_bottom_z, padding_top, padding_bottom, + padding_left, padding_right) + .reshape(pre_contract_dims) + .shuffle(shuffle_dims), contract_dims) - .reshape(kernel_dims) - .reverse(kernel_reverse) - .shuffle(kernel_shuffle), - output_backward - .extract_volume_patches(inputPlanes, inputRows, inputCols, 1, 1, 1, - stridePlanes, strideRows, strideCols, - padding_ztop, padding_zbottom, padding_top, + .reshape(kernel_dims), + input + .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, + stridePlanes, strideRows, strideCols, 1, 1, 1, + padding_top_z, padding_bottom_z, padding_top, padding_bottom, padding_left, padding_right) .reshape(pre_contract_dims) - .contract(input.reshape(input_dims), contract_dims) - .reshape(kernel_dims) - .reverse(kernel_reverse) - .shuffle(kernel_shuffle)); + .shuffle(shuffle_dims) + .contract(output_backward.reshape(output_dims), contract_dims) + .reshape(kernel_dims)); } } // end namespace Eigen -- GitLab From 97039a80b3dabb5ed2e4fb5d0d0bdc5229293718 Mon Sep 17 00:00:00 2001 From: HyoukJoong Lee Date: Tue, 4 Sep 2018 13:56:42 -0700 Subject: [PATCH 068/540] Fix CRS combiner for spatial partitioning PiperOrigin-RevId: 211519250 --- .../compiler/xla/service/hlo_domain_map.cc | 41 +++++++++++++++++++ .../compiler/xla/service/hlo_domain_map.h | 10 +++++ .../xla/service/hlo_domain_metadata.h | 3 ++ .../compiler/xla/service/hlo_domain_test.cc | 2 + .../xla/service/hlo_sharding_metadata.cc | 7 ++++ .../xla/service/hlo_sharding_metadata.h | 2 + 6 files changed, 65 insertions(+) diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc index 8b2846e0c2..113fd18eae 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -51,6 +51,10 @@ int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const { return FindOrDefault(instruction_to_domain_, instruction, -1); } +int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const { + return FindOrDie(domain_metadata_id_, instruction); +} + Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); // We only check operands, so we are sure to not process the empty domain from @@ -93,6 +97,43 @@ Status HloDomainMap::Populate(HloComputation* computation) { CreateDomain(instruction, instructions_post_order)); TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); } + TF_RETURN_IF_ERROR(PopulateDomainMetadataMap()); + return Status::OK(); +} + +Status HloDomainMap::PopulateDomainMetadataMap() { + auto hash = [](const DomainMetadata* m) { return m->Hash(); }; + auto equal = [](const DomainMetadata* a, const DomainMetadata* b) { + return a->Matches(*b); + }; + tensorflow::gtl::FlatMap + domain_metadata(1024, hash, equal); + + for (auto& domain : instruction_domains_) { + int64 domain_metadata_id = -1; + if (!domain->enter_domains.empty()) { + const HloInstruction* domain_instruction = *domain->enter_domains.begin(); + domain_metadata_id = + domain_metadata + .insert({&domain_instruction->user_side_metadata(), + domain_metadata.size() + 1}) + .first->second; + } else if (!domain->exit_domains.empty()) { + const HloInstruction* domain_instruction = *domain->exit_domains.begin(); + domain_metadata_id = + domain_metadata + .insert({&domain_instruction->operand_side_metadata(), + domain_metadata.size() + 1}) + .first->second; + } else { + domain_metadata_id = 0; + } + TF_RET_CHECK(domain_metadata_id >= 0); + for (HloInstruction* instruction : domain->instructions) { + domain_metadata_id_[instruction] = domain_metadata_id; + } + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h index 633109249a..56b557d7ce 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_map.h +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -69,6 +69,11 @@ class HloDomainMap { // instruction is not found within any domain. int64 GetDomainId(HloInstruction* instruction) const; + // Returns the unique id of the domain metadata for the domain the given + // instruction belongs to. The given instruction must not be a kDomain + // instruction since each domain instruction is associated with 2 domains. + int64 GetDomainMetadataId(HloInstruction* instruction) const; + private: // Map used for representing instruction ordering, i.e. // order_map[a] < order_map[b] means a must be ordered before b. @@ -109,9 +114,14 @@ class HloDomainMap { const tensorflow::gtl::FlatSet& instruction_set, const InstructionOrderMap& instructions_order); + // Populates domain_metadata_id_ that maps each HloInstruction to the unique + // ID of its associated domain metatadata. + Status PopulateDomainMetadataMap(); + string domain_kind_; std::vector> instruction_domains_; tensorflow::gtl::FlatMap instruction_to_domain_; + tensorflow::gtl::FlatMap domain_metadata_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h index 6c142ee474..302807f816 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -72,6 +72,9 @@ class DomainMetadata { // two matches. virtual bool Matches(const DomainMetadata& other) const = 0; + // Returns the hash value of the metadata. + virtual size_t Hash() const = 0; + // Returns a string representation of the metadata. virtual string ToString() const = 0; }; diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 974ab94467..43e74d2f6f 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -99,6 +99,8 @@ class OpNameMetadata : public DomainMetadata { static absl::string_view KindName() { return "opname"; } + size_t Hash() const override { return std::hash()(opname_); } + private: string opname_; }; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 34cba6136f..e3f4a9852a 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -422,6 +422,13 @@ bool ShardingMetadata::Matches(const DomainMetadata& other) const { : false; } +size_t ShardingMetadata::Hash() const { + if (sharding_ != nullptr) { + return sharding_->Hash(); + } + return static_cast(0x297814aaad196e6dULL); +} + string ShardingMetadata::ToString() const { return sharding_ != nullptr ? sharding_->ToString() : "{}"; } diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index cba5db927a..e3ae82a070 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -36,6 +36,8 @@ class ShardingMetadata : public DomainMetadata { bool Matches(const DomainMetadata& other) const override; + size_t Hash() const override; + string ToString() const override; const HloSharding* sharding() const { return sharding_.get(); } -- GitLab From 44a80cfa262da58d824ed6e0a7a1ffd1eea8a55b Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Tue, 4 Sep 2018 13:59:06 -0700 Subject: [PATCH 069/540] Simplify _get_grad_fn_name and other minor fixes. PiperOrigin-RevId: 211519628 --- tensorflow/python/ops/cond_v2_impl.py | 51 +++++++++++---------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py index c4e9c982b5..c6a6b2a7fa 100644 --- a/tensorflow/python/ops/cond_v2_impl.py +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -180,16 +180,16 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name def _get_func_graphs(if_op): - """Returns `_FuncGraph`s for the input op branches. + """Returns `FuncGraph`s for the input op branches. Args: if_op: The _If Operation. Returns: - A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch. + A 2-tuple of the `FuncGraph`s of the then_branch and else_branch. """ def _get_func_graph_for_branch(branch_name): - """Generates and returns a _FuncGraph for the given branch.""" + """Generates and returns a FuncGraph for the given branch.""" inputs = if_op.inputs[1:] # First input is pred. input_shapes = [t.shape for t in inputs] func_name = if_op.get_attr(branch_name).name @@ -197,7 +197,7 @@ def _get_func_graphs(if_op): # `if_op.graph` may not be the same as `ops.get_default_graph()` e.g. # in the case of nested if ops or when the gradient is being computed # from inside a Defun. We build the `func_graph` with `if_op.graph` as its - # `outer_graph`. This resembles how the `_FuncGraph` was built in the + # `outer_graph`. This resembles how the `FuncGraph` was built in the # forward pass. We need this so that we can resolve references to tensors # in `func_graph` from its gradient graph in `_resolve_grad_inputs`. with if_op.graph.as_default(): @@ -221,7 +221,7 @@ def _grad_fn(func_graph, grads): func_graph's outputs w.r.t. its inputs. Args: - func_graph: function._FuncGraph. The corresponding forward-pass function. + func_graph: function.FuncGraph. The corresponding forward-pass function. grads: The list of input gradient Tensors. Returns: @@ -259,7 +259,7 @@ def _grad_fn(func_graph, grads): def _create_grad_func(func_graph, grads, name): - """Returns the _FuncGraph representation of _grad_fn.""" + """Returns the FuncGraph representation of _grad_fn.""" return _function.func_graph_from_py_func( name, lambda: _grad_fn(func_graph, grads), [], {}) @@ -277,8 +277,8 @@ def _resolve_grad_inputs(cond_graph, grad_graph): functions, this is always possible. Args: - cond_graph: function._FuncGraph. The forward-pass function. - grad_graph: function._FuncGraph. The gradients function. + cond_graph: function.FuncGraph. The forward-pass function. + grad_graph: function.FuncGraph. The gradients function. Returns: A list of inputs tensors to be passed to grad_graph. @@ -313,7 +313,7 @@ def _create_new_tf_function(func_graph): """Converts func_graph to a TF_Function and adds it to the current graph. Args: - func_graph: function._FuncGraph + func_graph: function.FuncGraph Returns: The name of the new TF_Function. @@ -365,8 +365,8 @@ def _pad_params(true_graph, false_graph, true_params, false_params): There is no merging of params. Args: - true_graph: function._FuncGraph - false_graph: function._FuncGraph + true_graph: function.FuncGraph + false_graph: function.FuncGraph true_params: a list of Tensors from true_graph false_params: a list of Tensors from false_graph @@ -391,8 +391,8 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): graph to avoid duplicating shared arguments. Args: - true_graph: function._FuncGraph - false_graph: function._FuncGraph + true_graph: function.FuncGraph + false_graph: function.FuncGraph true_inputs: a list of Tensors in the outer graph. The inputs for true_graph. false_inputs: a list of Tensors in the outer graph. The inputs for @@ -421,7 +421,7 @@ def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): _create_dummy_params(false_graph, true_only_inputs) + [false_input_to_param[t] for t in false_only_inputs]) - # Rewrite the _FuncGraphs' state to reflect the new inputs. + # Rewrite the FuncGraphs' state to reflect the new inputs. true_graph.captures = collections.OrderedDict(zip(new_inputs, true_graph.inputs)) false_graph.captures = collections.OrderedDict(zip(new_inputs, @@ -434,7 +434,7 @@ def _create_dummy_params(func_graph, template_tensors): """Creates tensors in func_graph to represent template_tensors. Args: - func_graph: function._FuncGraph. + func_graph: function.FuncGraph. template_tensors: a list of tensors in the outer graph. Returns: @@ -451,27 +451,16 @@ def _get_grad_fn_name(func_graph): Ensures this name is unique in the entire hierarchy. Args: - func_graph: The _FuncGraph. + func_graph: The FuncGraph. Returns: A string, the name to use for the gradient function. """ name = "%s_grad" % func_graph.name - - base_name = name - counter = 1 - has_conflict = True - while has_conflict: - curr_graph = func_graph.outer_graph - has_conflict = curr_graph._is_function(name) - while not has_conflict and isinstance(curr_graph, _function.FuncGraph): - curr_graph = curr_graph.outer_graph - has_conflict = curr_graph._is_function(name) - if has_conflict: - name = "%s_%s" % (base_name, counter) - counter += 1 - - return name + outer_most_graph = func_graph + while isinstance(outer_most_graph, _function.FuncGraph): + outer_most_graph = outer_most_graph.outer_graph + return outer_most_graph.unique_name(name) def _check_same_outputs(true_graph, false_graph): -- GitLab From 8cf8afefdb4c240f74a05e24246c8cd2dcce9d54 Mon Sep 17 00:00:00 2001 From: Michael Case Date: Tue, 4 Sep 2018 13:59:25 -0700 Subject: [PATCH 070/540] Internal Change. PiperOrigin-RevId: 211519679 --- tensorflow/contrib/__init__.py | 8 ++++++++ tensorflow/python/__init__.py | 7 +++++++ tensorflow/python/tools/component_api_helper.py | 2 +- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 5f477a79a3..9478e42b46 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -21,6 +21,14 @@ from __future__ import print_function import os +from tensorflow.python.tools import component_api_helper +component_api_helper.package_hook( + parent_package_str=( + "tensorflow.contrib"), + child_package_str=( + "tensorflow_estimator.contrib.estimator")) +del component_api_helper + # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import autograph from tensorflow.contrib import batching diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index a2ab63bb48..4921ecc43c 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -48,6 +48,13 @@ import numpy as np from tensorflow.python import pywrap_tensorflow +from tensorflow.python.tools import component_api_helper +component_api_helper.package_hook( + parent_package_str='tensorflow.python', + child_package_str=( + 'tensorflow_estimator.python.estimator')) +del component_api_helper + # Protocol buffers from tensorflow.core.framework.graph_pb2 import * from tensorflow.core.framework.node_def_pb2 import * diff --git a/tensorflow/python/tools/component_api_helper.py b/tensorflow/python/tools/component_api_helper.py index 988ecc61f0..e261758add 100644 --- a/tensorflow/python/tools/component_api_helper.py +++ b/tensorflow/python/tools/component_api_helper.py @@ -67,7 +67,7 @@ def package_hook(parent_package_str, child_package_str, error_msg=None): """ child_pkg_path = [os.path.join(os.path.dirname(child_pkg.__file__), "..")] try: - parent_pkg.__path__ += child_pkg_path + parent_pkg.__path__ = child_pkg_path + parent_pkg.__path__ except AttributeError: parent_pkg.__path__ = child_pkg_path -- GitLab From 06e8109af2e5ae5bc149e25fc64fbf66d6c8b817 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 4 Sep 2018 14:01:13 -0700 Subject: [PATCH 071/540] [tf.data] Add internal optimizations for executing simple functions in `MapDataset`. PiperOrigin-RevId: 211520001 --- .../contrib/data/python/ops/interleave_ops.py | 17 +- tensorflow/contrib/data/python/ops/readers.py | 6 +- tensorflow/core/graph/testlib.cc | 27 ++ tensorflow/core/graph/testlib.h | 9 + tensorflow/core/kernels/data/BUILD | 37 ++ .../core/kernels/data/captured_function.cc | 20 +- .../core/kernels/data/captured_function.h | 13 +- .../core/kernels/data/map_dataset_op.cc | 6 +- .../kernels/data/single_threaded_executor.cc | 378 ++++++++++++++++++ .../kernels/data/single_threaded_executor.h | 60 +++ .../data/single_threaded_executor_test.cc | 330 +++++++++++++++ .../core/kernels/save_restore_tensor.cc | 9 +- tensorflow/core/ops/dataset_ops.cc | 1 + .../data/kernel_tests/map_dataset_op_test.py | 107 ++--- tensorflow/python/data/ops/dataset_ops.py | 4 +- 15 files changed, 963 insertions(+), 61 deletions(-) create mode 100644 tensorflow/core/kernels/data/single_threaded_executor.cc create mode 100644 tensorflow/core/kernels/data/single_threaded_executor.h create mode 100644 tensorflow/core/kernels/data/single_threaded_executor_test.cc diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 38c0a09c33..92d4251a86 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -220,6 +220,7 @@ def sample_from_datasets(datasets, weights=None, seed=None): if weights is None: # Select inputs with uniform probability. logits = [[1.0] * num_datasets] + else: # Use the given `weights` as the probability of choosing the respective # input. @@ -245,8 +246,11 @@ def sample_from_datasets(datasets, weights=None, seed=None): return array_ops.squeeze( stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) - selector_input = random_ops.RandomDataset(seed).batch(2).map( - select_dataset_constant_logits) + selector_input = dataset_ops.MapDataset( + random_ops.RandomDataset(seed).batch(2), + select_dataset_constant_logits, + use_inter_op_parallelism=False) + else: # Use each element of the given `weights` dataset as the probability of # choosing the respective input. @@ -259,9 +263,12 @@ def sample_from_datasets(datasets, weights=None, seed=None): return array_ops.squeeze( stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) - selector_input = dataset_ops.Dataset.zip( - (logits_ds, random_ops.RandomDataset(seed).batch(2) - )).map(select_dataset_varying_logits) + logits_and_seeds = dataset_ops.Dataset.zip( + (logits_ds, random_ops.RandomDataset(seed).batch(2))) + selector_input = dataset_ops.MapDataset( + logits_and_seeds, + select_dataset_varying_logits, + use_inter_op_parallelism=False) return _DirectedInterleaveDataset(selector_input, datasets) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 7f09ba71dc..4c466781f7 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -499,7 +499,8 @@ def make_csv_dataset( # indefinitely, and all batches will be full-sized. dataset = dataset.batch(batch_size=batch_size, drop_remainder=num_epochs is None) - dataset = dataset.map(map_fn) + dataset = dataset_ops.MapDataset( + dataset, map_fn, use_inter_op_parallelism=False) dataset = dataset.prefetch(prefetch_buffer_size) return dataset @@ -778,7 +779,8 @@ def make_batched_features_dataset(file_pattern, # Extract values if the `Example` tensors are stored as key-value tuples. if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset.map(lambda _, v: v) + dataset = dataset_ops.MapDataset( + dataset, lambda _, v: v, use_inter_op_parallelism=False) # Apply dataset repeat and shuffle transformations. dataset = _maybe_shuffle_and_repeat( diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index ea7788f654..0a38aa1c91 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -485,6 +485,33 @@ Node* DiagPart(Graph* g, Node* in, DataType type) { return ret; } +Node* CheckNumerics(Graph* g, Node* in, const string& message) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics") + .Input(in) + .Attr("message", message) + .Finalize(g, &ret)); + return ret; +} + +Node* Arg(Graph* g, int64 index, DataType type) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg") + .Attr("T", type) + .Attr("index", index) + .Finalize(g, &ret)); + return ret; +} + +Node* Retval(Graph* g, int64 index, Node* in) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval") + .Input(in) + .Attr("index", index) + .Finalize(g, &ret)); + return ret; +} + void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } } // end namespace graph diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h index 8585b35a19..bd0284d43a 100644 --- a/tensorflow/core/graph/testlib.h +++ b/tensorflow/core/graph/testlib.h @@ -209,6 +209,15 @@ Node* Diag(Graph* g, Node* in, DataType type); // Add a DiagPart node in "g". Node* DiagPart(Graph* g, Node* in, DataType type); +// Add a CheckNumerics node in "g". +Node* CheckNumerics(Graph* g, Node* in, const string& message); + +// Add an _Arg node in "g". +Node* Arg(Graph* g, int64 index, DataType type); + +// Add a _Retval node in "g". +Node* Retval(Graph* g, int64 index, Node* in); + } // end namespace graph } // end namespace test } // end namespace tensorflow diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index e7b3d0c92f..3a1ac73f64 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -51,6 +51,7 @@ cc_library( hdrs = ["captured_function.h"], deps = [ ":dataset", + ":single_threaded_executor", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -60,6 +61,42 @@ cc_library( ], ) +cc_library( + name = "single_threaded_executor", + srcs = ["single_threaded_executor.cc"], + hdrs = ["single_threaded_executor.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "single_threaded_executor_test", + srcs = ["single_threaded_executor_test.cc"], + deps = [ + ":single_threaded_executor", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:math", + "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:state", + ], +) + cc_library( name = "window_dataset", srcs = ["window_dataset.cc"], diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index abdf6ee4e8..186740c2ac 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -28,7 +28,16 @@ namespace tensorflow { Status CapturedFunction::Create( const NameAttrList& func, std::vector captured_inputs, std::unique_ptr* out_function) { - out_function->reset(new CapturedFunction(func, std::move(captured_inputs))); + return Create(func, std::move(captured_inputs), true, out_function); +} + +/* static */ +Status CapturedFunction::Create( + const NameAttrList& func, std::vector captured_inputs, + bool use_inter_op_parallelism, + std::unique_ptr* out_function) { + out_function->reset(new CapturedFunction(func, std::move(captured_inputs), + use_inter_op_parallelism)); return Status::OK(); } @@ -272,6 +281,9 @@ Status CapturedFunction::Instantiate(IteratorContext* ctx) { inst_opts.overlay_lib = ctx->function_library().get(); inst_opts.state_handle = std::to_string(random::New64()); inst_opts.create_kernels_eagerly = true; + if (!use_inter_op_parallelism_) { + inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR"; + } Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_opts, &f_handle_)); TF_RETURN_IF_ERROR(s); @@ -398,10 +410,12 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, } CapturedFunction::CapturedFunction(const NameAttrList& func, - std::vector captured_inputs) + std::vector captured_inputs, + bool use_inter_op_parallelism) : func_(func), lib_(nullptr), f_handle_(kInvalidHandle), - captured_inputs_(std::move(captured_inputs)) {} + captured_inputs_(std::move(captured_inputs)), + use_inter_op_parallelism_(use_inter_op_parallelism) {} } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index c95f2b1c01..ae6bdfc2a0 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -48,6 +48,15 @@ class CapturedFunction { std::vector captured_inputs, std::unique_ptr* out_function); + // Creates a new instance from a list of named attributes and captured inputs. + // + // If `low_latency_hint` is true, the runtime may use an executor that is + // optimized for small functions. + static Status Create(const NameAttrList& func, + std::vector captured_inputs, + bool use_inter_op_parallelism, + std::unique_ptr* out_function); + // Creates a new instance using a list of named attributes, fetching captured // inputs from a context argument. static Status Create(const NameAttrList& func, OpKernelContext* ctx, @@ -114,7 +123,8 @@ class CapturedFunction { private: CapturedFunction(const NameAttrList& func, - std::vector captured_inputs); + std::vector captured_inputs, + bool use_inter_op_parallelism); Status GetHandle(IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle); @@ -126,6 +136,7 @@ class CapturedFunction { const std::vector captured_inputs_; DataTypeSlice ret_types_; std::function)> captured_runner_ = nullptr; + const bool use_inter_op_parallelism_; TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); }; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 7f8182d917..6c45fcafcc 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -34,6 +34,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", + &use_inter_op_parallelism_)); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, @@ -48,7 +50,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr captured_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + func_, std::move(other_arguments), + use_inter_op_parallelism_, &captured_func)); *output = new Dataset(ctx, input, func_, std::move(captured_func), output_types_, output_shapes_); @@ -187,6 +190,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { DataTypeVector output_types_; std::vector output_shapes_; NameAttrList func_; + bool use_inter_op_parallelism_; }; REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp); diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc new file mode 100644 index 0000000000..e785b8b4d5 --- /dev/null +++ b/tensorflow/core/kernels/data/single_threaded_executor.cc @@ -0,0 +1,378 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/single_threaded_executor.h" + +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/executor_factory.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace { + +typedef gtl::InlinedVector TensorValueVec; +typedef gtl::InlinedVector DeviceContextVec; +typedef gtl::InlinedVector AllocatorAttributeVec; + +class SingleThreadedExecutorImpl : public Executor { + public: + explicit SingleThreadedExecutorImpl(const LocalExecutorParams& params) + : params_(params) {} + + ~SingleThreadedExecutorImpl() override { + for (const KernelState& kernel_state : kernels_) { + params_.delete_kernel(kernel_state.kernel); + } + } + + Status Initialize(const Graph& graph) { + // Topologicially sort `graph` to get a sequence of OpKernels. + std::vector ordered_nodes; + ordered_nodes.reserve(graph.num_nodes()); + GetReversePostOrder(graph, &ordered_nodes); + + if (ordered_nodes.size() != graph.num_nodes()) { + return errors::InvalidArgument("Graph had ", graph.num_nodes(), + " but reverse post-order had ", + ordered_nodes.size()); + } + + kernels_.resize(ordered_nodes.size()); + + std::unordered_map node_to_index_map; + + // Create the kernel and input-related structures for each node in `graph`. + for (size_t i = 0; i < ordered_nodes.size(); ++i) { + Node* n = ordered_nodes[i]; + node_to_index_map[n] = i; + + for (DataType dt : n->output_types()) { + if (IsRefType(dt)) { + return errors::Unimplemented( + "Single-threaded executor does not support reference-typed " + "edges."); + } + } + + if (n->IsControlFlow()) { + return errors::Unimplemented( + "Single-threaded executor does not support control flow."); + } + if (n->IsSend() || n->IsHostSend() || n->IsRecv() || n->IsHostRecv()) { + return errors::Unimplemented( + "Single-threaded executor does not support partitioned graphs."); + } + if (n->IsCollective()) { + return errors::Unimplemented( + "Single-threaded executor does not support collective ops."); + } + + KernelState& kernel_state = kernels_[i]; + TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel)); + kernel_state.num_inputs = n->num_inputs(); + kernel_state.num_outputs = n->num_outputs(); + + if (i == 0) { + kernel_state.input_start_index = 0; + } else { + const KernelState& previous_kernel_state = kernels_[i - 1]; + kernel_state.input_start_index = + previous_kernel_state.input_start_index + + previous_kernel_state.num_inputs; + } + } + + // Build the mapping from each node output to the input slot for the + // corresponding destination node. + for (size_t i = 0; i < ordered_nodes.size(); ++i) { + Node* n = ordered_nodes[i]; + KernelState& kernel_state = kernels_[i]; + kernel_state.output_locations.resize(kernel_state.num_outputs); + for (const Edge* e : n->out_edges()) { + if (!e->IsControlEdge()) { + kernel_state.output_locations[e->src_output()].push_back( + kernels_[node_to_index_map[e->dst()]].input_start_index + + e->dst_input()); + } + } + + // Compute allocator attributes for each node output, and corresponding + // node input. + kernel_state.output_alloc_attrs.resize(kernel_state.num_outputs); + AllocatorAttributes* attrs = kernel_state.output_alloc_attrs.data(); + + OpKernel* op_kernel = kernel_state.kernel; + for (int out = 0; out < n->num_outputs(); out++) { + DCHECK_LT(out, op_kernel->output_memory_types().size()); + bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY; + if (on_host) { + AllocatorAttributes h; + h.set_on_host(on_host); + attrs[out].Merge(h); + } + } + } + + if (!kernels_.empty()) { + const KernelState& last_kernel_state = kernels_.back(); + total_num_inputs_ = + last_kernel_state.input_start_index + last_kernel_state.num_inputs; + input_alloc_attrs_.resize(total_num_inputs_); + for (size_t i = 0; i < ordered_nodes.size(); ++i) { + for (size_t j = 0; j < kernels_[i].output_locations.size(); ++j) { + for (size_t output_location : kernels_[i].output_locations[j]) { + input_alloc_attrs_[output_location] = + kernels_[i].output_alloc_attrs[j]; + } + } + } + } else { + total_num_inputs_ = 0; + } + return Status::OK(); + } + + // TODO(mrry): Consider specializing the implementation of Executor::Run() + // instead, to avoid unnecessary atomic operations in the callback when + // running synchronously. + void RunAsync(const Args& args, DoneCallback done) override { + // The inputs to each kernel are stored contiguously in `inputs`. + // + // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to + // determine the range of elements in this vector that correspond to + // the inputs of `kernels_[i]`. + // + // This vector has the following layout: + // + // * Kernel 0, input 0. + // * Kernel 0, input 1. + // * ... + // * Kernel 0, input `kernels_[0].num_inputs - 1`. + // * Kernel 1, input 0. + // * ... + // * Kernel 1, input `kernels_[1].num_inputs - 1`. + // * ... + // * Kernel `kernels_.size() - 1`, input 0. + // * ... + // * Kernel `kernels_.size() - 1`, input `kernels_.back().num_inputs - 1`. + // + // Note that kernels with zero inputs do not correspond to any elements in + // this vector. + // + // We use `ManualConstructor` to avoid the overhead of + // default-constructing an invalid `Tensor` for each slot at the beginning + // of execution: + // * Elements are initialized when the outputs of a kernel execution are + // propagated to the inputs of kernels that depend on them. + // * The elements corresponding to the inputs for kernel `i` are destroyed + // after kernel `i` executes. + // * In an error case (see below), we use the connectivity information in + // `KernelState::output_locations` to determine which locations have been + // initialized, and manually destroy them. + std::vector> inputs(total_num_inputs_); + + // TODO(mrry): Can we avoid copying into these vectors? Consider modifying + // OpKernelContext to take the TensorValueVec as a pointer into `inputs`. + TensorValueVec node_inputs; + DeviceContextVec input_device_contexts; + AllocatorAttributeVec input_alloc_attrs; + + // Prepare the parameters that will be the same for all kernels. + OpKernelContext::Params params; + params.step_id = args.step_id; + Device* device = params_.device; + params.device = device; + params.log_memory = false; // TODO(mrry): Too severe? + params.record_tensor_accesses = false; // TODO(mrry): Too severe? + params.rendezvous = args.rendezvous; + params.session_state = args.session_state; + params.tensor_store = args.tensor_store; + params.cancellation_manager = args.cancellation_manager; + // TODO(mrry): ArgOp is a relatively expensive OpKernel due to the Tensor + // allocations that it performs. Consider specializing its handling in the + // executor. + params.call_frame = args.call_frame; + params.function_library = params_.function_library; + params.resource_manager = device->resource_manager(); + params.step_container = args.step_container; + params.slice_reader_cache = nullptr; // TODO(mrry): Too severe? + params.inputs = &node_inputs; + params.input_device_contexts = &input_device_contexts; + params.input_alloc_attrs = &input_alloc_attrs; + + Args::Runner runner_copy = args.runner; + params.runner = &runner_copy; + params.stats_collector = args.stats_collector; + + // NOTE(mrry): We are assuming that the graph is loopless and condless. + params.frame_iter = FrameAndIter(0, 0); + params.is_input_dead = false; + + // TODO(mrry): Add non-default device context inference. + params.op_device_context = nullptr; + // TODO(mrry): Consider implementing forwarding. + params.forward_from_array = nullptr; + + // Execute the kernels one-at-a-time in topological order. + for (size_t i = 0; i < kernels_.size(); ++i) { + const KernelState& kernel_state = kernels_[i]; + + // Prepare the per-kernel parameters. + const size_t input_start_index = kernel_state.input_start_index; + const size_t num_inputs = kernel_state.num_inputs; + const size_t num_outputs = kernel_state.num_outputs; + + node_inputs.clear(); + node_inputs.resize(num_inputs); + input_alloc_attrs.clear(); + input_alloc_attrs.resize(num_inputs); + for (size_t j = 0; j < num_inputs; ++j) { + auto t = inputs[input_start_index + j].get(); + node_inputs[j].tensor = t; + input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j]; + } + params.op_kernel = kernel_state.kernel; + input_device_contexts.clear(); + input_device_contexts.resize(num_inputs); + params.output_attr_array = kernel_state.output_alloc_attrs.data(); + OpKernelContext ctx(¶ms, num_outputs); + + // Actually execute the kernel. + device->Compute(kernel_state.kernel, &ctx); + + if (!ctx.status().ok()) { + // On failure, we must manually free all intermediate tensors. We have + // already freed all the inputs for kernels up to (but not including) + // the `i`th kernel. We scan through the previously executed kernels and + // destroy any tensors that were destined to be the input for a kernel + // that has not yet executed. + for (size_t j = 0; j < i; ++j) { + const KernelState& executed_kernel_state = kernels_[j]; + for (size_t k = 0; k < executed_kernel_state.num_outputs; ++k) { + for (size_t output_location : + executed_kernel_state.output_locations[k]) { + if (output_location >= input_start_index) { + // Only destroy an output location if it is an input to an + // operation that has not yet executed. + inputs[output_location].Destroy(); + } + } + } + } + done(ctx.status()); + return; + } + + // Free the inputs to the current kernel. + for (size_t j = 0; j < num_inputs; ++j) { + inputs[input_start_index + j].Destroy(); + } + + // Forward the outputs of the kernel to the inputs of subsequent kernels. + for (size_t j = 0; j < num_outputs; ++j) { + TensorValue val = ctx.release_output(j); + // TODO(mrry): Consider flattening the `output_locations` vector + // to improve the cache-friendliness of this loop. + for (size_t output_location : kernel_state.output_locations[j]) { + // TODO(mrry): Validate that the types match the expected values or + // ensure that the necessary validation has already happened. + inputs[output_location].Init(*val.tensor); + } + delete val.tensor; + } + } + done(Status::OK()); + } + + private: + const LocalExecutorParams params_; + + // All following members are read-only after Initialize(). + + // The sum of the number of inputs for each node in the graph. This determines + // the length of the flat `inputs` vector. See comment at the beginning of + // `RunAsync()` for details. + size_t total_num_inputs_; + + // Represents cached graph structure state for each kernel. + struct KernelState { + // The kernel object. Not owned. + // + // This pointer is managed by `params_.create_kernel()` and + // `params_.delete_kernel()`. + OpKernel* kernel; + + // These fields determine the range of elements in `inputs` that corresponds + // to the inputs of `kernel`. + size_t input_start_index; + size_t num_inputs; + + size_t num_outputs; + + // For the `j`th output of `kernel`, `output_locations[j]` contains the + // locations in the flat `inputs` vector to which that output must be + // copied. See comment at the beginning of `RunAsync()` for details. + std::vector> + output_locations; // Length = `num_outputs`. + + // Memory space information for each output of `kernel`. + std::vector + output_alloc_attrs; // Length = `num_outputs`. + }; + std::vector kernels_; + + // Memory space information for each input. This information is stored in the + // same order as the flat `inputs` vector. See comment at the beginning of + // `RunAsync()` for details. + std::vector + input_alloc_attrs_; // Length = `total_num_inputs_`. +}; + +class SingleThreadedExecutorRegistrar { + public: + SingleThreadedExecutorRegistrar() { + ExecutorFactory::Register("SINGLE_THREADED_EXECUTOR", new Factory()); + } + + private: + class Factory : public ExecutorFactory { + Status NewExecutor(const LocalExecutorParams& params, + std::unique_ptr graph, + std::unique_ptr* out_executor) override { + Executor* ret; + TF_RETURN_IF_ERROR( + NewSingleThreadedExecutor(params, std::move(graph), &ret)); + out_executor->reset(ret); + return Status::OK(); + } + }; +}; +static SingleThreadedExecutorRegistrar registrar; + +} // namespace + +Status NewSingleThreadedExecutor(const LocalExecutorParams& params, + std::unique_ptr graph, + Executor** executor) { + std::unique_ptr impl( + new SingleThreadedExecutorImpl(params)); + TF_RETURN_IF_ERROR(impl->Initialize(*graph)); + *executor = impl.release(); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/single_threaded_executor.h b/tensorflow/core/kernels/data/single_threaded_executor.h new file mode 100644 index 0000000000..15836b24c9 --- /dev/null +++ b/tensorflow/core/kernels/data/single_threaded_executor.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ + +#include "tensorflow/core/common_runtime/executor.h" + +namespace tensorflow { + +// Creates a new `Executor` for executing `graph` synchronously on the caller +// thread. +// +// NOTE(mrry): The returned executor is optimized to impose low overhead on +// graphs that perform a small amount of work (e.g. <15us of work per graph on +// present architectures). It eschews concurrency, because issuing work to +// multiple threads can dominate the cost of executing small ops synchronously, +// and because contention in the executor data structures can reduce throughput +// (in terms of ops executed per unit time). +// +// However, the current implementation has the following limitations: +// +// 1. Reference-typed tensors are not supported and will not be supported in +// future. +// 2. Graphs with control flow (containing "Switch" and "Merge" nodes) are not +// currently supported. The current plan is to extend support to "functional" +// control flow after the TensorFlow APIs transition to building graphs in +// that form (e.g. `tf.cond_v2()`). +// 3. Partitioned graphs (containing "_Recv" nodes) are not currently supported. +// The present implementation executes kernels one at a time in topological +// order, and cannot currently distinguish between disconnected subgraphs +// that are logically connected by subgraphs on a different device. +// 4. Memory logging is not currently supported. +// 5. Allocation forwarding is not currently supported. +// 6. Non-default device contexts are not currently supported. In effect, this +// limits the executor to CPU devices. +// 7. Ops that rely on `OpKernelContext::slice_reader_cache()` being non-null +// are not currently supported. +// +// The single-threaded executor is primarily suitable for executing simple +// TensorFlow functions, such as one might find in a `tf.data` pipeline. +Status NewSingleThreadedExecutor(const LocalExecutorParams& params, + std::unique_ptr graph, + Executor** executor); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc new file mode 100644 index 0000000000..f8b5769197 --- /dev/null +++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc @@ -0,0 +1,330 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/single_threaded_executor.h" + +#include + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { + +class ExecutorTest : public ::testing::Test { + protected: + ExecutorTest() + : device_(DeviceFactory::NewDevice("CPU", {}, + "/job:localhost/replica:0/task:0")) {} + + ~ExecutorTest() override { + // There should always be exactly one Ref left on the Rendezvous + // when the test completes. + CHECK(rendez_->Unref()); + delete exec_; + delete device_; + } + + // Resets executor_ with a new executor based on a graph 'gdef'. + void Create(std::unique_ptr graph) { + const int version = graph->versions().producer(); + LocalExecutorParams params; + params.device = device_; + params.create_kernel = [this, version](const NodeDef& ndef, + OpKernel** kernel) { + return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel); + }; + params.delete_kernel = [](OpKernel* kernel) { + DeleteNonCachedKernel(kernel); + }; + delete exec_; + TF_CHECK_OK(NewSingleThreadedExecutor(params, std::move(graph), &exec_)); + runner_ = [](std::function fn) { fn(); }; + rendez_ = NewLocalRendezvous(); + } + + Status Run(Rendezvous* rendez) { + Executor::Args args; + args.rendezvous = rendez; + args.runner = runner_; + return exec_->Run(args); + } + + Status Run(CallFrameInterface* call_frame) { + Executor::Args args; + args.call_frame = call_frame; + args.runner = runner_; + return exec_->Run(args); + } + + Device* device_ = nullptr; + Executor* exec_ = nullptr; + Executor::Args::Runner runner_; + Rendezvous* rendez_ = nullptr; +}; + +// A float val -> Tensor +Tensor V(const float val) { + Tensor tensor(DT_FLOAT, TensorShape({})); + tensor.scalar()() = val; + return tensor; +} + +// A int32 val -> Tensor +Tensor VI(const int32 val) { + Tensor tensor(DT_INT32, TensorShape({})); + tensor.scalar()() = val; + return tensor; +} + +// A bool val -> Tensor +Tensor VB(const bool val) { + Tensor tensor(DT_BOOL, TensorShape({})); + tensor.scalar()() = val; + return tensor; +} + +// A double val -> Tensor +Tensor VD(const double val) { + Tensor tensor(DT_DOUBLE, TensorShape({})); + tensor.scalar()() = val; + return tensor; +} + +// Tensor -> a float val. +float V(const Tensor& tensor) { + CHECK_EQ(tensor.dtype(), DT_FLOAT); + CHECK(TensorShapeUtils::IsScalar(tensor.shape())); + return tensor.scalar()(); +} + +Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, + const string& receiver, const string& name) { + Rendezvous::ParsedKey result; + TF_CHECK_OK( + Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver, + name, FrameAndIter(0, 0)), + &result)); + return result; +} + +TEST_F(ExecutorTest, SimpleAdd) { + // c = a + b + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT); + auto in1 = test::graph::Arg(g.get(), 0, DT_FLOAT); + auto tmp = test::graph::Add(g.get(), in0, in1); + test::graph::Retval(g.get(), 0, tmp); + FixupSourceAndSinkEdges(g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT}); + TF_ASSERT_OK(call_frame.SetArgs({V(1.0), V(1.0)})); + TF_ASSERT_OK(Run(&call_frame)); + std::vector retvals; + TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false)); + EXPECT_EQ(2.0, V(retvals[0])); // out = 1.0 + 1.0 = 2.0 +} + +TEST_F(ExecutorTest, SelfAdd) { + // v0 <- a + // v1 = v0 + v0 + // v2 = v1 + v1 + // ... ... + // v10 = v9 + v9 + // + // b <- v10 + // All nodes are executed by one thread. + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto v = test::graph::Arg(g.get(), 0, DT_FLOAT); + const int N = 10; + for (int i = 1; i <= N; ++i) { + v = test::graph::Add(g.get(), v, v); + } + // out <- v10 + test::graph::Retval(g.get(), 0, v); + FixupSourceAndSinkEdges(g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT}); + // a = 1.0 + TF_ASSERT_OK(call_frame.SetArgs({V(1.0)})); + TF_ASSERT_OK(Run(&call_frame)); + std::vector retvals; + TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false)); + EXPECT_EQ(1024.0, V(retvals[0])); // b=v10=2*v9=4*v8=...=1024*a=1024.0 +} + +// Builds a graph which adds N copies of one variable "in". I.e., +// a + a + a + ... + a +// The returned graph is parenthesized ramdonly. I.e., +// a + ((a + a) + a) +// (a + a) + (a + a) +// ((a + a) + a) + a +// are all possibly generated. +void BuildTree(int N, Graph* g) { + CHECK_GT(N, 1); + // A single input node "in". + auto in = test::graph::Arg(g, 0, DT_FLOAT); + std::vector nodes; + int i = 0; + // Duplicate "in" N times. Each copies is named as l0, l1, l2, .... + for (; i < N; ++i) { + nodes.push_back(test::graph::Identity(g, in, 0)); + } + random::PhiloxRandom philox(0, 17); + random::SimplePhilox rnd(&philox); + while (nodes.size() > 1) { + // Randomly pick two from nodes and add them. The resulting node + // is named lik n10, n11, .... and is put back into "nodes". + int x = rnd.Uniform(nodes.size()); + auto in0 = nodes[x]; + nodes[x] = nodes.back(); + nodes.resize(nodes.size() - 1); + x = rnd.Uniform(nodes.size()); + auto in1 = nodes[x]; + // node = in0 + in1. + nodes[x] = test::graph::Add(g, in0, in1); + } + // The final output node "out". + test::graph::Retval(g, 0, nodes.back()); + FixupSourceAndSinkEdges(g); +} + +TEST_F(ExecutorTest, RandomTree) { + std::unique_ptr g(new Graph(OpRegistry::Global())); + BuildTree(4096, g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT}); + TF_ASSERT_OK(call_frame.SetArgs({V(1.0)})); + TF_ASSERT_OK(Run(&call_frame)); + std::vector retvals; + TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false)); + EXPECT_EQ(4096.0, V(retvals[0])); +} + +TEST_F(ExecutorTest, OpError) { + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto zero = test::graph::Constant(g.get(), V(0.0)); + auto inf = test::graph::Unary(g.get(), "Reciprocal", zero); + auto check = test::graph::CheckNumerics(g.get(), inf, "message"); + auto two = test::graph::Constant(g.get(), V(2.0)); + test::graph::Binary(g.get(), "Mul", check, two); + FixupSourceAndSinkEdges(g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({}, {}); + // Fails due to invalid dtype. + EXPECT_TRUE(errors::IsInvalidArgument(Run(&call_frame))); +} + +static void BM_executor(int iters, int width, int depth) { +#ifdef PLATFORM_GOOGLE + BenchmarkUseRealTime(); +#endif // PLATFORM_GOOGLE + Graph* g = new Graph(OpRegistry::Global()); + random::PhiloxRandom philox(1729, 17); + random::SimplePhilox rand(&philox); + uint64 cur = 0; + uint32 r = 1 + rand.Rand32() % width; + std::vector ready_nodes; + for (int i = 0; i < r; ++i) { + ready_nodes.push_back(test::graph::NoOp(g, {})); + ++cur; + } + for (int i = 0; i < depth; ++i) { + std::random_shuffle(ready_nodes.begin(), ready_nodes.end()); + r = 1 + rand.Rand32() % (ready_nodes.size()); + std::vector control_inputs; + for (int j = 0; j < r; ++j) { + control_inputs.push_back(ready_nodes.back()); + ready_nodes.pop_back(); + } + Node* n = test::graph::NoOp(g, control_inputs); + ++cur; + r = 1 + rand.Rand32() % width; + for (int j = 0; j < r; ++j) { + ready_nodes.push_back(test::graph::NoOp(g, {n})); + ++cur; + } + } + FixupSourceAndSinkEdges(g); +#ifdef PLATFORM_GOOGLE + SetBenchmarkLabel(strings::StrCat("Nodes = ", cur)); + SetBenchmarkItemsProcessed(cur * static_cast(iters)); +#endif // PLATFORM_GOOGLE + test::Benchmark("cpu", g, nullptr, nullptr, nullptr, + "SINGLE_THREADED_EXECUTOR") + .Run(iters); +} + +// Tall skinny graphs +BENCHMARK(BM_executor)->ArgPair(16, 1024); +BENCHMARK(BM_executor)->ArgPair(32, 8192); + +// Short fat graphs +BENCHMARK(BM_executor)->ArgPair(1024, 16); +BENCHMARK(BM_executor)->ArgPair(8192, 32); + +// Tall fat graph +BENCHMARK(BM_executor)->ArgPair(1024, 1024); + +// TODO(mrry): This benchmark currently crashes with a use-after free, because +// test::Benchmark::RunWithArgs() assumes that the executor will take ownership +// of the given graph, *and* keep its nodes (`x`, `y` and `z`) alive for the +// duration of the benchmark. Since the single threaded executor does not retain +// a copy of the graph, this fails. +// +// TODO(mrry): Add support for Arg/Retval "function call convention" in +// `test::Benchmark::RunWithArgs()`. +#if 0 +#define ALICE "/job:j/replica:0/task:0/cpu:0" +#define BOB "/job:j/replica:0/task:0/gpu:0" + +static void BM_FeedInputFetchOutput(int iters) { + Graph* g = new Graph(OpRegistry::Global()); + // z = x + y: x and y are provided as benchmark inputs. z is the + // output of the benchmark. Conceptually, the caller is ALICE, the + // benchmark is BOB. + Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB); + Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB); + Node* sum = test::graph::Add(g, x, y); + Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE); + FixupSourceAndSinkEdges(g); + Tensor val(DT_FLOAT, TensorShape({})); + val.scalar()() = 3.14; + SetBenchmarkItemsProcessed(static_cast(iters)); + test::Benchmark("cpu", g, nullptr, nullptr, nullptr, + "SINGLE_THREADED_EXECUTOR") + .RunWithArgs({{x, val}, {y, val}}, {z}, iters); +} +BENCHMARK(BM_FeedInputFetchOutput); +#endif + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index e335e38bdc..82546d581a 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -161,9 +161,12 @@ void RestoreTensor(OpKernelContext* context, // If we cannot find a cached reader we will allocate our own. std::unique_ptr allocated_reader; - const checkpoint::TensorSliceReader* reader = - context->slice_reader_cache()->GetReader(file_pattern, open_func, - preferred_shard); + const checkpoint::TensorSliceReader* reader = nullptr; + + if (context->slice_reader_cache()) { + reader = context->slice_reader_cache()->GetReader(file_pattern, open_func, + preferred_shard); + } if (!reader) { allocated_reader.reset(new checkpoint::TensorSliceReader( file_pattern, open_func, preferred_shard)); diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index f03639e833..1a5ad8f421 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -198,6 +198,7 @@ REGISTER_OP("MapDataset") .Attr("Targuments: list(type) >= 0") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("use_inter_op_parallelism: bool = true") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("ParallelMapDataset") diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 52b4320bf1..df2c9b170a 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -711,57 +711,74 @@ class MapDatasetBenchmark(test.Benchmark): def benchmarkChainOfMaps(self): chain_lengths = [0, 1, 2, 5, 10, 20, 50] for chain_length in chain_lengths: - with ops.Graph().as_default(): - dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) - for _ in range(chain_length): - dataset = dataset.map(lambda x: x) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with session.Session() as sess: - for _ in range(5): - sess.run(next_element.op) - deltas = [] - for _ in range(100): - start = time.time() - for _ in range(100): + for use_inter_op_parallelism in [False, True]: + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors(0).repeat(None) + for _ in range(chain_length): + dataset = dataset_ops.MapDataset( + dataset, + lambda x: x, + use_inter_op_parallelism=use_inter_op_parallelism) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(5): sess.run(next_element.op) - end = time.time() - deltas.append(end - start) - - median_wall_time = np.median(deltas) / 100 - print("Map dataset chain length: %d Median wall time: %f" - % (chain_length, median_wall_time)) - self.report_benchmark( - iters=1000, wall_time=median_wall_time, - name="benchmark_map_dataset_chain_latency_%d" % chain_length) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element.op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + print("Map dataset chain length%s: %d Median wall time: %f" % + (" (single threaded mode)" if not use_inter_op_parallelism + else "", chain_length, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_dataset_chain_latency_%d%s" % + (chain_length, "_single_threaded" + if not use_inter_op_parallelism else "")) def benchmarkMapFanOut(self): fan_outs = [1, 2, 5, 10, 20, 50, 100] for fan_out in fan_outs: - with ops.Graph().as_default(): - dataset = dataset_ops.Dataset.from_tensors( - tuple(0 for _ in range(fan_out))).repeat(None).map(lambda *xs: xs) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with session.Session() as sess: - for _ in range(5): - sess.run(next_element[0].op) - deltas = [] - for _ in range(100): - start = time.time() - for _ in range(100): + for use_inter_op_parallelism in [False, True]: + with ops.Graph().as_default(): + dataset = dataset_ops.Dataset.from_tensors( + tuple(0 for _ in range(fan_out))).repeat(None) + dataset = dataset_ops.MapDataset( + dataset, + lambda *xs: xs, + use_inter_op_parallelism=use_inter_op_parallelism) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with session.Session() as sess: + for _ in range(5): sess.run(next_element[0].op) - end = time.time() - deltas.append(end - start) - - median_wall_time = np.median(deltas) / 100 - print("Map dataset fan out: %d Median wall time: %f" - % (fan_out, median_wall_time)) - self.report_benchmark( - iters=1000, wall_time=median_wall_time, - name="benchmark_map_dataset_fan_out_%d" % fan_out) + deltas = [] + for _ in range(100): + start = time.time() + for _ in range(100): + sess.run(next_element[0].op) + end = time.time() + deltas.append(end - start) + + median_wall_time = np.median(deltas) / 100 + print("Map dataset fan out%s: %d Median wall time: %f" % + (" (single threaded mode)" if not use_inter_op_parallelism + else "", fan_out, median_wall_time)) + self.report_benchmark( + iters=1000, + wall_time=median_wall_time, + name="benchmark_map_dataset_fan_out_%d%s" % + (fan_out, "_single_threaded" + if not use_inter_op_parallelism else "")) if __name__ == "__main__": diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 8c37b1871b..6205ee392e 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2207,10 +2207,11 @@ def _warn_if_collections(transformation_name): class MapDataset(Dataset): """A `Dataset` that maps a function over elements in its input.""" - def __init__(self, input_dataset, map_func): + def __init__(self, input_dataset, map_func, use_inter_op_parallelism=True): """See `Dataset.map()` for details.""" super(MapDataset, self).__init__() self._input_dataset = input_dataset + self._use_inter_op_parallelism = use_inter_op_parallelism wrapped_func = StructuredFunctionWrapper( map_func, "Dataset.map()", input_dataset) @@ -2225,6 +2226,7 @@ class MapDataset(Dataset): input_t, self._map_func.captured_inputs, f=self._map_func, + use_inter_op_parallelism=self._use_inter_op_parallelism, **flat_structure(self)) @property -- GitLab From d29eb6d1c9d1e4b2f601864f53878674f219fe6f Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Tue, 4 Sep 2018 14:03:08 -0700 Subject: [PATCH 072/540] Remove reference cycles when constructing distribution objects self -> _parameters -> self cycles were creating work for Python's garbage collector in training loops, where Distribution objects may be created repeatedly when executing eagerly. This CL just fixes that narrow memory issue; I'm not convinced dict(locals()) is super efficient, so we may want to follow up on that for performance. Adds a few unit tests tests with run_test_in_graph_and_eager_modes(assert_no_eager_garbage=True). It'd be nice to expand this coverage over time. Includes a small test_util simplification to support this (TFP tests don't like reset_default_graph for some reason). Testing for cycles in the TFP repo will need to wait on the Normal changes from the TF repo syncing. PiperOrigin-RevId: 211520394 --- tensorflow/python/framework/test_util.py | 19 ++++++++++--------- .../kernel_tests/distributions/normal_test.py | 4 ++-- .../python/ops/distributions/distribution.py | 18 ++++++++++++++++++ 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index b5388ad0b2..3b63e49a84 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -535,15 +535,16 @@ def assert_no_new_tensors(f): tensors_before = set( id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) - if context.executing_eagerly(): - f(self, **kwargs) - ops.reset_default_graph() - else: - # Run the test in a new graph so that collections get cleared when it's - # done, but inherit the graph key so optimizers behave. - outside_graph_key = ops.get_default_graph()._graph_key - with ops.Graph().as_default(): - ops.get_default_graph()._graph_key = outside_graph_key + outside_executed_eagerly = context.executing_eagerly() + # Run the test in a new graph so that collections get cleared when it's + # done, but inherit the graph key so optimizers behave. + outside_graph_key = ops.get_default_graph()._graph_key + with ops.Graph().as_default(): + ops.get_default_graph()._graph_key = outside_graph_key + if outside_executed_eagerly: + with context.eager_mode(): + f(self, **kwargs) + else: f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py index 7ff48c0c10..5dcd6f6df4 100644 --- a/tensorflow/python/kernel_tests/distributions/normal_test.py +++ b/tensorflow/python/kernel_tests/distributions/normal_test.py @@ -91,7 +91,7 @@ class NormalTest(test.TestCase): self._testParamStaticShapes( tensor_shape.TensorShape(sample_shape), sample_shape) - @test_util.run_in_graph_and_eager_modes + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNormalWithSoftplusScale(self): with self.test_session(): mu = array_ops.zeros((10, 3)) @@ -329,7 +329,7 @@ class NormalTest(test.TestCase): self.assertAllEqual(normal.batch_shape, entropy.get_shape()) self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) - @test_util.run_in_graph_and_eager_modes + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNormalMeanAndMode(self): with self.test_session(): # Mu will be broadcast to [7, 7, 7]. diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py index ddf9442cd2..578e7b7dd2 100644 --- a/tensorflow/python/ops/distributions/distribution.py +++ b/tensorflow/python/ops/distributions/distribution.py @@ -446,6 +446,24 @@ class Distribution(_BaseDistribution): self._graph_parents = graph_parents self._name = name + @property + def _parameters(self): + return self._parameter_dict + + @_parameters.setter + def _parameters(self, value): + """Intercept assignments to self._parameters to avoid reference cycles. + + Parameters are often created using locals(), so we need to clean out any + references to `self` before assigning it to an attribute. + + Args: + value: A dictionary of parameters to assign to the `_parameters` property. + """ + if "self" in value: + del value["self"] + self._parameter_dict = value + @classmethod def param_shapes(cls, sample_shape, name="DistributionParamShapes"): """Shapes of parameters given the desired shape of a call to `sample()`. -- GitLab From 5bb543dbac388e794133975c4108daa1ccbc55ca Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 4 Sep 2018 14:08:12 -0700 Subject: [PATCH 073/540] [XLA] Add a test case for propagating the result layout of a non-elementwise HLO instruction to its operands. PiperOrigin-RevId: 211521410 --- .../xla/service/layout_assignment_test.cc | 76 +++++++++++++------ 1 file changed, 52 insertions(+), 24 deletions(-) diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 021fe630ff..69c7e42601 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -874,18 +874,18 @@ TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) { )"; auto module = ParseHloString(module_str).ValueOrDie(); - module = + auto compiled_module = backend() .compiler() ->RunHloPasses(std::move(module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - - auto copy = FindInstruction(module.get(), "copy.1"); - auto slice = FindInstruction(module.get(), "slice0"); - EXPECT_EQ(slice->operand(0), copy); - EXPECT_TRUE( - LayoutUtil::Equal(slice->shape().layout(), copy->shape().layout())); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); + EXPECT_THAT(root, op::Add(op::Parameter(), + op::Slice(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy))))); } TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { @@ -902,18 +902,20 @@ TEST_F(LayoutAssignmentTest, CopyDSliceOperandToAvoidImplicitLayoutChange) { )"; auto module = ParseHloString(module_str).ValueOrDie(); - module = + auto compiled_module = backend() .compiler() ->RunHloPasses(std::move(module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - - auto copy = FindInstruction(module.get(), "copy.1"); - auto dslice = FindInstruction(module.get(), "dslice0"); - EXPECT_EQ(dslice->operand(0), copy); - EXPECT_TRUE( - LayoutUtil::Equal(dslice->shape().layout(), copy->shape().layout())); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {1, 0}); + EXPECT_THAT(root, + op::Add(op::Parameter(), + op::DynamicSlice(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy)), + op::Parameter(2)))); } TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { @@ -931,18 +933,20 @@ TEST_F(LayoutAssignmentTest, CopyConcatOperandToAvoidImplicitLayoutChange) { )"; auto module = ParseHloString(module_str).ValueOrDie(); - module = + auto compiled_module = backend() .compiler() ->RunHloPasses(std::move(module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); - - auto copy = FindInstruction(module.get(), "copy.1"); - auto concat = FindInstruction(module.get(), "concat0"); - EXPECT_EQ(concat->operand(0), copy); - EXPECT_TRUE( - LayoutUtil::Equal(concat->shape().layout(), copy->shape().layout())); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {3, 5}, {1, 0}); + EXPECT_THAT(root, + op::Add(op::Parameter(), + op::Concatenate(AllOf(op::Copy(op::Parameter(1)), + op::ShapeWithLayout(shape_copy)), + op::Parameter(2)))); } TEST_F(LayoutAssignmentTest, @@ -960,15 +964,39 @@ TEST_F(LayoutAssignmentTest, )"; auto module = ParseHloString(module_str).ValueOrDie(); - module = + auto compiled_module = backend() .compiler() ->RunHloPasses(std::move(module), backend().default_stream_executor(), /*device_allocator=*/nullptr) .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Convolution(op::Parameter(0), op::Parameter(1))); +} + +TEST_F(LayoutAssignmentTest, PropagatingLayoutFromResultToOperand) { + const char* module_str = R"( + HloModule PropagatingLayoutFromResultToOperand + + ENTRY PropagatingLayoutFromResultToOperand { + par0 = f32[4,5]{1,0} parameter(0) + ROOT slice0 = f32[3,4]{0,1} slice(par0), slice={[1:4],[1:5]} + } + )"; - auto copy = FindInstruction(module.get(), "copy.1"); - EXPECT_EQ(copy, nullptr); + auto module = ParseHloString(module_str).ValueOrDie(); + auto compiled_module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + HloInstruction* root = + compiled_module->entry_computation()->root_instruction(); + Shape shape_copy = ShapeUtil::MakeShapeWithLayout(F32, {4, 5}, {0, 1}); + EXPECT_THAT(root, op::Slice(AllOf(op::Copy(op::Parameter(0)), + op::ShapeWithLayout(shape_copy)))); } } // namespace -- GitLab From ed643f5522774d8dcb98530cf241e94a86ae88c2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 14:18:49 -0700 Subject: [PATCH 074/540] Add unit test that shows how to use foldl with inputs that have different shapes. PiperOrigin-RevId: 211523104 --- tensorflow/python/kernel_tests/functional_ops_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 1e76ad7476..7739b13143 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -93,6 +93,15 @@ class FunctionalOpsTest(test.TestCase): initializer) self.assertAllEqual(1, self.evaluate(r)) + @test_util.run_in_graph_and_eager_modes + def testFoldl_MultiInputDifferentDimsSingleOutput(self): + elems = np.array([[1.0, 1.0, 1.0], [2.0, 3.0, 4.0]]) + other_elems = np.array([-1.0, 1.0]) + initializer = np.array([0.0, 0.0, 0.0]) + r = functional_ops.foldl(lambda a, x: a + x[0] * x[1], + (elems, other_elems), initializer) + self.assertAllEqual([1.0, 2.0, 3.0], self.evaluate(r)) + def testFoldl_Scoped(self): with self.test_session() as sess: with variable_scope.variable_scope("root") as varscope: -- GitLab From 9bea7a8aa991b63f7349514a5a2dc0d04d261f8f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 14:28:37 -0700 Subject: [PATCH 075/540] Add support for Softmax of 3D tensors PiperOrigin-RevId: 211524810 --- .../contrib/lite/kernels/activations.cc | 36 +++++++++- .../contrib/lite/kernels/activations_test.cc | 70 +++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 9c891fe904..5cdd9fc94f 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -200,7 +200,7 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, input->type, output->type); const int num_dims = NumDimensions(input); - TF_LITE_ENSURE(context, num_dims == 1 || num_dims == 2 || num_dims == 4); + TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4); if (input->type == kTfLiteUInt8) { TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); @@ -453,6 +453,19 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output, Softmax(input->data.f, input_size, batch_size, params->beta, output->data.f); } +// Takes a 3D tensor and perform softmax along the last dimension. +void Softmax3DFloat(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int batch_size = input->dims->data[0]; + const int intermediate_size = input->dims->data[1]; + const int input_size = input->dims->data[2]; + optimized_ops::Softmax( + GetTensorData(input), + GetTensorShape({batch_size, intermediate_size, 1, input_size}), + params->beta, GetTensorData(output), + GetTensorShape({batch_size, intermediate_size, 1, input_size})); +} + void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params, OpData* data) { // TODO(ahentz): this is arguably a dirty trick. Since the implementation @@ -480,6 +493,19 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output, GetTensorShape({batch_size, 1, 1, input_size})); } +void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + const int batch_size = input->dims->data[0]; + const int intermediate_size = input->dims->data[1]; + const int input_size = input->dims->data[2]; + optimized_ops::Softmax( + GetTensorData(input), + GetTensorShape({batch_size, intermediate_size, 1, input_size}), + data->input_multiplier, data->input_left_shift, data->diff_min, + GetTensorData(output), + GetTensorShape({batch_size, intermediate_size, 1, input_size})); +} + // Takes a 4D tensor and perform softmax along the forth dimension. void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output, TfLiteSoftmaxParams* params) { @@ -515,6 +541,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { Softmax2DFloat(input, output, params); return kTfLiteOk; } + if (NumDimensions(input) == 3) { + Softmax3DFloat(input, output, params); + return kTfLiteOk; + } if (NumDimensions(input) == 4) { Softmax4DFloat(input, output, params); return kTfLiteOk; @@ -533,6 +563,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { Softmax2DQuantized(input, output, params, data); return kTfLiteOk; } + if (NumDimensions(input) == 3) { + Softmax3DQuantized(input, output, params, data); + return kTfLiteOk; + } if (NumDimensions(input) == 4) { Softmax4DQuantized(input, output, params, data); return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index e577e3a762..9fa47e190a 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -339,6 +339,76 @@ TEST(QuantizedActivationsOpTest, Softmax4D) { kQuantizedTolerance))); } +TEST(FloatActivationsOpTest, Softmax3D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {1, 2, 4}}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(0.1, + /*input=*/{TensorType_FLOAT32, {4, 1, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }))); +} + +TEST(QuantizedActivationsOpTest, Softmax3D) { + QuantizedActivationsOpModel m( + 0.1, + /*input=*/{TensorType_UINT8, {1, 2, 4}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2( + 0.1, + /*input=*/{TensorType_UINT8, {4, 1, 2}, -10, 10}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + TEST(FloatActivationsOpTest, Softmax1D) { FloatActivationsOpModel m(0.1, /*input=*/{TensorType_FLOAT32, {8}}); -- GitLab From ee24255e3dddae6c1d1cf44f6cf800883015fc8e Mon Sep 17 00:00:00 2001 From: Michael Case Date: Tue, 4 Sep 2018 15:04:21 -0700 Subject: [PATCH 076/540] Internal Change PiperOrigin-RevId: 211531374 --- tensorflow/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index b5e0a4e98b..661cba5ff0 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -433,6 +433,7 @@ package_group( "-//third_party/tensorflow/python/estimator", "//learning/meta_rank/...", "//tensorflow/...", + "//tensorflow_estimator/...", "//tensorflow_fold/llgtm/...", "//third_party/py/tensor2tensor/...", ], -- GitLab From 4fbc4e5b9833fb1936250d8a52aad57e7c7469e2 Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Tue, 4 Sep 2018 15:11:28 -0700 Subject: [PATCH 077/540] Automatically use single core for stateful RNN in Keras TPU. PiperOrigin-RevId: 211532963 --- .../contrib/tpu/python/tpu/keras_support.py | 132 ++++++++++++------ 1 file changed, 89 insertions(+), 43 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index ff88508d03..dd7f8b678f 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -170,11 +170,41 @@ class TPUDistributionStrategy(object): worker_re = re.compile('/job:([^/]+)') for device in metadata.devices: if 'TPU:0' in device.name: - self.worker_name = worker_re.search(device.name).group(1) + self._worker_name = worker_re.search(device.name).group(1) break + def _make_assignment_for_model(self, cpu_model): + """Makes a `TPUAssignment` for the passed in `cpu_model`.""" + num_cores = self._num_cores + if num_cores > 1 and cpu_model.stateful: + logging.warning( + 'Model replication does not currently support stateful models. ' + 'Degrading to a single core.') + num_cores = 1 + + return TPUAssignment( + worker_name=self._worker_name, num_cores=num_cores) + + +class TPUAssignment(object): + """This is object holding TPU resources assignment for the concrete model. + + `TPUDistributionStrategy` is responsible to create the instance of + `TPUAssignment`, so, it can dynamically adjust the `num_cores` to use based on + model and input batch sizes. + """ + + def __init__(self, worker_name, num_cores): + self._worker_name = worker_name + self._num_cores = num_cores + + @property + def worker_name(self): + return self._worker_name + @property def num_towers(self): + # TODO(xiejw): Support automatically assign num_cores based on inputs. return self._num_cores @@ -495,8 +525,8 @@ class TPUNumpyInfeedManager(TPUInfeedManager): infeed_dict[tensor] = value return infeed_dict - def __init__(self, distribution_strategy): - self._strategy = distribution_strategy + def __init__(self, tpu_assignment): + self._tpu_assignment = tpu_assignment def _split_tensors(self, inputs): """Split input data across shards. @@ -509,16 +539,16 @@ class TPUNumpyInfeedManager(TPUInfeedManager): Returns: List of lists containing the input to feed to each TPU shard. """ - if self._strategy.num_towers == 1: + if self._tpu_assignment.num_towers == 1: return [inputs] batch_size = inputs[0].shape[0] - assert batch_size % self._strategy.num_towers == 0, ( - 'batch_size must be divisible by strategy.num_towers (%s vs %s)' % - (batch_size, self._strategy.num_towers)) - shard_size = batch_size // self._strategy.num_towers + assert batch_size % self._tpu_assignment.num_towers == 0, ( + 'batch_size must be divisible by the number of TPU cores in use (%s ' + 'vs %s)' % (batch_size, self._tpu_assignment.num_towers)) + shard_size = batch_size // self._tpu_assignment.num_towers input_list = [] - for index in range(self._strategy.num_towers): + for index in range(self._tpu_assignment.num_towers): shard_inputs = [ x[index * shard_size:(index + 1) * shard_size] for x in inputs ] @@ -533,8 +563,9 @@ class TPUNumpyInfeedManager(TPUInfeedManager): infeed_op = [] shard_infeed_tensors = [] - for shard_id in range(self._strategy.num_towers): - with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name): + for shard_id in range(self._tpu_assignment.num_towers): + with ops.device( + '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): infeed_tensors = [] with ops.device('/device:TPU:%d' % shard_id): for spec in input_specs: @@ -573,30 +604,31 @@ class TPUDatasetInfeedManager(TPUInfeedManager): # TODO(saeta): Verify tpu_model_op is as expected! return {} - def __init__(self, dataset, distribution_strategy, tpu_session): + # pylint: disable=redefined-outer-name + def __init__(self, dataset, tpu_assignment, tpu_session): """Constructs a TPUDatasetInfeedManager. Must be called within a `KerasTPUModel.tpu_session` context! Args: dataset: A `tf.data.Dataset` to infeed. - distribution_strategy: The `TPUDistributionStrategy` used to configure the + tpu_assignment: The `TPUAssignment` used to configure the Keras TPU model. tpu_session: The `tf.Session` object used for running the TPU model. """ self._verify_dataset_shape(dataset) self._dataset = dataset - self._strategy = distribution_strategy + self._tpu_assignment = tpu_assignment dummy_x_shape = dataset.output_shapes[0].as_list() - dummy_x_shape[0] *= distribution_strategy.num_towers + dummy_x_shape[0] *= tpu_assignment.num_towers dummy_y_shape = dataset.output_shapes[1].as_list() - dummy_y_shape[0] *= distribution_strategy.num_towers + dummy_y_shape[0] *= tpu_assignment.num_towers self._iterator = dataset.make_initializable_iterator() tpu_session.run(self._iterator.initializer) self._get_next_ops = [] ctrl_deps = [] - for i in range(distribution_strategy.num_towers): + for i in range(tpu_assignment.num_towers): with ops.control_dependencies(ctrl_deps): # Ensure deterministic # TODO(saeta): Ensure correct placement! get_next_op = self._iterator.get_next() @@ -676,10 +708,11 @@ class TPUDatasetInfeedManager(TPUInfeedManager): def build_infeed_from_input_specs(self, input_specs, execution_mode): shard_infeed_tensors = self._get_next_ops - assert len(shard_infeed_tensors) == self._strategy.num_towers + assert len(shard_infeed_tensors) == self._tpu_assignment.num_towers infeed_ops = [] - for shard_id in range(self._strategy.num_towers): - with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name): + for shard_id in range(self._tpu_assignment.num_towers): + with ops.device( + '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): infeed_ops.append( tpu_ops.infeed_enqueue_tuple( shard_infeed_tensors[shard_id], @@ -702,10 +735,10 @@ class TPUFunction(object): instead of being injected as `feed_dict` items or fetches. """ - def __init__(self, model, execution_mode, strategy): + def __init__(self, model, execution_mode, tpu_assignment): self.model = model self.execution_mode = execution_mode - self._strategy = strategy + self._tpu_assignment = tpu_assignment self._compilation_cache = {} self._cloned_model = None @@ -757,7 +790,8 @@ class TPUFunction(object): # Clone our CPU model, running within the TPU device context. with TPURewriteContext(tpu_input_map): with variable_scope.variable_scope('tpu_model_%s' % id(self.model)): - with keras_tpu_variables.replicated_scope(self._strategy.num_towers): + with keras_tpu_variables.replicated_scope( + self._tpu_assignment.num_towers): self._cloned_model = models.clone_model(self.model) # Create a copy of the optimizer for this graph. @@ -827,7 +861,7 @@ class TPUFunction(object): # `execute op` replicates `_model_fn` `num_replicas` times, with each shard # running on a different logical core. compile_op, execute_op = tpu.split_compile_and_replicate( - _model_fn, inputs=[[]] * self._strategy.num_towers) + _model_fn, inputs=[[]] * self._tpu_assignment.num_towers) # Generate CPU side operations to enqueue features/labels and dequeue # outputs from the model call. @@ -835,8 +869,9 @@ class TPUFunction(object): input_specs, self.execution_mode) # Build output ops. outfeed_op = [] - for shard_id in range(self._strategy.num_towers): - with ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name): + for shard_id in range(self._tpu_assignment.num_towers): + with ops.device( + '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): outfeed_op.extend( tpu_ops.outfeed_dequeue_tuple( dtypes=[spec.dtype for spec in self._outfeed_spec], @@ -886,7 +921,7 @@ class TPUFunction(object): for x, mgr in self.model._numpy_to_infeed_manager_list: if inputs[0] is x: return mgr - return TPUNumpyInfeedManager(self.model._strategy) + return TPUNumpyInfeedManager(self.model._tpu_assignment) def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager): """Looks up the corresponding `TPUModelOp` for a given `input_specs`. @@ -958,7 +993,7 @@ class TPUFunction(object): outputs = [[]] * len(self._outfeed_spec) outputs_per_replica = len(self._outfeed_spec) - for i in range(self._strategy.num_towers): + for i in range(self._tpu_assignment.num_towers): output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) * outputs_per_replica] for j in range(outputs_per_replica): @@ -967,7 +1002,7 @@ class TPUFunction(object): return [np.concatenate(group) for group in outputs] else: return outfeed_outputs[:len(outfeed_outputs) // - self._strategy.num_towers] + self._tpu_assignment.num_towers] def __call__(self, inputs): """__call__ executes the function on the computational hardware. @@ -1119,11 +1154,11 @@ class KerasTPUModel(models.Model): self.predict_function = None self.test_function = None self.train_function = None - self._strategy = strategy - cluster_resolver = self._strategy._tpu_cluster_resolver + cluster_resolver = strategy._tpu_cluster_resolver self._tpu_name_or_address = cluster_resolver.get_master() self._cpu_model = cpu_model + self._tpu_assignment = strategy._make_assignment_for_model(cpu_model) self._tpu_model = None self._tpu_weights_initialized = False @@ -1146,7 +1181,7 @@ class KerasTPUModel(models.Model): return { 'cpu_model': self._cpu_model, 'tpu_name_or_address': self._tpu_name_or_address, - 'strategy': self._strategy, + 'tpu_assignment': self._tpu_assignment, } def compile(self, @@ -1207,7 +1242,7 @@ class KerasTPUModel(models.Model): '/keras') if callable(x): with self.tpu_session() as sess,\ - ops.device('/job:%s/device:CPU:0' % self._strategy.worker_name): + ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): dataset = x() if steps_per_epoch is None: raise ValueError('When using tf.data as input to a model, you ' @@ -1215,7 +1250,8 @@ class KerasTPUModel(models.Model): if y is not None: raise ValueError('When using tf.data as input to a model, y must be ' 'None') - infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, + sess) # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. x = infeed_manager.dummy_x @@ -1236,7 +1272,8 @@ class KerasTPUModel(models.Model): if validation_steps is None: raise ValueError('When using tf.data as validation for a model, you ' 'should specify the validation_steps argument.') - infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, + sess) # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. val_x = infeed_manager.dummy_x @@ -1313,7 +1350,8 @@ class KerasTPUModel(models.Model): if y is not None: raise ValueError('When using tf.data as input to a model, y must be ' 'None') - infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess) + infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, + sess) # Use dummy numpy inputs for the rest of Keras' shape checking. We # intercept them when building the model. x = infeed_manager.dummy_x @@ -1740,20 +1778,24 @@ class KerasTPUModel(models.Model): def _make_train_function(self): if not self.train_function: self.train_function = TPUFunction( - self, model_fn_lib.ModeKeys.TRAIN, strategy=self._strategy) + self, + model_fn_lib.ModeKeys.TRAIN, + tpu_assignment=self._tpu_assignment) return self.train_function def _make_test_function(self): if not self.test_function: self.test_function = TPUFunction( - self, model_fn_lib.ModeKeys.EVAL, strategy=self._strategy) + self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment) return self.test_function def _make_predict_function(self): if not self.predict_function: self.predict_function = TPUFunction( - self, model_fn_lib.ModeKeys.PREDICT, strategy=self._strategy) + self, + model_fn_lib.ModeKeys.PREDICT, + tpu_assignment=self._tpu_assignment) return self.predict_function def _initialize_weights(self, cloned_model): @@ -1825,6 +1867,7 @@ class KerasTPUModel(models.Model): self._session.close() +# pylint: disable=bad-continuation def _validate_shapes(model): """Validate that all layers in `model` have constant shape.""" for layer in model.layers: @@ -1852,10 +1895,13 @@ Layer: %(layer)s Input shape: %(input_shape)s Output shape: %(output_shape)s """ % { - 'layer': layer, - 'input_shape': layer.input_shape, - 'output_shape': layer.output_shape - }) + 'layer': layer, + 'input_shape': layer.input_shape, + 'output_shape': layer.output_shape + }) + + +# pylint: enable=bad-continuation @experimental -- GitLab From 5b576291e3ba981249d2666d9061b92725d703c2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 15:18:15 -0700 Subject: [PATCH 078/540] Update ops-related pbtxt files. PiperOrigin-RevId: 211534283 --- .../core/ops/compat/ops_history.v1.pbtxt | 43 +++++++++++++++++++ tensorflow/core/ops/ops.pbtxt | 7 +++ 2 files changed, 50 insertions(+) diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index cb0cb46752..9836f784ab 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -29380,6 +29380,49 @@ op { minimum: 1 } } +op { + name: "MapDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } +} op { name: "MapDefun" input_arg { diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 4419f93d0c..28b25fdeae 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -14542,6 +14542,13 @@ op { has_minimum: true minimum: 1 } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } } op { name: "MapDefun" -- GitLab From 5cb997a35383bc2832be5a415d72aa950374ebfa Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 4 Sep 2018 15:28:17 -0700 Subject: [PATCH 079/540] Sort namedtuple fields PiperOrigin-RevId: 211535930 --- tensorflow/tools/docs/parser.py | 26 +++++++++++++++- tensorflow/tools/docs/parser_test.py | 46 +++++++++++++++++++++++++++- tensorflow/tools/docs/pretty_docs.py | 2 +- 3 files changed, 71 insertions(+), 3 deletions(-) diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 997afc6ac7..549056c6c4 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -947,6 +947,7 @@ class _ClassPageInfo(object): self._aliases = None self._doc = None self._guides = None + self._namedtuplefields = None self._bases = None self._properties = [] @@ -1029,6 +1030,17 @@ class _ClassPageInfo(object): assert self.guides is None self._guides = guides + @property + def namedtuplefields(self): + return self._namedtuplefields + + def set_namedtuplefields(self, py_class): + if issubclass(py_class, tuple): + if all( + hasattr(py_class, attr) + for attr in ('_asdict', '_fields', '_make', '_replace')): + self._namedtuplefields = py_class._fields + @property def bases(self): """Returns a list of `_LinkInfo` objects pointing to the class' parents.""" @@ -1066,7 +1078,15 @@ class _ClassPageInfo(object): @property def properties(self): """Returns a list of `_PropertyInfo` describing the class' properties.""" - return self._properties + props_dict = {prop.short_name: prop for prop in self._properties} + props = [] + if self.namedtuplefields: + for field in self.namedtuplefields: + props.append(props_dict.pop(field)) + + props.extend(sorted(props_dict.values())) + + return props def _add_property(self, short_name, full_name, obj, doc): """Adds a `_PropertyInfo` entry to the `properties` list. @@ -1077,6 +1097,9 @@ class _ClassPageInfo(object): obj: The property object itself doc: The property's parsed docstring, a `_DocstringInfo`. """ + # Hide useless namedtuple docs-trings + if re.match('Alias for field number [0-9]+', doc.docstring): + doc = doc._replace(docstring='', brief='') property_info = _PropertyInfo(short_name, full_name, obj, doc) self._properties.append(property_info) @@ -1156,6 +1179,7 @@ class _ClassPageInfo(object): py_class: The class object being documented parser_config: An instance of ParserConfig. """ + self.set_namedtuplefields(py_class) doc_path = documentation_path(self.full_name) relative_path = os.path.relpath( path='.', start=os.path.dirname(doc_path) or '.') diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py index 9f6b185e81..71e96afa10 100644 --- a/tensorflow/tools/docs/parser_test.py +++ b/tensorflow/tools/docs/parser_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import functools import os import sys @@ -190,6 +191,50 @@ class ParserTest(googletest.TestCase): # Make sure this file is contained as the definition location. self.assertEqual(os.path.relpath(__file__, '/'), page_info.defined_in.path) + def test_namedtuple_field_order(self): + namedtupleclass = collections.namedtuple('namedtupleclass', + {'z', 'y', 'x', 'w', 'v', 'u'}) + + index = { + 'namedtupleclass': namedtupleclass, + 'namedtupleclass.u': namedtupleclass.u, + 'namedtupleclass.v': namedtupleclass.v, + 'namedtupleclass.w': namedtupleclass.w, + 'namedtupleclass.x': namedtupleclass.x, + 'namedtupleclass.y': namedtupleclass.y, + 'namedtupleclass.z': namedtupleclass.z, + } + + visitor = DummyVisitor(index=index, duplicate_of={}) + + reference_resolver = parser.ReferenceResolver.from_visitor( + visitor=visitor, doc_index={}, py_module_names=['tf']) + + tree = {'namedtupleclass': {'u', 'v', 'w', 'x', 'y', 'z'}} + parser_config = parser.ParserConfig( + reference_resolver=reference_resolver, + duplicates={}, + duplicate_of={}, + tree=tree, + index=index, + reverse_index={}, + guide_index={}, + base_dir='/') + + page_info = parser.docs_for_object( + full_name='namedtupleclass', + py_object=namedtupleclass, + parser_config=parser_config) + + # Each namedtiple field has a docstring of the form: + # 'Alias for field number ##'. These props are returned sorted. + + def sort_key(prop_info): + return int(prop_info.obj.__doc__.split(' ')[-1]) + + self.assertSequenceEqual(page_info.properties, + sorted(page_info.properties, key=sort_key)) + def test_docs_for_class_should_skip(self): class Parent(object): @@ -736,6 +781,5 @@ class TestGenerateSignature(googletest.TestCase): sig = parser._generate_signature(example_fun, reverse_index={}) self.assertEqual(sig, ['arg1=a.b.c.d', 'arg2=a.b.c.d(1, 2)', "arg3=e['f']"]) - if __name__ == '__main__': googletest.main() diff --git a/tensorflow/tools/docs/pretty_docs.py b/tensorflow/tools/docs/pretty_docs.py index aecf753a58..448f246e0e 100644 --- a/tensorflow/tools/docs/pretty_docs.py +++ b/tensorflow/tools/docs/pretty_docs.py @@ -136,7 +136,7 @@ def _build_class_page(page_info): if page_info.properties: parts.append('## Properties\n\n') - for prop_info in sorted(page_info.properties): + for prop_info in page_info.properties: h3 = '

{short_name}

\n\n' parts.append(h3.format(short_name=prop_info.short_name)) -- GitLab From d72b4c0d4972c7da2a226c9692dbbd450cac4959 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Camargo?= Date: Tue, 4 Sep 2018 19:55:24 -0300 Subject: [PATCH 080/540] LSTMCell base article at rnn_cell_impl.py resubmitting to master branch For discussion see https://github.com/tensorflow/tensorflow/pull/22035 --- tensorflow/python/ops/rnn_cell_impl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index fa13568596..e8698c6359 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -783,10 +783,10 @@ class LSTMCell(LayerRNNCell): The default non-peephole implementation is based on: - http://www.bioinf.jku.at/publications/older/2604.pdf + https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf - S. Hochreiter and J. Schmidhuber. - "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + Felix Gers, Jürgen Schmidhuber, and Fred Cummins. + "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999. The peephole implementation is based on: -- GitLab From 69753ba5dbe5950639efc1b5e065901651cd8973 Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Tue, 4 Sep 2018 15:57:55 -0700 Subject: [PATCH 081/540] Create a way to serialize Interpreter data to a flatbuffer. PiperOrigin-RevId: 211540844 --- .../contrib/lite/experimental/writer/BUILD | 64 +++ .../lite/experimental/writer/enum_mapping.h | 116 ++++++ .../writer/option_writer_generator.cc | 370 ++++++++++++++++++ .../lite/experimental/writer/writer.cc | 41 ++ .../lite/experimental/writer/writer_lib.cc | 281 +++++++++++++ .../lite/experimental/writer/writer_lib.h | 126 ++++++ .../experimental/writer/writer_lib_test.cc | 62 +++ tensorflow/contrib/lite/op_resolver.cc | 2 + tensorflow/contrib/lite/schema/BUILD | 14 + 9 files changed, 1076 insertions(+) create mode 100644 tensorflow/contrib/lite/experimental/writer/BUILD create mode 100644 tensorflow/contrib/lite/experimental/writer/enum_mapping.h create mode 100644 tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc create mode 100644 tensorflow/contrib/lite/experimental/writer/writer.cc create mode 100644 tensorflow/contrib/lite/experimental/writer/writer_lib.cc create mode 100644 tensorflow/contrib/lite/experimental/writer/writer_lib.h create mode 100644 tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc diff --git a/tensorflow/contrib/lite/experimental/writer/BUILD b/tensorflow/contrib/lite/experimental/writer/BUILD new file mode 100644 index 0000000000..d43964208b --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/BUILD @@ -0,0 +1,64 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +cc_binary( + name = "option_writer_generator", + srcs = ["option_writer_generator.cc"], + deps = [ + "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection", + "@flatbuffers", + ], +) + +cc_library( + name = "writer_lib", + srcs = [ + "enum_mapping.h", + "writer_lib.cc", + ], + hdrs = [ + "writer_lib.h", + ], + textual_hdrs = ["option_writer_generated.h"], + deps = [ + ":option_writer_gen", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection", + ], +) + +cc_binary( + name = "writer", + srcs = ["writer.cc"], + deps = [ + ":writer_lib", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) + +cc_test( + name = "writer_lib_test", + size = "small", + srcs = ["writer_lib_test.cc"], + deps = [ + ":writer_lib", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/testing:util", + "//testing/base/public:gunit", + ], +) + +genrule( + name = "option_writer_gen", + outs = ["option_writer_generated.h"], + cmd = "$(location :option_writer_generator) $(@)", + tools = [":option_writer_generator"], +) diff --git a/tensorflow/contrib/lite/experimental/writer/enum_mapping.h b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h new file mode 100644 index 0000000000..8bc464fd71 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h @@ -0,0 +1,116 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" + +// TODO(aselle): Ideally extract this from the schema. + +namespace tflite { + +inline ActivationFunctionType TfLiteActivationToSchemaActivation( + TfLiteFusedActivation act) { + switch (act) { + case kTfLiteActNone: + return ActivationFunctionType_NONE; + case kTfLiteActRelu: + return ActivationFunctionType_RELU; + case kTfLiteActRelu1: + return ActivationFunctionType_RELU_N1_TO_1; + case kTfLiteActRelu6: + return ActivationFunctionType_RELU6; + case kTfLiteActTanh: + return ActivationFunctionType_TANH; + case kTfLiteActSignBit: + return ActivationFunctionType_SIGN_BIT; + case kTfLiteActSigmoid: + return ActivationFunctionType_NONE; // TODO(aselle): Add to schema + } + return ActivationFunctionType_NONE; +} + +inline Padding TfLitePaddingToSchemaPadding(TfLitePadding padding) { + switch (padding) { + case kTfLitePaddingUnknown: + return Padding_SAME; // TODO(aselle): Consider an error. + case kTfLitePaddingSame: + return Padding_SAME; + case kTfLitePaddingValid: + return Padding_VALID; + } + return Padding_SAME; // TODO(aselle): Consider an error. +} + +inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { + switch (type) { + // case kTfLiteNoType: return TensorType_NONE; + case kTfLiteNoType: + return TensorType_FLOAT32; // TODO(aselle): Consider an error. + case kTfLiteFloat32: + return TensorType_FLOAT32; + case kTfLiteInt32: + return TensorType_INT32; + case kTfLiteUInt8: + return TensorType_UINT8; + case kTfLiteInt64: + return TensorType_INT64; + case kTfLiteString: + return TensorType_STRING; + case kTfLiteBool: + return TensorType_BOOL; + case kTfLiteInt16: + return TensorType_INT16; + case kTfLiteComplex64: + return TensorType_COMPLEX64; + } + // TODO(aselle): consider an error +} + +inline FullyConnectedOptionsWeightsFormat +FullyConnectedOptionsWeightsFormatToSchema( + TfLiteFullyConnectedWeightsFormat format) { + switch (format) { + case kTfLiteFullyConnectedWeightsFormatDefault: + return FullyConnectedOptionsWeightsFormat_DEFAULT; + case kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8: + return FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; + } +} + +inline LSTMKernelType LSTMKernelTypeToSchema(TfLiteLSTMKernelType type) { + switch (type) { + case kTfLiteLSTMFullKernel: + return LSTMKernelType_FULL; + case kTfLiteLSTMBasicKernel: + return LSTMKernelType_BASIC; + } +} + +inline LSHProjectionType LSHProjectionTypeToSchema( + TfLiteLSHProjectionType type) { + switch (type) { + case kTfLiteLshProjectionUnknown: + return LSHProjectionType_UNKNOWN; + case kTfLiteLshProjectionSparse: + return LSHProjectionType_SPARSE; + case kTfLiteLshProjectionDense: + return LSHProjectionType_DENSE; + } +} + +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc new file mode 100644 index 0000000000..e6d5a776b3 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc @@ -0,0 +1,370 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include "flatbuffers/minireflect.h" // flatbuffers +#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" + +namespace tflite { +namespace { +// This is generated by grepping +// cat third_party/tensorflow/contrib/lite/builtin_op_data.h +//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}" +static const char* param_structs[] = {"TfLiteConvParams", + "TfLitePoolParams", + "TfLiteDepthwiseConvParams", + "TfLiteSVDFParams", + "TfLiteRNNParams", + "TfLiteSequenceRNNParams", + "TfLiteFullyConnectedParams", + "TfLiteLSHProjectionParams", + "TfLiteSoftmaxParams", + "TfLiteConcatenationParams", + "TfLiteAddParams", + "TfLiteSpaceToBatchNDParams", + "TfLiteBatchToSpaceNDParams", + "TfLiteMulParams", + "TfLiteSubParams", + "TfLiteDivParams", + "TfLiteL2NormParams", + "TfLiteLocalResponseNormParams", + "TfLiteLSTMParams", + "TfLiteResizeBilinearParams", + "TfLitePadParams", + "TfLitePadV2Params", + "TfLiteReshapeParams", + "TfLiteSkipGramParams", + "TfLiteSpaceToDepthParams", + "TfLiteCastParams", + "TfLiteEmbeddingLookupSparseParams", + "TfLiteGatherParams", + "TfLiteTransposeParams", + "TfLiteReducerParams", + "TfLiteSplitParams", + "TfLiteSqueezeParams", + "TfLiteStridedSliceParams", + "TfLiteArgMaxParams", + "TfLiteArgMinParams", + "TfLiteTransposeConvParams", + "TfLiteSparseToDenseParams", + "TfLiteShapeParams", + "TfLiteFakeQuantParams", + "TfLitePackParams", + "TfLiteOneHotParams", + nullptr}; +} // namespace + +// Get rid of all underscores and make everything lower case to make name +// matching work for stuff like 3D vs 3d or RNN vs Rnn. +std::string ToCollapsed(const std::string& in) { + const char* s = in.c_str(); + bool first = true; + std::string out; + while (*s != '\0') { + if (*s == '_') { + first = true; + } else if (first) { + out.push_back(tolower(*s)); + first = false; + } else { + out.push_back(tolower(*s)); + } + s++; + } + return out; +} + +// A collection of information about builtin ops. +class OpOptionData { + public: + OpOptionData() { + BuildOpList(); + BuildOptionToTypeFunctionMap(); + BuildOpToOptionMap(); + } + + // A list of builtin operations + const std::vector& ops() const { return ops_; } + // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions') + const std::unordered_map& op_to_option() { + return op_to_option_; + } + // Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions' + const std::unordered_map& option_to_struct() { + return option_to_struct_; + } + // Maps from option to a flatbuffer type function that describes that option. + const std::unordered_map& + option_to_type_function() { + return option_to_type_function_; + } + + private: + void BuildOpList() { + for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr; + ++curr) { + if (strlen(*curr) != 0) ops_.push_back(*curr); + } + } + + void BuildOptionToTypeFunctionMap() { + auto d = tflite::BuiltinOptionsTypeTable(); + for (int i = 0; i < d->num_elems; i++) { + flatbuffers::TypeCode code = d->type_codes[i]; + if (code.sequence_ref != -1) { + option_to_type_function_.insert( + std::make_pair(d->names[i], d->type_refs[code.sequence_ref])); + } + } + } + + void BuildOpToOptionMap() { + // Manually specified mappings between ops and options + op_to_option_["REDUCE_MAX"] = "ReducerOptions"; + op_to_option_["REDUCE_MIN"] = "ReducerOptions"; + op_to_option_["REDUCE_ANY"] = "ReducerOptions"; + op_to_option_["UNPACK"] = ""; + op_to_option_["SUM"] = "ReducerOptions"; + op_to_option_["REDUCE_MAX"] = "ReducerOptions"; + op_to_option_["REDUCE_PROD"] = "ReducerOptions"; + op_to_option_["MEAN"] = "ReducerOptions"; + op_to_option_["L2_POOL_2D"] = "Pool2DOptions"; + op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions"; + op_to_option_["MAX_POOL_2D"] = "Pool2DOptions"; + op_to_option_["L2_NORMALIZATION"] = "L2NormOptions"; + op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions"; + op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions"; + op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; + op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; + op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; + // Manually specified mappings between ops and options (none) + op_to_option_["EMBEDDING_LOOKUP"] = + ""; // TODO(aselle): maybe something else. + op_to_option_["FLOOR"] = ""; + op_to_option_["HASHTABLE_LOOKUP"] = + ""; // TODO(aselle): maybe something else. + op_to_option_["LOGISTIC"] = ""; + op_to_option_["RELU"] = ""; + op_to_option_["RELU_N1_TO_1"] = ""; + op_to_option_["RELU6"] = ""; + op_to_option_["TANH"] = ""; + op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else. + op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else. + op_to_option_["PRELU"] = ""; + op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions + op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions + op_to_option_["SIN"] = ""; + op_to_option_["LOG"] = ""; + op_to_option_["SQRT"] = ""; + op_to_option_["RSQRT"] = ""; + + // TODO(aselle): These are undesirable hacks. Consider changing C structs + option_to_struct_["Pool2DOptions"] = "TfLitePoolParams"; + option_to_struct_["Conv2DOptions"] = "TfLiteConvParams"; + option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams"; + option_to_struct_["LocalResponseNormalizationOptions"] = + "TfLiteLocalResponseNormParams"; + // Now for every op, try to find an option. + bool fatal = false; + for (auto op_name : ops_) { + bool found_option = false; + auto d = tflite::BuiltinOptionsTypeTable(); + std::string collapsed_option_name_guess = + ToCollapsed(op_name) + "options"; + // O(n^2) but not that big of n. + for (int i = 0; i < d->num_elems; i++) { + std::string option_name = d->names[i]; + std::string collapsed_option_name = ToCollapsed(option_name); + if (collapsed_option_name_guess == collapsed_option_name) { + op_to_option_.insert(std::make_pair(op_name, option_name)); + found_option = true; + break; + } + } + auto it = op_to_option_.find(op_name); + if (it == op_to_option_.end()) { + std::cerr << "Didn't find option for " << op_name << std::endl; + fatal = true; + } else if (!it->second.empty()) { + std::string option_name = it->second; + + if (option_to_struct_.find(option_name) == option_to_struct_.end()) { + bool param_struct_found = false; + std::string params_guess = std::string("TfLite") + option_name; + size_t start = params_guess.find("Options"); + size_t len = strlen("Options"); + params_guess.replace(start, len, "Params"); + for (auto* param = param_structs; *param != nullptr; param++) { + if (*param == params_guess) { + param_struct_found = true; + break; + } + } + if (!param_struct_found) { + std::cerr << "Failed to get param struct for option " << option_name + << std::endl; + fatal = true; + } else { + option_to_struct_.insert(std::make_pair(option_name, params_guess)); + } + } + } + } + } + + private: + std::vector ops_; + std::unordered_map op_to_option_; + std::unordered_map option_to_struct_; + std::unordered_map + option_to_type_function_; +}; + +void GenerateImportForOp(FILE* fp, const std::string& op_name, + const std::string& option_name, + const std::string& option_type, + const flatbuffers::TypeTable* options, + const std::string& struct_name) { + // Skip tricky ones for now + if (struct_name == "TfLiteResizeBilinearParams") return; + if (struct_name == "TfLiteSqueezeParams") return; + if (struct_name == "TfLiteEmbeddingLookupSparseParams") return; + if (struct_name == "TfLiteReshapeParams") return; + + fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str()); + fprintf(fp, + " const auto* params = reinterpret_cast(builtin_op_data);\n", + struct_name.c_str()); + + for (size_t i = 0; i < options->num_elems; i++) { + std::string elem_name = options->names[i]; + // TODO(aselle): Irregular naming in builtins + if (elem_name == "fused_activation_function") + elem_name = "activation"; + else if (elem_name == "stride_w") + elem_name = "stride_width"; + else if (elem_name == "stride_h") + elem_name = "stride_height"; + else if (elem_name == "dilation_h_factor") + elem_name = "dilation_height_factor"; + else if (elem_name == "dilation_w_factor") + elem_name = "dilation_width_factor"; + else if (elem_name == "new_shape") + elem_name = "shape"; + + flatbuffers::TypeCode code = options->type_codes[i]; + auto contained_type = code.sequence_ref != -1 + ? options->type_refs[code.sequence_ref] + : nullptr; + std::string mapper = ""; + if (contained_type == TensorTypeTypeTable) { + mapper = "TfLiteTypeToSchemaType"; + } else if (contained_type == ActivationFunctionTypeTypeTable) { + mapper = "TfLiteActivationToSchemaActivation"; + } else if (contained_type == PaddingTypeTable) { + mapper = "TfLitePaddingToSchemaPadding"; + } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) { + mapper = "FullyConnectedOptionsWeightsFormatToSchema"; + } else if (contained_type == LSTMKernelTypeTypeTable) { + mapper = "LSTMKernelTypeToSchema"; + } else if (contained_type == LSHProjectionTypeTypeTable) { + mapper = "LSHProjectionTypeToSchema"; + } + + fprintf(fp, + " auto val%zu = " + "%s(params->%s);\n", + i, mapper.c_str(), elem_name.c_str()); + } + fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str()); + for (size_t i = 0; i < options->num_elems; i++) { + fprintf(fp, ", val%zu", i); + } + fprintf(fp, ").Union();\n"); + fprintf(fp, " return std::make_pair(%s, union_type);\n", + option_type.c_str()); + fprintf(fp, " }\n break;\n"); +} + +void GenerateImport(OpOptionData* option, FILE* fp) { + std::unordered_set ignores; + ignores.insert("CONCAT_EMBEDDINGS"); + ignores.insert("CALL"); + + // Allow any op that doesn't have an options struct to be blocked + // together + for (const auto& op_name : option->ops()) { + auto option_it = option->op_to_option().find(op_name); + if (!option_it->second.empty() && ignores.find(op_name) == ignores.end()) + continue; + fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str()); + } + fprintf(fp, + " return std::make_pair(BuiltinOptions_NONE, " + "flatbuffers::Offset());\n break;\n"); + + // Iterate over each ops + for (const auto& op_name : option->ops()) { + if (ignores.find(op_name) != ignores.end()) continue; + // Get to the option and struct names, continuing if not found. + auto option_it = option->op_to_option().find(op_name); + if (option_it->second.empty()) continue; + std::string option_name = option_it->second; + std::string option_type = "BuiltinOptions_" + option_name; + auto option_func_it = option->option_to_type_function().find(option_name); + if (option_func_it == option->option_to_type_function().end()) continue; + auto struct_name_it = option->option_to_struct().find(option_name); + if (struct_name_it == option->option_to_struct().end()) { + // If no C struct, then it better have no arguments. + auto type_info = option_func_it->second(); + if (type_info->num_elems != 0) { + // We have non-zero arguments in the schema, this means there + // should be a struct. + fprintf(stderr, + "Op %s uses option struct %s which has no builtin struct\n", + op_name.c_str(), option_name.c_str()); + exit(1); + } + fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str()); + fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());", + option_type.c_str(), option_name.c_str()); + } else { + // If C struct, then we need to assign all properties + auto struct_name = struct_name_it->second; + GenerateImportForOp(fp, op_name, option_name, option_type, + option_func_it->second(), struct_name); + } + } + // TODO(aselle): Handle unhandled cases more gracefully. + fprintf(fp, + "default: return std::make_pair(BuiltinOptions_NONE, " + "flatbuffers::Offset());\n break;\n"); +} + +} // namespace tflite + +int main(int argc, char* argv[]) { + tflite::OpOptionData option; + if (argc != 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + FILE* fp = fopen(argv[1], "w"); + tflite::GenerateImport(&option, fp); + fclose(fp); +} diff --git a/tensorflow/contrib/lite/experimental/writer/writer.cc b/tensorflow/contrib/lite/experimental/writer/writer.cc new file mode 100644 index 0000000000..20ede214fb --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/writer.cc @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Just does a read/write loop of tflite file format using the interpreter as +// an intermediate. +// +// Usage: +// writer + +#include + +#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" + +int main(int argc, char* argv[]) { + if (argc != 3) { + fprintf(stderr, "Usage: %s input_file output_file\n", argv[0]); + return 1; + } + std::unique_ptr model = + tflite::FlatBufferModel::BuildFromFile(argv[1]); + std::unique_ptr interpreter; + tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver; + tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter); + tflite::InterpreterWriter writer(interpreter.get()); + writer.Write(argv[2]); + + return 0; +} diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc new file mode 100644 index 0000000000..52b17faf82 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc @@ -0,0 +1,281 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" +#include +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { +template +using Offset = flatbuffers::Offset; +template +using Vector = flatbuffers::Vector; +using FlatBufferBuilder = flatbuffers::FlatBufferBuilder; + +std::pair> CreateBuiltinUnion( + FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) { + switch (op) { +#include "tensorflow/contrib/lite/experimental/writer/option_writer_generated.h" + } + return std::make_pair(BuiltinOptions_NONE, Offset()); +} + +template +Offset> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb, + const T_INPUT& v) { + std::vector inputs(v.begin(), v.end()); + return fbb->template CreateVector(inputs); +} + +Offset>> InterpreterWriter::ExportOperators( + FlatBufferBuilder* fbb) { + std::vector> operators; + + std::vector operator_to_opcode; + // TODO(aselle): Augment this once we put execution plan in schema. + operator_to_opcode.resize(interpreter_->nodes_size(), -1); + for (int op_index : interpreter_->execution_plan()) { + const auto* node_and_registration = + interpreter_->node_and_registration(op_index); + const TfLiteRegistration* registration = &node_and_registration->second; + if (!registration->custom_name) { + operator_to_opcode[op_index] = + GetOpCodeForBuiltin(registration->builtin_code); + } else { + operator_to_opcode[op_index] = + GetOpCodeForCustom(registration->custom_name); + } + } + // second pass serialize operators + for (int op_index : interpreter_->execution_plan()) { + const auto* node_and_registration = + interpreter_->node_and_registration(op_index); + const TfLiteNode& node = node_and_registration->first; + const TfLiteRegistration& registration = node_and_registration->second; + Offset builtin_options; + BuiltinOptions builtin_options_type = BuiltinOptions_NONE; + // Custom data + // TODO(aselle): Custom options format is not known by default. Just assume + // for now. + auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS; + Offset> custom_options = 0; + + if (!registration.custom_name) { + // builtin + auto builtin_options_and_type = CreateBuiltinUnion( + fbb, static_cast(registration.builtin_code), + node.builtin_data); + builtin_options = builtin_options_and_type.second; + builtin_options_type = builtin_options_and_type.first; + } else { + auto custom_writer = custom_op_to_writer_.find(registration.custom_name); + if (custom_writer != custom_op_to_writer_.end() && + custom_writer->second) { + // delegate to custom writer if it exists + custom_writer->second(fbb, interpreter_, op_index, &custom_options, + &custom_options_format); + } else { + // use the custom data as fact + custom_options = fbb->CreateVector( + reinterpret_cast(node.custom_initial_data), + node.custom_initial_data_size); + } + } + + int opcode_index = operator_to_opcode[op_index]; + std::vector written_inputs = + RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs)); + std::vector written_outputs = + RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs)); + auto inputs = ExportVector(fbb, written_inputs); + auto outputs = ExportVector(fbb, written_outputs); + operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs, + builtin_options_type, builtin_options, + custom_options, custom_options_format)); + } + + return fbb->template CreateVector>(operators); +} + +Offset>> InterpreterWriter::ExportTensors( + FlatBufferBuilder* fbb) { + tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1); + + std::vector> tensors; + + // Make a map from tensor index to whether the tensor is a temporary. + std::vector tensor_is_temporary(interpreter_->tensors_size(), false); + for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) { + const auto* node_and_registration = + interpreter_->node_and_registration(op_index); + for (auto tensor_index : + TfLiteIntArrayView(node_and_registration->first.temporaries)) + tensor_is_temporary[tensor_index] = true; + } + + // Now we need to remap all used tensor indices + int curr_output_index = 0; + for (int tensor_index = 0; tensor_index < interpreter_->tensors_size(); + tensor_index++) { + if (!tensor_is_temporary[tensor_index]) { + tensor_to_written_tensor_[tensor_index] = curr_output_index++; + } + } + + for (int tensor_index = 0; tensor_index < interpreter_->tensors_size(); + ++tensor_index) { + // Skip temporaries. + if (tensor_is_temporary[tensor_index]) continue; + + if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) { + // We only need to convert non temporaries + if (tensor->allocation_type != kTfLiteArenaRw && + tensor->allocation_type != kTfLiteMmapRo && + tensor->allocation_type != kTfLiteArenaRwPersistent) + continue; + // Allocate a buffer index + int buffer_index = 0; // This is null + if (tensor->allocation_type == kTfLiteMmapRo) { + buffer_index = buffers_.size(); + buffers_.push_back(std::make_pair( + reinterpret_cast(tensor->data.raw), tensor->bytes)); + } + // Primitive type. + TensorType type = TfLiteTypeToSchemaType(tensor->type); + // Handle quantization + const Offset> null_array; + Offset> scale_array; + Offset> zero_point_array; + if (tensor->params.scale != 0.f) { + // We have quantization, make a single arugment array (multi channel + // quant needs updating here). + scale_array = fbb->CreateVector({tensor->params.scale}); + zero_point_array = + fbb->CreateVector({tensor->params.zero_point}); + } + Offset quantization_params = + CreateQuantizationParameters(*fbb, null_array, null_array, + scale_array, zero_point_array); + // Shape + TfLiteIntArrayView shape_view(tensor->dims); + std::vector shape = + std::vector(shape_view.begin(), shape_view.end()); + + tensors.push_back(CreateTensor(*fbb, ExportVector(fbb, shape), + type, buffer_index, + fbb->CreateString(tensor->name), + quantization_params, tensor->is_variable)); + } + } + return fbb->template CreateVector>(tensors); +} + +Offset>> InterpreterWriter::ExportBuffers( + FlatBufferBuilder* fbb) { + std::vector> buffer_vector; + for (auto buffer : buffers_) { + auto data_offset = fbb->CreateVector(buffer.first, buffer.second); + buffer_vector.push_back(CreateBuffer(*fbb, data_offset)); + } + return fbb->template CreateVector>(buffer_vector); +} + +Offset>> InterpreterWriter::CreateOpCodeTable( + FlatBufferBuilder* fbb) { + std::vector> codes; + for (auto it : opcodes_) { + const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str(); + codes.push_back(CreateOperatorCodeDirect( + *fbb, static_cast(it.builtin), custom_name)); + } + return fbb->template CreateVector>(codes); +} + +template +std::vector InterpreterWriter::RemapTensorIndicesToWritten( + const T& input) { + std::vector output; + output.reserve(input.size()); + for (int x : input) { + output.push_back(tensor_to_written_tensor_[x]); + } + return output; +} + +TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr* out, + size_t* size) { + if (!out || !size) return kTfLiteError; + FlatBufferBuilder builder(/*initial_size=*/10240); + + std::vector> subgraphs_as_vector; + { // subgraph specific stuff + auto tensors = ExportTensors(&builder); + std::vector written_inputs = + RemapTensorIndicesToWritten(interpreter_->inputs()); + std::vector written_outputs = + RemapTensorIndicesToWritten(interpreter_->outputs()); + auto inputs = ExportVector(&builder, written_inputs); + auto outputs = ExportVector(&builder, written_outputs); + + auto ops = ExportOperators(&builder); + subgraphs_as_vector.push_back( + CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0)); + } + Offset>> buffers = ExportBuffers(&builder); + + auto description = builder.CreateString("Exported from Interpreter."); + + auto op_codes = CreateOpCodeTable(&builder); + auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes, + builder.CreateVector(subgraphs_as_vector), + description, buffers); + ::tflite::FinishModelBuffer(builder, model); + const uint8_t* buffer = builder.GetBufferPointer(); + *size = builder.GetSize(); + (*out).reset(new uint8_t[*size]); + memcpy(out->get(), buffer, *size); + return kTfLiteOk; +} + +TfLiteStatus InterpreterWriter::Write(const std::string& filename) { + std::unique_ptr buffer; + size_t size; + TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size)); + + FILE* fp = fopen(filename.c_str(), "wb"); + if (!fp) return kTfLiteError; + + if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError; + if (fclose(fp)) return kTfLiteError; + + return kTfLiteOk; +} + +TfLiteStatus InterpreterWriter::RegisterCustomWriter( + const std::string& custom_name, CustomWriter custom_writer) { + if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) { + return kTfLiteError; + } + custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer)); + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h new file mode 100644 index 0000000000..a98108b496 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.h @@ -0,0 +1,126 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter. +// +// Usage: +// From command line: +// bazel run third_party/tensorflow/contrib/lite/experimental/writer:writer +// -- foo.tflite foo.out.tflite +// +// From C++ +// std::unique_ptr interpreter; +// // Build Interpreter however +// // ... +// InterpreterWriter(interpreter.get()).Write("output.tflite"); +#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ +#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { + +// Handles writing TensorFlow Lite running interpreter to a serialized TF lite +// file format. +class InterpreterWriter { + public: + typedef flatbuffers::Offset (*CustomWriter)( + flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter, + int node_index, + flatbuffers::Offset>* output_options, + CustomOptionsFormat* custom_options_format); + + // Construct an interpreter writer for the specified `interpreter`. Then, + // a uses .Write() or .GetBuffer(...) to extract the data. + explicit InterpreterWriter(Interpreter* interpreter) + : interpreter_(interpreter) { + buffers_.push_back(std::make_pair(nullptr, 0)); + } + + // Get a buffer and size of a serialized flatbuffer. + TfLiteStatus GetBuffer(std::unique_ptr* out, size_t* size); + // Write the serialized flatbuffer to the prescribed `filename`. + TfLiteStatus Write(const std::string& filename); + // Registers a custom writer for a custom op. The customization allows the + // caller to change the custom data. + TfLiteStatus RegisterCustomWriter(const std::string& custom_name, + CustomWriter custom_writer); + + private: + template + using Offset = flatbuffers::Offset; + template + Offset> ExportVector( + flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v); + Offset>> ExportTensors( + flatbuffers::FlatBufferBuilder* fbb); + Offset>> ExportOperators( + flatbuffers::FlatBufferBuilder* fbb); + Offset>> CreateOpCodeTable( + flatbuffers::FlatBufferBuilder* fbb); + Offset>> ExportBuffers( + flatbuffers::FlatBufferBuilder* fbb); + + template + std::vector RemapTensorIndicesToWritten(const T& input); + + int GetOpCodeForBuiltin(int builtin_op_index) { + // auto it = builtin_op_to_opcode_.find(builtin_op_index); + std::pair result = + builtin_op_to_opcode_.insert( + std::make_pair(builtin_op_index, opcodes_.size())); + if (result.second) { + opcodes_.push_back({builtin_op_index, ""}); + } + return result.first->second; + } + + int GetOpCodeForCustom(const std::string& custom_name) { + std::pair result = + custom_op_to_opcode_.insert( + std::make_pair(custom_name, opcodes_.size())); + if (result.second) { + opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name}); + } + return result.first->second; + } + + // The interpreter we are writing + Interpreter* interpreter_; + // Keep track of byte buffers + std::vector> buffers_; + // List of op codes and mappings from builtin or custom op to opcode + struct OpCode { + int builtin; + std::string custom; + }; + // For every tensor index in the interpreter, the index in the written. + // This is different due to temporary tensors not being written. + std::vector tensor_to_written_tensor_; + // List of used opcodes + std::vector opcodes_; + std::unordered_map builtin_op_to_opcode_; + std::unordered_map custom_op_to_opcode_; + std::unordered_map custom_op_to_writer_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc new file mode 100644 index 0000000000..49194a76c8 --- /dev/null +++ b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +// Make an interpreter that has no tensors and no nodes +// TODO(b/113731921): add more tests. +TEST(Writer, BasicTest) { + Interpreter interpreter; + interpreter.AddTensors(3); + float foo[] = {1, 2, 3}; + interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3}, + TfLiteQuantizationParams()); + interpreter.SetTensorParametersReadOnly( + 1, kTfLiteFloat32, "b", {3}, TfLiteQuantizationParams(), + reinterpret_cast(foo), sizeof(foo)); + interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3}, + TfLiteQuantizationParams()); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({2}); + const char* initial_data = ""; + tflite::ops::builtin::BuiltinOpResolver resolver; + TfLiteAddParams* builtin_data = + reinterpret_cast(malloc(sizeof(TfLiteAddParams))); + builtin_data->activation = kTfLiteActNone; + const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1); + interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, + reinterpret_cast(builtin_data), reg); + + InterpreterWriter writer(&interpreter); + writer.Write("/tmp/test.tflite"); + std::unique_ptr model = + FlatBufferModel::BuildFromFile("/tmp/test.tflite"); + InterpreterBuilder builder(*model, resolver); + std::unique_ptr new_interpreter; + builder(&new_interpreter); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/op_resolver.cc b/tensorflow/contrib/lite/op_resolver.cc index f6e435e982..a9885f7737 100644 --- a/tensorflow/contrib/lite/op_resolver.cc +++ b/tensorflow/contrib/lite/op_resolver.cc @@ -46,6 +46,8 @@ void MutableOpResolver::AddCustom(const char* name, TfLiteRegistration* registration, int min_version, int max_version) { for (int version = min_version; version <= max_version; ++version) { + // TODO(aselle): This should verify that the incoming registration + // has the name in the registration already and it matches!!! TfLiteRegistration new_registration = *registration; new_registration.builtin_code = BuiltinOperator_CUSTOM; new_registration.version = version; diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD index 28a7e50003..55bf2c48b9 100644 --- a/tensorflow/contrib/lite/schema/BUILD +++ b/tensorflow/contrib/lite/schema/BUILD @@ -56,6 +56,20 @@ flatbuffer_cc_library( srcs = ["schema.fbs"], ) +# Generic schema for inference on device (but with reflections makes bigger). +flatbuffer_cc_library( + name = "schema_fbs_with_reflection", + srcs = ["schema.fbs"], + flatc_args = [ + "--reflect-types", + "--reflect-names", + "--no-union-value-namespacing", + "--gen-object-api", + ], + gen_reflections = True, + out_prefix = "reflection/", +) + # Schema test to make sure we don't introduce backward incompatible changes # to schemas. cc_test( -- GitLab From 0065d3389a63a529469dc71e950c66da2ebdbc24 Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Tue, 4 Sep 2018 16:01:54 -0700 Subject: [PATCH 082/540] Automated rollback of commit 69753ba5dbe5950639efc1b5e065901651cd8973 PiperOrigin-RevId: 211541639 --- .../contrib/lite/experimental/writer/BUILD | 64 --- .../lite/experimental/writer/enum_mapping.h | 116 ------ .../writer/option_writer_generator.cc | 370 ------------------ .../lite/experimental/writer/writer.cc | 41 -- .../lite/experimental/writer/writer_lib.cc | 281 ------------- .../lite/experimental/writer/writer_lib.h | 126 ------ .../experimental/writer/writer_lib_test.cc | 62 --- tensorflow/contrib/lite/op_resolver.cc | 2 - tensorflow/contrib/lite/schema/BUILD | 14 - 9 files changed, 1076 deletions(-) delete mode 100644 tensorflow/contrib/lite/experimental/writer/BUILD delete mode 100644 tensorflow/contrib/lite/experimental/writer/enum_mapping.h delete mode 100644 tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc delete mode 100644 tensorflow/contrib/lite/experimental/writer/writer.cc delete mode 100644 tensorflow/contrib/lite/experimental/writer/writer_lib.cc delete mode 100644 tensorflow/contrib/lite/experimental/writer/writer_lib.h delete mode 100644 tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc diff --git a/tensorflow/contrib/lite/experimental/writer/BUILD b/tensorflow/contrib/lite/experimental/writer/BUILD deleted file mode 100644 index d43964208b..0000000000 --- a/tensorflow/contrib/lite/experimental/writer/BUILD +++ /dev/null @@ -1,64 +0,0 @@ -package(default_visibility = [ - "//visibility:public", -]) - -licenses(["notice"]) # Apache 2.0 - -cc_binary( - name = "option_writer_generator", - srcs = ["option_writer_generator.cc"], - deps = [ - "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection", - "@flatbuffers", - ], -) - -cc_library( - name = "writer_lib", - srcs = [ - "enum_mapping.h", - "writer_lib.cc", - ], - hdrs = [ - "writer_lib.h", - ], - textual_hdrs = ["option_writer_generated.h"], - deps = [ - ":option_writer_gen", - "//tensorflow/contrib/lite:builtin_op_data", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:schema_fbs_version", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection", - ], -) - -cc_binary( - name = "writer", - srcs = ["writer.cc"], - deps = [ - ":writer_lib", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", - ], -) - -cc_test( - name = "writer_lib_test", - size = "small", - srcs = ["writer_lib_test.cc"], - deps = [ - ":writer_lib", - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite/kernels:builtin_ops", - "//tensorflow/contrib/lite/testing:util", - "//testing/base/public:gunit", - ], -) - -genrule( - name = "option_writer_gen", - outs = ["option_writer_generated.h"], - cmd = "$(location :option_writer_generator) $(@)", - tools = [":option_writer_generator"], -) diff --git a/tensorflow/contrib/lite/experimental/writer/enum_mapping.h b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h deleted file mode 100644 index 8bc464fd71..0000000000 --- a/tensorflow/contrib/lite/experimental/writer/enum_mapping.h +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ - -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" - -// TODO(aselle): Ideally extract this from the schema. - -namespace tflite { - -inline ActivationFunctionType TfLiteActivationToSchemaActivation( - TfLiteFusedActivation act) { - switch (act) { - case kTfLiteActNone: - return ActivationFunctionType_NONE; - case kTfLiteActRelu: - return ActivationFunctionType_RELU; - case kTfLiteActRelu1: - return ActivationFunctionType_RELU_N1_TO_1; - case kTfLiteActRelu6: - return ActivationFunctionType_RELU6; - case kTfLiteActTanh: - return ActivationFunctionType_TANH; - case kTfLiteActSignBit: - return ActivationFunctionType_SIGN_BIT; - case kTfLiteActSigmoid: - return ActivationFunctionType_NONE; // TODO(aselle): Add to schema - } - return ActivationFunctionType_NONE; -} - -inline Padding TfLitePaddingToSchemaPadding(TfLitePadding padding) { - switch (padding) { - case kTfLitePaddingUnknown: - return Padding_SAME; // TODO(aselle): Consider an error. - case kTfLitePaddingSame: - return Padding_SAME; - case kTfLitePaddingValid: - return Padding_VALID; - } - return Padding_SAME; // TODO(aselle): Consider an error. -} - -inline TensorType TfLiteTypeToSchemaType(TfLiteType type) { - switch (type) { - // case kTfLiteNoType: return TensorType_NONE; - case kTfLiteNoType: - return TensorType_FLOAT32; // TODO(aselle): Consider an error. - case kTfLiteFloat32: - return TensorType_FLOAT32; - case kTfLiteInt32: - return TensorType_INT32; - case kTfLiteUInt8: - return TensorType_UINT8; - case kTfLiteInt64: - return TensorType_INT64; - case kTfLiteString: - return TensorType_STRING; - case kTfLiteBool: - return TensorType_BOOL; - case kTfLiteInt16: - return TensorType_INT16; - case kTfLiteComplex64: - return TensorType_COMPLEX64; - } - // TODO(aselle): consider an error -} - -inline FullyConnectedOptionsWeightsFormat -FullyConnectedOptionsWeightsFormatToSchema( - TfLiteFullyConnectedWeightsFormat format) { - switch (format) { - case kTfLiteFullyConnectedWeightsFormatDefault: - return FullyConnectedOptionsWeightsFormat_DEFAULT; - case kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8: - return FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8; - } -} - -inline LSTMKernelType LSTMKernelTypeToSchema(TfLiteLSTMKernelType type) { - switch (type) { - case kTfLiteLSTMFullKernel: - return LSTMKernelType_FULL; - case kTfLiteLSTMBasicKernel: - return LSTMKernelType_BASIC; - } -} - -inline LSHProjectionType LSHProjectionTypeToSchema( - TfLiteLSHProjectionType type) { - switch (type) { - case kTfLiteLshProjectionUnknown: - return LSHProjectionType_UNKNOWN; - case kTfLiteLshProjectionSparse: - return LSHProjectionType_SPARSE; - case kTfLiteLshProjectionDense: - return LSHProjectionType_DENSE; - } -} - -} // namespace tflite -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_ diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc deleted file mode 100644 index e6d5a776b3..0000000000 --- a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc +++ /dev/null @@ -1,370 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include -#include "flatbuffers/minireflect.h" // flatbuffers -#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" - -namespace tflite { -namespace { -// This is generated by grepping -// cat third_party/tensorflow/contrib/lite/builtin_op_data.h -//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}" -static const char* param_structs[] = {"TfLiteConvParams", - "TfLitePoolParams", - "TfLiteDepthwiseConvParams", - "TfLiteSVDFParams", - "TfLiteRNNParams", - "TfLiteSequenceRNNParams", - "TfLiteFullyConnectedParams", - "TfLiteLSHProjectionParams", - "TfLiteSoftmaxParams", - "TfLiteConcatenationParams", - "TfLiteAddParams", - "TfLiteSpaceToBatchNDParams", - "TfLiteBatchToSpaceNDParams", - "TfLiteMulParams", - "TfLiteSubParams", - "TfLiteDivParams", - "TfLiteL2NormParams", - "TfLiteLocalResponseNormParams", - "TfLiteLSTMParams", - "TfLiteResizeBilinearParams", - "TfLitePadParams", - "TfLitePadV2Params", - "TfLiteReshapeParams", - "TfLiteSkipGramParams", - "TfLiteSpaceToDepthParams", - "TfLiteCastParams", - "TfLiteEmbeddingLookupSparseParams", - "TfLiteGatherParams", - "TfLiteTransposeParams", - "TfLiteReducerParams", - "TfLiteSplitParams", - "TfLiteSqueezeParams", - "TfLiteStridedSliceParams", - "TfLiteArgMaxParams", - "TfLiteArgMinParams", - "TfLiteTransposeConvParams", - "TfLiteSparseToDenseParams", - "TfLiteShapeParams", - "TfLiteFakeQuantParams", - "TfLitePackParams", - "TfLiteOneHotParams", - nullptr}; -} // namespace - -// Get rid of all underscores and make everything lower case to make name -// matching work for stuff like 3D vs 3d or RNN vs Rnn. -std::string ToCollapsed(const std::string& in) { - const char* s = in.c_str(); - bool first = true; - std::string out; - while (*s != '\0') { - if (*s == '_') { - first = true; - } else if (first) { - out.push_back(tolower(*s)); - first = false; - } else { - out.push_back(tolower(*s)); - } - s++; - } - return out; -} - -// A collection of information about builtin ops. -class OpOptionData { - public: - OpOptionData() { - BuildOpList(); - BuildOptionToTypeFunctionMap(); - BuildOpToOptionMap(); - } - - // A list of builtin operations - const std::vector& ops() const { return ops_; } - // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions') - const std::unordered_map& op_to_option() { - return op_to_option_; - } - // Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions' - const std::unordered_map& option_to_struct() { - return option_to_struct_; - } - // Maps from option to a flatbuffer type function that describes that option. - const std::unordered_map& - option_to_type_function() { - return option_to_type_function_; - } - - private: - void BuildOpList() { - for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr; - ++curr) { - if (strlen(*curr) != 0) ops_.push_back(*curr); - } - } - - void BuildOptionToTypeFunctionMap() { - auto d = tflite::BuiltinOptionsTypeTable(); - for (int i = 0; i < d->num_elems; i++) { - flatbuffers::TypeCode code = d->type_codes[i]; - if (code.sequence_ref != -1) { - option_to_type_function_.insert( - std::make_pair(d->names[i], d->type_refs[code.sequence_ref])); - } - } - } - - void BuildOpToOptionMap() { - // Manually specified mappings between ops and options - op_to_option_["REDUCE_MAX"] = "ReducerOptions"; - op_to_option_["REDUCE_MIN"] = "ReducerOptions"; - op_to_option_["REDUCE_ANY"] = "ReducerOptions"; - op_to_option_["UNPACK"] = ""; - op_to_option_["SUM"] = "ReducerOptions"; - op_to_option_["REDUCE_MAX"] = "ReducerOptions"; - op_to_option_["REDUCE_PROD"] = "ReducerOptions"; - op_to_option_["MEAN"] = "ReducerOptions"; - op_to_option_["L2_POOL_2D"] = "Pool2DOptions"; - op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions"; - op_to_option_["MAX_POOL_2D"] = "Pool2DOptions"; - op_to_option_["L2_NORMALIZATION"] = "L2NormOptions"; - op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions"; - op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions"; - op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; - op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; - op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions"; - // Manually specified mappings between ops and options (none) - op_to_option_["EMBEDDING_LOOKUP"] = - ""; // TODO(aselle): maybe something else. - op_to_option_["FLOOR"] = ""; - op_to_option_["HASHTABLE_LOOKUP"] = - ""; // TODO(aselle): maybe something else. - op_to_option_["LOGISTIC"] = ""; - op_to_option_["RELU"] = ""; - op_to_option_["RELU_N1_TO_1"] = ""; - op_to_option_["RELU6"] = ""; - op_to_option_["TANH"] = ""; - op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else. - op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else. - op_to_option_["PRELU"] = ""; - op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions - op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions - op_to_option_["SIN"] = ""; - op_to_option_["LOG"] = ""; - op_to_option_["SQRT"] = ""; - op_to_option_["RSQRT"] = ""; - - // TODO(aselle): These are undesirable hacks. Consider changing C structs - option_to_struct_["Pool2DOptions"] = "TfLitePoolParams"; - option_to_struct_["Conv2DOptions"] = "TfLiteConvParams"; - option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams"; - option_to_struct_["LocalResponseNormalizationOptions"] = - "TfLiteLocalResponseNormParams"; - // Now for every op, try to find an option. - bool fatal = false; - for (auto op_name : ops_) { - bool found_option = false; - auto d = tflite::BuiltinOptionsTypeTable(); - std::string collapsed_option_name_guess = - ToCollapsed(op_name) + "options"; - // O(n^2) but not that big of n. - for (int i = 0; i < d->num_elems; i++) { - std::string option_name = d->names[i]; - std::string collapsed_option_name = ToCollapsed(option_name); - if (collapsed_option_name_guess == collapsed_option_name) { - op_to_option_.insert(std::make_pair(op_name, option_name)); - found_option = true; - break; - } - } - auto it = op_to_option_.find(op_name); - if (it == op_to_option_.end()) { - std::cerr << "Didn't find option for " << op_name << std::endl; - fatal = true; - } else if (!it->second.empty()) { - std::string option_name = it->second; - - if (option_to_struct_.find(option_name) == option_to_struct_.end()) { - bool param_struct_found = false; - std::string params_guess = std::string("TfLite") + option_name; - size_t start = params_guess.find("Options"); - size_t len = strlen("Options"); - params_guess.replace(start, len, "Params"); - for (auto* param = param_structs; *param != nullptr; param++) { - if (*param == params_guess) { - param_struct_found = true; - break; - } - } - if (!param_struct_found) { - std::cerr << "Failed to get param struct for option " << option_name - << std::endl; - fatal = true; - } else { - option_to_struct_.insert(std::make_pair(option_name, params_guess)); - } - } - } - } - } - - private: - std::vector ops_; - std::unordered_map op_to_option_; - std::unordered_map option_to_struct_; - std::unordered_map - option_to_type_function_; -}; - -void GenerateImportForOp(FILE* fp, const std::string& op_name, - const std::string& option_name, - const std::string& option_type, - const flatbuffers::TypeTable* options, - const std::string& struct_name) { - // Skip tricky ones for now - if (struct_name == "TfLiteResizeBilinearParams") return; - if (struct_name == "TfLiteSqueezeParams") return; - if (struct_name == "TfLiteEmbeddingLookupSparseParams") return; - if (struct_name == "TfLiteReshapeParams") return; - - fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str()); - fprintf(fp, - " const auto* params = reinterpret_cast(builtin_op_data);\n", - struct_name.c_str()); - - for (size_t i = 0; i < options->num_elems; i++) { - std::string elem_name = options->names[i]; - // TODO(aselle): Irregular naming in builtins - if (elem_name == "fused_activation_function") - elem_name = "activation"; - else if (elem_name == "stride_w") - elem_name = "stride_width"; - else if (elem_name == "stride_h") - elem_name = "stride_height"; - else if (elem_name == "dilation_h_factor") - elem_name = "dilation_height_factor"; - else if (elem_name == "dilation_w_factor") - elem_name = "dilation_width_factor"; - else if (elem_name == "new_shape") - elem_name = "shape"; - - flatbuffers::TypeCode code = options->type_codes[i]; - auto contained_type = code.sequence_ref != -1 - ? options->type_refs[code.sequence_ref] - : nullptr; - std::string mapper = ""; - if (contained_type == TensorTypeTypeTable) { - mapper = "TfLiteTypeToSchemaType"; - } else if (contained_type == ActivationFunctionTypeTypeTable) { - mapper = "TfLiteActivationToSchemaActivation"; - } else if (contained_type == PaddingTypeTable) { - mapper = "TfLitePaddingToSchemaPadding"; - } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) { - mapper = "FullyConnectedOptionsWeightsFormatToSchema"; - } else if (contained_type == LSTMKernelTypeTypeTable) { - mapper = "LSTMKernelTypeToSchema"; - } else if (contained_type == LSHProjectionTypeTypeTable) { - mapper = "LSHProjectionTypeToSchema"; - } - - fprintf(fp, - " auto val%zu = " - "%s(params->%s);\n", - i, mapper.c_str(), elem_name.c_str()); - } - fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str()); - for (size_t i = 0; i < options->num_elems; i++) { - fprintf(fp, ", val%zu", i); - } - fprintf(fp, ").Union();\n"); - fprintf(fp, " return std::make_pair(%s, union_type);\n", - option_type.c_str()); - fprintf(fp, " }\n break;\n"); -} - -void GenerateImport(OpOptionData* option, FILE* fp) { - std::unordered_set ignores; - ignores.insert("CONCAT_EMBEDDINGS"); - ignores.insert("CALL"); - - // Allow any op that doesn't have an options struct to be blocked - // together - for (const auto& op_name : option->ops()) { - auto option_it = option->op_to_option().find(op_name); - if (!option_it->second.empty() && ignores.find(op_name) == ignores.end()) - continue; - fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str()); - } - fprintf(fp, - " return std::make_pair(BuiltinOptions_NONE, " - "flatbuffers::Offset());\n break;\n"); - - // Iterate over each ops - for (const auto& op_name : option->ops()) { - if (ignores.find(op_name) != ignores.end()) continue; - // Get to the option and struct names, continuing if not found. - auto option_it = option->op_to_option().find(op_name); - if (option_it->second.empty()) continue; - std::string option_name = option_it->second; - std::string option_type = "BuiltinOptions_" + option_name; - auto option_func_it = option->option_to_type_function().find(option_name); - if (option_func_it == option->option_to_type_function().end()) continue; - auto struct_name_it = option->option_to_struct().find(option_name); - if (struct_name_it == option->option_to_struct().end()) { - // If no C struct, then it better have no arguments. - auto type_info = option_func_it->second(); - if (type_info->num_elems != 0) { - // We have non-zero arguments in the schema, this means there - // should be a struct. - fprintf(stderr, - "Op %s uses option struct %s which has no builtin struct\n", - op_name.c_str(), option_name.c_str()); - exit(1); - } - fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str()); - fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());", - option_type.c_str(), option_name.c_str()); - } else { - // If C struct, then we need to assign all properties - auto struct_name = struct_name_it->second; - GenerateImportForOp(fp, op_name, option_name, option_type, - option_func_it->second(), struct_name); - } - } - // TODO(aselle): Handle unhandled cases more gracefully. - fprintf(fp, - "default: return std::make_pair(BuiltinOptions_NONE, " - "flatbuffers::Offset());\n break;\n"); -} - -} // namespace tflite - -int main(int argc, char* argv[]) { - tflite::OpOptionData option; - if (argc != 2) { - fprintf(stderr, "Usage: %s \n", argv[0]); - return 1; - } - FILE* fp = fopen(argv[1], "w"); - tflite::GenerateImport(&option, fp); - fclose(fp); -} diff --git a/tensorflow/contrib/lite/experimental/writer/writer.cc b/tensorflow/contrib/lite/experimental/writer/writer.cc deleted file mode 100644 index 20ede214fb..0000000000 --- a/tensorflow/contrib/lite/experimental/writer/writer.cc +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Just does a read/write loop of tflite file format using the interpreter as -// an intermediate. -// -// Usage: -// writer - -#include - -#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" - -int main(int argc, char* argv[]) { - if (argc != 3) { - fprintf(stderr, "Usage: %s input_file output_file\n", argv[0]); - return 1; - } - std::unique_ptr model = - tflite::FlatBufferModel::BuildFromFile(argv[1]); - std::unique_ptr interpreter; - tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver; - tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter); - tflite::InterpreterWriter writer(interpreter.get()); - writer.Write(argv[2]); - - return 0; -} diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc deleted file mode 100644 index 52b17faf82..0000000000 --- a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc +++ /dev/null @@ -1,281 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" -#include -#include -#include -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context_util.h" -#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" -#include "tensorflow/contrib/lite/version.h" - -namespace tflite { -template -using Offset = flatbuffers::Offset; -template -using Vector = flatbuffers::Vector; -using FlatBufferBuilder = flatbuffers::FlatBufferBuilder; - -std::pair> CreateBuiltinUnion( - FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) { - switch (op) { -#include "tensorflow/contrib/lite/experimental/writer/option_writer_generated.h" - } - return std::make_pair(BuiltinOptions_NONE, Offset()); -} - -template -Offset> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb, - const T_INPUT& v) { - std::vector inputs(v.begin(), v.end()); - return fbb->template CreateVector(inputs); -} - -Offset>> InterpreterWriter::ExportOperators( - FlatBufferBuilder* fbb) { - std::vector> operators; - - std::vector operator_to_opcode; - // TODO(aselle): Augment this once we put execution plan in schema. - operator_to_opcode.resize(interpreter_->nodes_size(), -1); - for (int op_index : interpreter_->execution_plan()) { - const auto* node_and_registration = - interpreter_->node_and_registration(op_index); - const TfLiteRegistration* registration = &node_and_registration->second; - if (!registration->custom_name) { - operator_to_opcode[op_index] = - GetOpCodeForBuiltin(registration->builtin_code); - } else { - operator_to_opcode[op_index] = - GetOpCodeForCustom(registration->custom_name); - } - } - // second pass serialize operators - for (int op_index : interpreter_->execution_plan()) { - const auto* node_and_registration = - interpreter_->node_and_registration(op_index); - const TfLiteNode& node = node_and_registration->first; - const TfLiteRegistration& registration = node_and_registration->second; - Offset builtin_options; - BuiltinOptions builtin_options_type = BuiltinOptions_NONE; - // Custom data - // TODO(aselle): Custom options format is not known by default. Just assume - // for now. - auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS; - Offset> custom_options = 0; - - if (!registration.custom_name) { - // builtin - auto builtin_options_and_type = CreateBuiltinUnion( - fbb, static_cast(registration.builtin_code), - node.builtin_data); - builtin_options = builtin_options_and_type.second; - builtin_options_type = builtin_options_and_type.first; - } else { - auto custom_writer = custom_op_to_writer_.find(registration.custom_name); - if (custom_writer != custom_op_to_writer_.end() && - custom_writer->second) { - // delegate to custom writer if it exists - custom_writer->second(fbb, interpreter_, op_index, &custom_options, - &custom_options_format); - } else { - // use the custom data as fact - custom_options = fbb->CreateVector( - reinterpret_cast(node.custom_initial_data), - node.custom_initial_data_size); - } - } - - int opcode_index = operator_to_opcode[op_index]; - std::vector written_inputs = - RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs)); - std::vector written_outputs = - RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs)); - auto inputs = ExportVector(fbb, written_inputs); - auto outputs = ExportVector(fbb, written_outputs); - operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs, - builtin_options_type, builtin_options, - custom_options, custom_options_format)); - } - - return fbb->template CreateVector>(operators); -} - -Offset>> InterpreterWriter::ExportTensors( - FlatBufferBuilder* fbb) { - tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1); - - std::vector> tensors; - - // Make a map from tensor index to whether the tensor is a temporary. - std::vector tensor_is_temporary(interpreter_->tensors_size(), false); - for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) { - const auto* node_and_registration = - interpreter_->node_and_registration(op_index); - for (auto tensor_index : - TfLiteIntArrayView(node_and_registration->first.temporaries)) - tensor_is_temporary[tensor_index] = true; - } - - // Now we need to remap all used tensor indices - int curr_output_index = 0; - for (int tensor_index = 0; tensor_index < interpreter_->tensors_size(); - tensor_index++) { - if (!tensor_is_temporary[tensor_index]) { - tensor_to_written_tensor_[tensor_index] = curr_output_index++; - } - } - - for (int tensor_index = 0; tensor_index < interpreter_->tensors_size(); - ++tensor_index) { - // Skip temporaries. - if (tensor_is_temporary[tensor_index]) continue; - - if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) { - // We only need to convert non temporaries - if (tensor->allocation_type != kTfLiteArenaRw && - tensor->allocation_type != kTfLiteMmapRo && - tensor->allocation_type != kTfLiteArenaRwPersistent) - continue; - // Allocate a buffer index - int buffer_index = 0; // This is null - if (tensor->allocation_type == kTfLiteMmapRo) { - buffer_index = buffers_.size(); - buffers_.push_back(std::make_pair( - reinterpret_cast(tensor->data.raw), tensor->bytes)); - } - // Primitive type. - TensorType type = TfLiteTypeToSchemaType(tensor->type); - // Handle quantization - const Offset> null_array; - Offset> scale_array; - Offset> zero_point_array; - if (tensor->params.scale != 0.f) { - // We have quantization, make a single arugment array (multi channel - // quant needs updating here). - scale_array = fbb->CreateVector({tensor->params.scale}); - zero_point_array = - fbb->CreateVector({tensor->params.zero_point}); - } - Offset quantization_params = - CreateQuantizationParameters(*fbb, null_array, null_array, - scale_array, zero_point_array); - // Shape - TfLiteIntArrayView shape_view(tensor->dims); - std::vector shape = - std::vector(shape_view.begin(), shape_view.end()); - - tensors.push_back(CreateTensor(*fbb, ExportVector(fbb, shape), - type, buffer_index, - fbb->CreateString(tensor->name), - quantization_params, tensor->is_variable)); - } - } - return fbb->template CreateVector>(tensors); -} - -Offset>> InterpreterWriter::ExportBuffers( - FlatBufferBuilder* fbb) { - std::vector> buffer_vector; - for (auto buffer : buffers_) { - auto data_offset = fbb->CreateVector(buffer.first, buffer.second); - buffer_vector.push_back(CreateBuffer(*fbb, data_offset)); - } - return fbb->template CreateVector>(buffer_vector); -} - -Offset>> InterpreterWriter::CreateOpCodeTable( - FlatBufferBuilder* fbb) { - std::vector> codes; - for (auto it : opcodes_) { - const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str(); - codes.push_back(CreateOperatorCodeDirect( - *fbb, static_cast(it.builtin), custom_name)); - } - return fbb->template CreateVector>(codes); -} - -template -std::vector InterpreterWriter::RemapTensorIndicesToWritten( - const T& input) { - std::vector output; - output.reserve(input.size()); - for (int x : input) { - output.push_back(tensor_to_written_tensor_[x]); - } - return output; -} - -TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr* out, - size_t* size) { - if (!out || !size) return kTfLiteError; - FlatBufferBuilder builder(/*initial_size=*/10240); - - std::vector> subgraphs_as_vector; - { // subgraph specific stuff - auto tensors = ExportTensors(&builder); - std::vector written_inputs = - RemapTensorIndicesToWritten(interpreter_->inputs()); - std::vector written_outputs = - RemapTensorIndicesToWritten(interpreter_->outputs()); - auto inputs = ExportVector(&builder, written_inputs); - auto outputs = ExportVector(&builder, written_outputs); - - auto ops = ExportOperators(&builder); - subgraphs_as_vector.push_back( - CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0)); - } - Offset>> buffers = ExportBuffers(&builder); - - auto description = builder.CreateString("Exported from Interpreter."); - - auto op_codes = CreateOpCodeTable(&builder); - auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes, - builder.CreateVector(subgraphs_as_vector), - description, buffers); - ::tflite::FinishModelBuffer(builder, model); - const uint8_t* buffer = builder.GetBufferPointer(); - *size = builder.GetSize(); - (*out).reset(new uint8_t[*size]); - memcpy(out->get(), buffer, *size); - return kTfLiteOk; -} - -TfLiteStatus InterpreterWriter::Write(const std::string& filename) { - std::unique_ptr buffer; - size_t size; - TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size)); - - FILE* fp = fopen(filename.c_str(), "wb"); - if (!fp) return kTfLiteError; - - if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError; - if (fclose(fp)) return kTfLiteError; - - return kTfLiteOk; -} - -TfLiteStatus InterpreterWriter::RegisterCustomWriter( - const std::string& custom_name, CustomWriter custom_writer) { - if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) { - return kTfLiteError; - } - custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer)); - return kTfLiteOk; -} - -} // namespace tflite diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h deleted file mode 100644 index a98108b496..0000000000 --- a/tensorflow/contrib/lite/experimental/writer/writer_lib.h +++ /dev/null @@ -1,126 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter. -// -// Usage: -// From command line: -// bazel run third_party/tensorflow/contrib/lite/experimental/writer:writer -// -- foo.tflite foo.out.tflite -// -// From C++ -// std::unique_ptr interpreter; -// // Build Interpreter however -// // ... -// InterpreterWriter(interpreter.get()).Write("output.tflite"); -#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ -#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ -#include -#include -#include "tensorflow/contrib/lite/builtin_op_data.h" -#include "tensorflow/contrib/lite/context_util.h" -#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h" -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h" -#include "tensorflow/contrib/lite/version.h" - -namespace tflite { - -// Handles writing TensorFlow Lite running interpreter to a serialized TF lite -// file format. -class InterpreterWriter { - public: - typedef flatbuffers::Offset (*CustomWriter)( - flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter, - int node_index, - flatbuffers::Offset>* output_options, - CustomOptionsFormat* custom_options_format); - - // Construct an interpreter writer for the specified `interpreter`. Then, - // a uses .Write() or .GetBuffer(...) to extract the data. - explicit InterpreterWriter(Interpreter* interpreter) - : interpreter_(interpreter) { - buffers_.push_back(std::make_pair(nullptr, 0)); - } - - // Get a buffer and size of a serialized flatbuffer. - TfLiteStatus GetBuffer(std::unique_ptr* out, size_t* size); - // Write the serialized flatbuffer to the prescribed `filename`. - TfLiteStatus Write(const std::string& filename); - // Registers a custom writer for a custom op. The customization allows the - // caller to change the custom data. - TfLiteStatus RegisterCustomWriter(const std::string& custom_name, - CustomWriter custom_writer); - - private: - template - using Offset = flatbuffers::Offset; - template - Offset> ExportVector( - flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v); - Offset>> ExportTensors( - flatbuffers::FlatBufferBuilder* fbb); - Offset>> ExportOperators( - flatbuffers::FlatBufferBuilder* fbb); - Offset>> CreateOpCodeTable( - flatbuffers::FlatBufferBuilder* fbb); - Offset>> ExportBuffers( - flatbuffers::FlatBufferBuilder* fbb); - - template - std::vector RemapTensorIndicesToWritten(const T& input); - - int GetOpCodeForBuiltin(int builtin_op_index) { - // auto it = builtin_op_to_opcode_.find(builtin_op_index); - std::pair result = - builtin_op_to_opcode_.insert( - std::make_pair(builtin_op_index, opcodes_.size())); - if (result.second) { - opcodes_.push_back({builtin_op_index, ""}); - } - return result.first->second; - } - - int GetOpCodeForCustom(const std::string& custom_name) { - std::pair result = - custom_op_to_opcode_.insert( - std::make_pair(custom_name, opcodes_.size())); - if (result.second) { - opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name}); - } - return result.first->second; - } - - // The interpreter we are writing - Interpreter* interpreter_; - // Keep track of byte buffers - std::vector> buffers_; - // List of op codes and mappings from builtin or custom op to opcode - struct OpCode { - int builtin; - std::string custom; - }; - // For every tensor index in the interpreter, the index in the written. - // This is different due to temporary tensors not being written. - std::vector tensor_to_written_tensor_; - // List of used opcodes - std::vector opcodes_; - std::unordered_map builtin_op_to_opcode_; - std::unordered_map custom_op_to_opcode_; - std::unordered_map custom_op_to_writer_; -}; - -} // namespace tflite - -#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_ diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc deleted file mode 100644 index 49194a76c8..0000000000 --- a/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h" -#include -#include "tensorflow/contrib/lite/interpreter.h" -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/testing/util.h" - -namespace tflite { -// Make an interpreter that has no tensors and no nodes -// TODO(b/113731921): add more tests. -TEST(Writer, BasicTest) { - Interpreter interpreter; - interpreter.AddTensors(3); - float foo[] = {1, 2, 3}; - interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3}, - TfLiteQuantizationParams()); - interpreter.SetTensorParametersReadOnly( - 1, kTfLiteFloat32, "b", {3}, TfLiteQuantizationParams(), - reinterpret_cast(foo), sizeof(foo)); - interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3}, - TfLiteQuantizationParams()); - interpreter.SetInputs({0, 1}); - interpreter.SetOutputs({2}); - const char* initial_data = ""; - tflite::ops::builtin::BuiltinOpResolver resolver; - TfLiteAddParams* builtin_data = - reinterpret_cast(malloc(sizeof(TfLiteAddParams))); - builtin_data->activation = kTfLiteActNone; - const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1); - interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0, - reinterpret_cast(builtin_data), reg); - - InterpreterWriter writer(&interpreter); - writer.Write("/tmp/test.tflite"); - std::unique_ptr model = - FlatBufferModel::BuildFromFile("/tmp/test.tflite"); - InterpreterBuilder builder(*model, resolver); - std::unique_ptr new_interpreter; - builder(&new_interpreter); -} - -} // namespace tflite - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/tensorflow/contrib/lite/op_resolver.cc b/tensorflow/contrib/lite/op_resolver.cc index a9885f7737..f6e435e982 100644 --- a/tensorflow/contrib/lite/op_resolver.cc +++ b/tensorflow/contrib/lite/op_resolver.cc @@ -46,8 +46,6 @@ void MutableOpResolver::AddCustom(const char* name, TfLiteRegistration* registration, int min_version, int max_version) { for (int version = min_version; version <= max_version; ++version) { - // TODO(aselle): This should verify that the incoming registration - // has the name in the registration already and it matches!!! TfLiteRegistration new_registration = *registration; new_registration.builtin_code = BuiltinOperator_CUSTOM; new_registration.version = version; diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD index 55bf2c48b9..28a7e50003 100644 --- a/tensorflow/contrib/lite/schema/BUILD +++ b/tensorflow/contrib/lite/schema/BUILD @@ -56,20 +56,6 @@ flatbuffer_cc_library( srcs = ["schema.fbs"], ) -# Generic schema for inference on device (but with reflections makes bigger). -flatbuffer_cc_library( - name = "schema_fbs_with_reflection", - srcs = ["schema.fbs"], - flatc_args = [ - "--reflect-types", - "--reflect-names", - "--no-union-value-namespacing", - "--gen-object-api", - ], - gen_reflections = True, - out_prefix = "reflection/", -) - # Schema test to make sure we don't introduce backward incompatible changes # to schemas. cc_test( -- GitLab From ec6ea3ad0ac405c2516036d0ccf60149fad9c4c4 Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Tue, 4 Sep 2018 16:05:05 -0700 Subject: [PATCH 083/540] contrib/distributions: Test code cleanups - Remove unnecessary test_session() boilerplate when executing eagerly - Use self.cached_session() instead of self.test_session() when using graphs self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 211542360 --- .../distributions/bernoulli_test.py | 196 +++--- .../kernel_tests/distributions/beta_test.py | 462 +++++++------ .../distributions/bijector_test.py | 13 +- .../distributions/dirichlet_test.py | 262 ++++---- .../distributions/exponential_test.py | 187 +++--- .../kernel_tests/distributions/gamma_test.py | 529 ++++++++------- .../distributions/laplace_test.py | 439 ++++++------- .../kernel_tests/distributions/normal_test.py | 607 +++++++++--------- .../distributions/special_math_test.py | 35 +- .../distributions/student_t_test.py | 505 +++++++-------- .../distributions/uniform_test.py | 354 +++++----- .../kernel_tests/distributions/util_test.py | 230 +++---- 12 files changed, 1803 insertions(+), 2016 deletions(-) diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py index 9ad77a54cb..26d013bccb 100644 --- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py +++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py @@ -62,59 +62,50 @@ class BernoulliTest(test.TestCase): def testP(self): p = [0.2, 0.4] dist = bernoulli.Bernoulli(probs=p) - with self.test_session(): - self.assertAllClose(p, self.evaluate(dist.probs)) + self.assertAllClose(p, self.evaluate(dist.probs)) @test_util.run_in_graph_and_eager_modes def testLogits(self): logits = [-42., 42.] dist = bernoulli.Bernoulli(logits=logits) - with self.test_session(): - self.assertAllClose(logits, self.evaluate(dist.logits)) + self.assertAllClose(logits, self.evaluate(dist.logits)) if not special: return - with self.test_session(): - self.assertAllClose(special.expit(logits), self.evaluate(dist.probs)) + self.assertAllClose(special.expit(logits), self.evaluate(dist.probs)) p = [0.01, 0.99, 0.42] dist = bernoulli.Bernoulli(probs=p) - with self.test_session(): - self.assertAllClose(special.logit(p), self.evaluate(dist.logits)) + self.assertAllClose(special.logit(p), self.evaluate(dist.logits)) @test_util.run_in_graph_and_eager_modes def testInvalidP(self): invalid_ps = [1.01, 2.] for p in invalid_ps: - with self.test_session(): - with self.assertRaisesOpError("probs has components greater than 1"): - dist = bernoulli.Bernoulli(probs=p, validate_args=True) - self.evaluate(dist.probs) + with self.assertRaisesOpError("probs has components greater than 1"): + dist = bernoulli.Bernoulli(probs=p, validate_args=True) + self.evaluate(dist.probs) invalid_ps = [-0.01, -3.] for p in invalid_ps: - with self.test_session(): - with self.assertRaisesOpError("Condition x >= 0"): - dist = bernoulli.Bernoulli(probs=p, validate_args=True) - self.evaluate(dist.probs) + with self.assertRaisesOpError("Condition x >= 0"): + dist = bernoulli.Bernoulli(probs=p, validate_args=True) + self.evaluate(dist.probs) valid_ps = [0.0, 0.5, 1.0] for p in valid_ps: - with self.test_session(): - dist = bernoulli.Bernoulli(probs=p) - self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail + dist = bernoulli.Bernoulli(probs=p) + self.assertEqual(p, self.evaluate(dist.probs)) # Should not fail @test_util.run_in_graph_and_eager_modes def testShapes(self): - with self.test_session(): - for batch_shape in ([], [1], [2, 3, 4]): - dist = make_bernoulli(batch_shape) - self.assertAllEqual(batch_shape, dist.batch_shape.as_list()) - self.assertAllEqual(batch_shape, - self.evaluate(dist.batch_shape_tensor())) - self.assertAllEqual([], dist.event_shape.as_list()) - self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) + for batch_shape in ([], [1], [2, 3, 4]): + dist = make_bernoulli(batch_shape) + self.assertAllEqual(batch_shape, dist.batch_shape.as_list()) + self.assertAllEqual(batch_shape, self.evaluate(dist.batch_shape_tensor())) + self.assertAllEqual([], dist.event_shape.as_list()) + self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) @test_util.run_in_graph_and_eager_modes def testDtype(self): @@ -137,31 +128,29 @@ class BernoulliTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def _testPmf(self, **kwargs): dist = bernoulli.Bernoulli(**kwargs) - with self.test_session(): - # pylint: disable=bad-continuation - xs = [ - 0, - [1], - [1, 0], - [[1, 0]], - [[1, 0], [1, 1]], - ] - expected_pmfs = [ - [[0.8, 0.6], [0.7, 0.4]], - [[0.2, 0.4], [0.3, 0.6]], - [[0.2, 0.6], [0.3, 0.4]], - [[0.2, 0.6], [0.3, 0.4]], - [[0.2, 0.6], [0.3, 0.6]], - ] - # pylint: enable=bad-continuation - - for x, expected_pmf in zip(xs, expected_pmfs): - self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf) - self.assertAllClose( - self.evaluate(dist.log_prob(x)), np.log(expected_pmf)) + # pylint: disable=bad-continuation + xs = [ + 0, + [1], + [1, 0], + [[1, 0]], + [[1, 0], [1, 1]], + ] + expected_pmfs = [ + [[0.8, 0.6], [0.7, 0.4]], + [[0.2, 0.4], [0.3, 0.6]], + [[0.2, 0.6], [0.3, 0.4]], + [[0.2, 0.6], [0.3, 0.4]], + [[0.2, 0.6], [0.3, 0.6]], + ] + # pylint: enable=bad-continuation + + for x, expected_pmf in zip(xs, expected_pmfs): + self.assertAllClose(self.evaluate(dist.prob(x)), expected_pmf) + self.assertAllClose(self.evaluate(dist.log_prob(x)), np.log(expected_pmf)) def testPmfCorrectBroadcastDynamicShape(self): - with self.test_session(): + with self.cached_session(): p = array_ops.placeholder(dtype=dtypes.float32) dist = bernoulli.Bernoulli(probs=p) event1 = [1, 0, 1] @@ -178,12 +167,11 @@ class BernoulliTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testPmfInvalid(self): p = [0.1, 0.2, 0.7] - with self.test_session(): - dist = bernoulli.Bernoulli(probs=p, validate_args=True) - with self.assertRaisesOpError("must be non-negative."): - self.evaluate(dist.prob([1, 1, -1])) - with self.assertRaisesOpError("Elements cannot exceed 1."): - self.evaluate(dist.prob([2, 0, 1])) + dist = bernoulli.Bernoulli(probs=p, validate_args=True) + with self.assertRaisesOpError("must be non-negative."): + self.evaluate(dist.prob([1, 1, -1])) + with self.assertRaisesOpError("Elements cannot exceed 1."): + self.evaluate(dist.prob([2, 0, 1])) @test_util.run_in_graph_and_eager_modes def testPmfWithP(self): @@ -194,7 +182,7 @@ class BernoulliTest(test.TestCase): self._testPmf(logits=special.logit(p)) def testBroadcasting(self): - with self.test_session(): + with self.cached_session(): p = array_ops.placeholder(dtypes.float32) dist = bernoulli.Bernoulli(probs=p) self.assertAllClose(np.log(0.5), dist.log_prob(1).eval({p: 0.5})) @@ -208,70 +196,63 @@ class BernoulliTest(test.TestCase): })) def testPmfShapes(self): - with self.test_session(): + with self.cached_session(): p = array_ops.placeholder(dtypes.float32, shape=[None, 1]) dist = bernoulli.Bernoulli(probs=p) self.assertEqual(2, len(dist.log_prob(1).eval({p: [[0.5], [0.5]]}).shape)) - with self.test_session(): dist = bernoulli.Bernoulli(probs=0.5) self.assertEqual(2, len(self.evaluate(dist.log_prob([[1], [1]])).shape)) - with self.test_session(): dist = bernoulli.Bernoulli(probs=0.5) self.assertEqual((), dist.log_prob(1).get_shape()) self.assertEqual((1), dist.log_prob([1]).get_shape()) self.assertEqual((2, 1), dist.log_prob([[1], [1]]).get_shape()) - with self.test_session(): dist = bernoulli.Bernoulli(probs=[[0.5], [0.5]]) self.assertEqual((2, 1), dist.log_prob(1).get_shape()) @test_util.run_in_graph_and_eager_modes def testBoundaryConditions(self): - with self.test_session(): - dist = bernoulli.Bernoulli(probs=1.0) - self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0))) - self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))]) + dist = bernoulli.Bernoulli(probs=1.0) + self.assertAllClose(np.nan, self.evaluate(dist.log_prob(0))) + self.assertAllClose([np.nan], [self.evaluate(dist.log_prob(1))]) @test_util.run_in_graph_and_eager_modes def testEntropyNoBatch(self): p = 0.2 dist = bernoulli.Bernoulli(probs=p) - with self.test_session(): - self.assertAllClose(self.evaluate(dist.entropy()), entropy(p)) + self.assertAllClose(self.evaluate(dist.entropy()), entropy(p)) @test_util.run_in_graph_and_eager_modes def testEntropyWithBatch(self): p = [[0.1, 0.7], [0.2, 0.6]] dist = bernoulli.Bernoulli(probs=p, validate_args=False) - with self.test_session(): - self.assertAllClose( - self.evaluate(dist.entropy()), - [[entropy(0.1), entropy(0.7)], [entropy(0.2), - entropy(0.6)]]) + self.assertAllClose( + self.evaluate(dist.entropy()), + [[entropy(0.1), entropy(0.7)], [entropy(0.2), + entropy(0.6)]]) @test_util.run_in_graph_and_eager_modes def testSampleN(self): - with self.test_session(): - p = [0.2, 0.6] - dist = bernoulli.Bernoulli(probs=p) - n = 100000 - samples = dist.sample(n) - samples.set_shape([n, 2]) - self.assertEqual(samples.dtype, dtypes.int32) - sample_values = self.evaluate(samples) - self.assertTrue(np.all(sample_values >= 0)) - self.assertTrue(np.all(sample_values <= 1)) - # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) / - # n). This means that the tolerance is very sensitive to the value of p - # as well as n. - self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2) - self.assertEqual(set([0, 1]), set(sample_values.flatten())) - # In this test we're just interested in verifying there isn't a crash - # owing to mismatched types. b/30940152 - dist = bernoulli.Bernoulli(np.log([.2, .4])) - self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list()) + p = [0.2, 0.6] + dist = bernoulli.Bernoulli(probs=p) + n = 100000 + samples = dist.sample(n) + samples.set_shape([n, 2]) + self.assertEqual(samples.dtype, dtypes.int32) + sample_values = self.evaluate(samples) + self.assertTrue(np.all(sample_values >= 0)) + self.assertTrue(np.all(sample_values <= 1)) + # Note that the standard error for the sample mean is ~ sqrt(p * (1 - p) / + # n). This means that the tolerance is very sensitive to the value of p + # as well as n. + self.assertAllClose(p, np.mean(sample_values, axis=0), atol=1e-2) + self.assertEqual(set([0, 1]), set(sample_values.flatten())) + # In this test we're just interested in verifying there isn't a crash + # owing to mismatched types. b/30940152 + dist = bernoulli.Bernoulli(np.log([.2, .4])) + self.assertAllEqual((1, 2), dist.sample(1, seed=42).get_shape().as_list()) @test_util.run_in_graph_and_eager_modes def testNotReparameterized(self): @@ -284,7 +265,7 @@ class BernoulliTest(test.TestCase): self.assertIsNone(grad_p) def testSampleActsLikeSampleN(self): - with self.test_session() as sess: + with self.cached_session() as sess: p = [0.2, 0.6] dist = bernoulli.Bernoulli(probs=p) n = 1000 @@ -299,27 +280,24 @@ class BernoulliTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testMean(self): - with self.test_session(): - p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) - dist = bernoulli.Bernoulli(probs=p) - self.assertAllEqual(self.evaluate(dist.mean()), p) + p = np.array([[0.2, 0.7], [0.5, 0.4]], dtype=np.float32) + dist = bernoulli.Bernoulli(probs=p) + self.assertAllEqual(self.evaluate(dist.mean()), p) @test_util.run_in_graph_and_eager_modes def testVarianceAndStd(self): var = lambda p: p * (1. - p) - with self.test_session(): - p = [[0.2, 0.7], [0.5, 0.4]] - dist = bernoulli.Bernoulli(probs=p) - self.assertAllClose( - self.evaluate(dist.variance()), - np.array( - [[var(0.2), var(0.7)], [var(0.5), var(0.4)]], dtype=np.float32)) - self.assertAllClose( - self.evaluate(dist.stddev()), - np.array( - [[np.sqrt(var(0.2)), np.sqrt(var(0.7))], - [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], - dtype=np.float32)) + p = [[0.2, 0.7], [0.5, 0.4]] + dist = bernoulli.Bernoulli(probs=p) + self.assertAllClose( + self.evaluate(dist.variance()), + np.array([[var(0.2), var(0.7)], [var(0.5), var(0.4)]], + dtype=np.float32)) + self.assertAllClose( + self.evaluate(dist.stddev()), + np.array([[np.sqrt(var(0.2)), np.sqrt(var(0.7))], + [np.sqrt(var(0.5)), np.sqrt(var(0.4))]], + dtype=np.float32)) @test_util.run_in_graph_and_eager_modes def testBernoulliBernoulliKL(self): diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py index 36f3ffc333..d580a415dd 100644 --- a/tensorflow/python/kernel_tests/distributions/beta_test.py +++ b/tensorflow/python/kernel_tests/distributions/beta_test.py @@ -20,7 +20,6 @@ import importlib import numpy as np -from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import random_seed @@ -51,237 +50,215 @@ stats = try_import("scipy.stats") class BetaTest(test.TestCase): def testSimpleShapes(self): - with self.test_session(): - a = np.random.rand(3) - b = np.random.rand(3) - dist = beta_lib.Beta(a, b) - self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) - self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor())) - self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) - self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape) + a = np.random.rand(3) + b = np.random.rand(3) + dist = beta_lib.Beta(a, b) + self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([3], self.evaluate(dist.batch_shape_tensor())) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape) def testComplexShapes(self): - with self.test_session(): - a = np.random.rand(3, 2, 2) - b = np.random.rand(3, 2, 2) - dist = beta_lib.Beta(a, b) - self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) - self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) - self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) - self.assertEqual( - tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) + a = np.random.rand(3, 2, 2) + b = np.random.rand(3, 2, 2) + dist = beta_lib.Beta(a, b) + self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) def testComplexShapesBroadcast(self): - with self.test_session(): - a = np.random.rand(3, 2, 2) - b = np.random.rand(2, 2) - dist = beta_lib.Beta(a, b) - self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) - self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) - self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) - self.assertEqual( - tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) + a = np.random.rand(3, 2, 2) + b = np.random.rand(2, 2) + dist = beta_lib.Beta(a, b) + self.assertAllEqual([], self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([3, 2, 2], self.evaluate(dist.batch_shape_tensor())) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) def testAlphaProperty(self): a = [[1., 2, 3]] b = [[2., 4, 3]] - with self.test_session(): - dist = beta_lib.Beta(a, b) - self.assertEqual([1, 3], dist.concentration1.get_shape()) - self.assertAllClose(a, self.evaluate(dist.concentration1)) + dist = beta_lib.Beta(a, b) + self.assertEqual([1, 3], dist.concentration1.get_shape()) + self.assertAllClose(a, self.evaluate(dist.concentration1)) def testBetaProperty(self): a = [[1., 2, 3]] b = [[2., 4, 3]] - with self.test_session(): - dist = beta_lib.Beta(a, b) - self.assertEqual([1, 3], dist.concentration0.get_shape()) - self.assertAllClose(b, self.evaluate(dist.concentration0)) + dist = beta_lib.Beta(a, b) + self.assertEqual([1, 3], dist.concentration0.get_shape()) + self.assertAllClose(b, self.evaluate(dist.concentration0)) def testPdfXProper(self): a = [[1., 2, 3]] b = [[2., 4, 3]] - with self.test_session(): - dist = beta_lib.Beta(a, b, validate_args=True) - self.evaluate(dist.prob([.1, .3, .6])) - self.evaluate(dist.prob([.2, .3, .5])) - # Either condition can trigger. - with self.assertRaisesOpError("sample must be positive"): - self.evaluate(dist.prob([-1., 0.1, 0.5])) - with self.assertRaisesOpError("sample must be positive"): - self.evaluate(dist.prob([0., 0.1, 0.5])) - with self.assertRaisesOpError("sample must be less than `1`"): - self.evaluate(dist.prob([.1, .2, 1.2])) - with self.assertRaisesOpError("sample must be less than `1`"): - self.evaluate(dist.prob([.1, .2, 1.0])) + dist = beta_lib.Beta(a, b, validate_args=True) + self.evaluate(dist.prob([.1, .3, .6])) + self.evaluate(dist.prob([.2, .3, .5])) + # Either condition can trigger. + with self.assertRaisesOpError("sample must be positive"): + self.evaluate(dist.prob([-1., 0.1, 0.5])) + with self.assertRaisesOpError("sample must be positive"): + self.evaluate(dist.prob([0., 0.1, 0.5])) + with self.assertRaisesOpError("sample must be less than `1`"): + self.evaluate(dist.prob([.1, .2, 1.2])) + with self.assertRaisesOpError("sample must be less than `1`"): + self.evaluate(dist.prob([.1, .2, 1.0])) def testPdfTwoBatches(self): - with self.test_session(): - a = [1., 2] - b = [1., 2] - x = [.5, .5] - dist = beta_lib.Beta(a, b) - pdf = dist.prob(x) - self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) - self.assertEqual((2,), pdf.get_shape()) + a = [1., 2] + b = [1., 2] + x = [.5, .5] + dist = beta_lib.Beta(a, b) + pdf = dist.prob(x) + self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) + self.assertEqual((2,), pdf.get_shape()) def testPdfTwoBatchesNontrivialX(self): - with self.test_session(): - a = [1., 2] - b = [1., 2] - x = [.3, .7] - dist = beta_lib.Beta(a, b) - pdf = dist.prob(x) - self.assertAllClose([1, 63. / 50], self.evaluate(pdf)) - self.assertEqual((2,), pdf.get_shape()) + a = [1., 2] + b = [1., 2] + x = [.3, .7] + dist = beta_lib.Beta(a, b) + pdf = dist.prob(x) + self.assertAllClose([1, 63. / 50], self.evaluate(pdf)) + self.assertEqual((2,), pdf.get_shape()) def testPdfUniformZeroBatch(self): - with self.test_session(): - # This is equivalent to a uniform distribution - a = 1. - b = 1. - x = np.array([.1, .2, .3, .5, .8], dtype=np.float32) - dist = beta_lib.Beta(a, b) - pdf = dist.prob(x) - self.assertAllClose([1.] * 5, self.evaluate(pdf)) - self.assertEqual((5,), pdf.get_shape()) + # This is equivalent to a uniform distribution + a = 1. + b = 1. + x = np.array([.1, .2, .3, .5, .8], dtype=np.float32) + dist = beta_lib.Beta(a, b) + pdf = dist.prob(x) + self.assertAllClose([1.] * 5, self.evaluate(pdf)) + self.assertEqual((5,), pdf.get_shape()) def testPdfAlphaStretchedInBroadcastWhenSameRank(self): - with self.test_session(): - a = [[1., 2]] - b = [[1., 2]] - x = [[.5, .5], [.3, .7]] - dist = beta_lib.Beta(a, b) - pdf = dist.prob(x) - self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf)) - self.assertEqual((2, 2), pdf.get_shape()) + a = [[1., 2]] + b = [[1., 2]] + x = [[.5, .5], [.3, .7]] + dist = beta_lib.Beta(a, b) + pdf = dist.prob(x) + self.assertAllClose([[1., 3. / 2], [1., 63. / 50]], self.evaluate(pdf)) + self.assertEqual((2, 2), pdf.get_shape()) def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): - with self.test_session(): - a = [1., 2] - b = [1., 2] - x = [[.5, .5], [.2, .8]] - pdf = beta_lib.Beta(a, b).prob(x) - self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf)) - self.assertEqual((2, 2), pdf.get_shape()) + a = [1., 2] + b = [1., 2] + x = [[.5, .5], [.2, .8]] + pdf = beta_lib.Beta(a, b).prob(x) + self.assertAllClose([[1., 3. / 2], [1., 24. / 25]], self.evaluate(pdf)) + self.assertEqual((2, 2), pdf.get_shape()) def testPdfXStretchedInBroadcastWhenSameRank(self): - with self.test_session(): - a = [[1., 2], [2., 3]] - b = [[1., 2], [2., 3]] - x = [[.5, .5]] - pdf = beta_lib.Beta(a, b).prob(x) - self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf)) - self.assertEqual((2, 2), pdf.get_shape()) + a = [[1., 2], [2., 3]] + b = [[1., 2], [2., 3]] + x = [[.5, .5]] + pdf = beta_lib.Beta(a, b).prob(x) + self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf)) + self.assertEqual((2, 2), pdf.get_shape()) def testPdfXStretchedInBroadcastWhenLowerRank(self): - with self.test_session(): - a = [[1., 2], [2., 3]] - b = [[1., 2], [2., 3]] - x = [.5, .5] - pdf = beta_lib.Beta(a, b).prob(x) - self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf)) - self.assertEqual((2, 2), pdf.get_shape()) + a = [[1., 2], [2., 3]] + b = [[1., 2], [2., 3]] + x = [.5, .5] + pdf = beta_lib.Beta(a, b).prob(x) + self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf)) + self.assertEqual((2, 2), pdf.get_shape()) def testBetaMean(self): - with session.Session(): - a = [1., 2, 3] - b = [2., 4, 1.2] - dist = beta_lib.Beta(a, b) - self.assertEqual(dist.mean().get_shape(), (3,)) - if not stats: - return - expected_mean = stats.beta.mean(a, b) - self.assertAllClose(expected_mean, self.evaluate(dist.mean())) + a = [1., 2, 3] + b = [2., 4, 1.2] + dist = beta_lib.Beta(a, b) + self.assertEqual(dist.mean().get_shape(), (3,)) + if not stats: + return + expected_mean = stats.beta.mean(a, b) + self.assertAllClose(expected_mean, self.evaluate(dist.mean())) def testBetaVariance(self): - with session.Session(): - a = [1., 2, 3] - b = [2., 4, 1.2] - dist = beta_lib.Beta(a, b) - self.assertEqual(dist.variance().get_shape(), (3,)) - if not stats: - return - expected_variance = stats.beta.var(a, b) - self.assertAllClose(expected_variance, self.evaluate(dist.variance())) + a = [1., 2, 3] + b = [2., 4, 1.2] + dist = beta_lib.Beta(a, b) + self.assertEqual(dist.variance().get_shape(), (3,)) + if not stats: + return + expected_variance = stats.beta.var(a, b) + self.assertAllClose(expected_variance, self.evaluate(dist.variance())) def testBetaMode(self): - with session.Session(): - a = np.array([1.1, 2, 3]) - b = np.array([2., 4, 1.2]) - expected_mode = (a - 1) / (a + b - 2) - dist = beta_lib.Beta(a, b) - self.assertEqual(dist.mode().get_shape(), (3,)) - self.assertAllClose(expected_mode, self.evaluate(dist.mode())) + a = np.array([1.1, 2, 3]) + b = np.array([2., 4, 1.2]) + expected_mode = (a - 1) / (a + b - 2) + dist = beta_lib.Beta(a, b) + self.assertEqual(dist.mode().get_shape(), (3,)) + self.assertAllClose(expected_mode, self.evaluate(dist.mode())) def testBetaModeInvalid(self): - with session.Session(): - a = np.array([1., 2, 3]) - b = np.array([2., 4, 1.2]) - dist = beta_lib.Beta(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): - self.evaluate(dist.mode()) - - a = np.array([2., 2, 3]) - b = np.array([1., 4, 1.2]) - dist = beta_lib.Beta(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): - self.evaluate(dist.mode()) + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = beta_lib.Beta(a, b, allow_nan_stats=False) + with self.assertRaisesOpError("Condition x < y.*"): + self.evaluate(dist.mode()) + + a = np.array([2., 2, 3]) + b = np.array([1., 4, 1.2]) + dist = beta_lib.Beta(a, b, allow_nan_stats=False) + with self.assertRaisesOpError("Condition x < y.*"): + self.evaluate(dist.mode()) def testBetaModeEnableAllowNanStats(self): - with session.Session(): - a = np.array([1., 2, 3]) - b = np.array([2., 4, 1.2]) - dist = beta_lib.Beta(a, b, allow_nan_stats=True) + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = beta_lib.Beta(a, b, allow_nan_stats=True) - expected_mode = (a - 1) / (a + b - 2) - expected_mode[0] = np.nan - self.assertEqual((3,), dist.mode().get_shape()) - self.assertAllClose(expected_mode, self.evaluate(dist.mode())) + expected_mode = (a - 1) / (a + b - 2) + expected_mode[0] = np.nan + self.assertEqual((3,), dist.mode().get_shape()) + self.assertAllClose(expected_mode, self.evaluate(dist.mode())) - a = np.array([2., 2, 3]) - b = np.array([1., 4, 1.2]) - dist = beta_lib.Beta(a, b, allow_nan_stats=True) + a = np.array([2., 2, 3]) + b = np.array([1., 4, 1.2]) + dist = beta_lib.Beta(a, b, allow_nan_stats=True) - expected_mode = (a - 1) / (a + b - 2) - expected_mode[0] = np.nan - self.assertEqual((3,), dist.mode().get_shape()) - self.assertAllClose(expected_mode, self.evaluate(dist.mode())) + expected_mode = (a - 1) / (a + b - 2) + expected_mode[0] = np.nan + self.assertEqual((3,), dist.mode().get_shape()) + self.assertAllClose(expected_mode, self.evaluate(dist.mode())) def testBetaEntropy(self): - with session.Session(): - a = [1., 2, 3] - b = [2., 4, 1.2] - dist = beta_lib.Beta(a, b) - self.assertEqual(dist.entropy().get_shape(), (3,)) - if not stats: - return - expected_entropy = stats.beta.entropy(a, b) - self.assertAllClose(expected_entropy, self.evaluate(dist.entropy())) + a = [1., 2, 3] + b = [2., 4, 1.2] + dist = beta_lib.Beta(a, b) + self.assertEqual(dist.entropy().get_shape(), (3,)) + if not stats: + return + expected_entropy = stats.beta.entropy(a, b) + self.assertAllClose(expected_entropy, self.evaluate(dist.entropy())) def testBetaSample(self): - with self.test_session(): - a = 1. - b = 2. - beta = beta_lib.Beta(a, b) - n = constant_op.constant(100000) - samples = beta.sample(n) - sample_values = self.evaluate(samples) - self.assertEqual(sample_values.shape, (100000,)) - self.assertFalse(np.any(sample_values < 0.0)) - if not stats: - return - self.assertLess( - stats.kstest( - # Beta is a univariate distribution. - sample_values, - stats.beta(a=1., b=2.).cdf)[0], - 0.01) - # The standard error of the sample mean is 1 / (sqrt(18 * n)) - self.assertAllClose( - sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2) - self.assertAllClose( - np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1) + a = 1. + b = 2. + beta = beta_lib.Beta(a, b) + n = constant_op.constant(100000) + samples = beta.sample(n) + sample_values = self.evaluate(samples) + self.assertEqual(sample_values.shape, (100000,)) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + self.assertLess( + stats.kstest( + # Beta is a univariate distribution. + sample_values, + stats.beta(a=1., b=2.).cdf)[0], + 0.01) + # The standard error of the sample mean is 1 / (sqrt(18 * n)) + self.assertAllClose( + sample_values.mean(axis=0), stats.beta.mean(a, b), atol=1e-2) + self.assertAllClose( + np.cov(sample_values, rowvar=0), stats.beta.var(a, b), atol=1e-1) def testBetaFullyReparameterized(self): a = constant_op.constant(1.0) @@ -297,78 +274,71 @@ class BetaTest(test.TestCase): # Test that sampling with the same seed twice gives the same results. def testBetaSampleMultipleTimes(self): - with self.test_session(): - a_val = 1. - b_val = 2. - n_val = 100 + a_val = 1. + b_val = 2. + n_val = 100 - random_seed.set_random_seed(654321) - beta1 = beta_lib.Beta(concentration1=a_val, - concentration0=b_val, - name="beta1") - samples1 = self.evaluate(beta1.sample(n_val, seed=123456)) + random_seed.set_random_seed(654321) + beta1 = beta_lib.Beta( + concentration1=a_val, concentration0=b_val, name="beta1") + samples1 = self.evaluate(beta1.sample(n_val, seed=123456)) - random_seed.set_random_seed(654321) - beta2 = beta_lib.Beta(concentration1=a_val, - concentration0=b_val, - name="beta2") - samples2 = self.evaluate(beta2.sample(n_val, seed=123456)) + random_seed.set_random_seed(654321) + beta2 = beta_lib.Beta( + concentration1=a_val, concentration0=b_val, name="beta2") + samples2 = self.evaluate(beta2.sample(n_val, seed=123456)) - self.assertAllClose(samples1, samples2) + self.assertAllClose(samples1, samples2) def testBetaSampleMultidimensional(self): - with self.test_session(): - a = np.random.rand(3, 2, 2).astype(np.float32) - b = np.random.rand(3, 2, 2).astype(np.float32) - beta = beta_lib.Beta(a, b) - n = constant_op.constant(100000) - samples = beta.sample(n) - sample_values = self.evaluate(samples) - self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) - self.assertFalse(np.any(sample_values < 0.0)) - if not stats: - return - self.assertAllClose( - sample_values[:, 1, :].mean(axis=0), - stats.beta.mean(a, b)[1, :], - atol=1e-1) + a = np.random.rand(3, 2, 2).astype(np.float32) + b = np.random.rand(3, 2, 2).astype(np.float32) + beta = beta_lib.Beta(a, b) + n = constant_op.constant(100000) + samples = beta.sample(n) + sample_values = self.evaluate(samples) + self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + self.assertAllClose( + sample_values[:, 1, :].mean(axis=0), + stats.beta.mean(a, b)[1, :], + atol=1e-1) def testBetaCdf(self): - with self.test_session(): - shape = (30, 40, 50) - for dt in (np.float32, np.float64): - a = 10. * np.random.random(shape).astype(dt) - b = 10. * np.random.random(shape).astype(dt) - x = np.random.random(shape).astype(dt) - actual = self.evaluate(beta_lib.Beta(a, b).cdf(x)) - self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) - self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) - if not stats: - return - self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) + shape = (30, 40, 50) + for dt in (np.float32, np.float64): + a = 10. * np.random.random(shape).astype(dt) + b = 10. * np.random.random(shape).astype(dt) + x = np.random.random(shape).astype(dt) + actual = self.evaluate(beta_lib.Beta(a, b).cdf(x)) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) + if not stats: + return + self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) def testBetaLogCdf(self): - with self.test_session(): - shape = (30, 40, 50) - for dt in (np.float32, np.float64): - a = 10. * np.random.random(shape).astype(dt) - b = 10. * np.random.random(shape).astype(dt) - x = np.random.random(shape).astype(dt) - actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x))) - self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) - self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) - if not stats: - return - self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) + shape = (30, 40, 50) + for dt in (np.float32, np.float64): + a = 10. * np.random.random(shape).astype(dt) + b = 10. * np.random.random(shape).astype(dt) + x = np.random.random(shape).astype(dt) + actual = self.evaluate(math_ops.exp(beta_lib.Beta(a, b).log_cdf(x))) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) + if not stats: + return + self.assertAllClose(stats.beta.cdf(x, a, b), actual, rtol=1e-4, atol=0) def testBetaWithSoftplusConcentration(self): - with self.test_session(): - a, b = -4.2, -9.1 - dist = beta_lib.BetaWithSoftplusConcentration(a, b) - self.assertAllClose( - self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1)) - self.assertAllClose( - self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0)) + a, b = -4.2, -9.1 + dist = beta_lib.BetaWithSoftplusConcentration(a, b) + self.assertAllClose( + self.evaluate(nn_ops.softplus(a)), self.evaluate(dist.concentration1)) + self.assertAllClose( + self.evaluate(nn_ops.softplus(b)), self.evaluate(dist.concentration0)) def testBetaBetaKL(self): for shape in [(10,), (4, 5)]: diff --git a/tensorflow/python/kernel_tests/distributions/bijector_test.py b/tensorflow/python/kernel_tests/distributions/bijector_test.py index 8b11556330..e20f59f48a 100644 --- a/tensorflow/python/kernel_tests/distributions/bijector_test.py +++ b/tensorflow/python/kernel_tests/distributions/bijector_test.py @@ -36,11 +36,10 @@ class BaseBijectorTest(test.TestCase): """Tests properties of the Bijector base-class.""" def testIsAbstract(self): - with self.test_session(): - with self.assertRaisesRegexp(TypeError, - ("Can't instantiate abstract class Bijector " - "with abstract methods __init__")): - bijector.Bijector() # pylint: disable=abstract-class-instantiated + with self.assertRaisesRegexp(TypeError, + ("Can't instantiate abstract class Bijector " + "with abstract methods __init__")): + bijector.Bijector() # pylint: disable=abstract-class-instantiated def testDefaults(self): class _BareBonesBijector(bijector.Bijector): @@ -136,7 +135,7 @@ class BijectorTestEventNdims(test.TestCase): def testBijectorDynamicEventNdims(self): bij = BrokenBijector(validate_args=True) event_ndims = array_ops.placeholder(dtype=np.int32, shape=None) - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Expected scalar"): bij.forward_log_det_jacobian(1., event_ndims=event_ndims).eval({ event_ndims: (1, 2)}) @@ -308,7 +307,7 @@ class BijectorReduceEventDimsTest(test.TestCase): event_ndims = array_ops.placeholder(dtype=np.int32, shape=[]) bij = ExpOnlyJacobian(forward_min_event_ndims=1) bij.inverse_log_det_jacobian(x, event_ndims=event_ndims) - with self.test_session() as sess: + with self.cached_session() as sess: ildj = sess.run(bij.inverse_log_det_jacobian(x, event_ndims=event_ndims), feed_dict={event_ndims: 1}) self.assertAllClose(-np.log(x_), ildj) diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py index 67ed0447ed..cace5b3ba2 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py @@ -49,115 +49,102 @@ stats = try_import("scipy.stats") class DirichletTest(test.TestCase): def testSimpleShapes(self): - with self.test_session(): - alpha = np.random.rand(3) - dist = dirichlet_lib.Dirichlet(alpha) - self.assertEqual(3, self.evaluate(dist.event_shape_tensor())) - self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor())) - self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape) - self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape) + alpha = np.random.rand(3) + dist = dirichlet_lib.Dirichlet(alpha) + self.assertEqual(3, self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([], self.evaluate(dist.batch_shape_tensor())) + self.assertEqual(tensor_shape.TensorShape([3]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape) def testComplexShapes(self): - with self.test_session(): - alpha = np.random.rand(3, 2, 2) - dist = dirichlet_lib.Dirichlet(alpha) - self.assertEqual(2, self.evaluate(dist.event_shape_tensor())) - self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor())) - self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) - self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape) + alpha = np.random.rand(3, 2, 2) + dist = dirichlet_lib.Dirichlet(alpha) + self.assertEqual(2, self.evaluate(dist.event_shape_tensor())) + self.assertAllEqual([3, 2], self.evaluate(dist.batch_shape_tensor())) + self.assertEqual(tensor_shape.TensorShape([2]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3, 2]), dist.batch_shape) def testConcentrationProperty(self): alpha = [[1., 2, 3]] - with self.test_session(): - dist = dirichlet_lib.Dirichlet(alpha) - self.assertEqual([1, 3], dist.concentration.get_shape()) - self.assertAllClose(alpha, self.evaluate(dist.concentration)) + dist = dirichlet_lib.Dirichlet(alpha) + self.assertEqual([1, 3], dist.concentration.get_shape()) + self.assertAllClose(alpha, self.evaluate(dist.concentration)) def testPdfXProper(self): alpha = [[1., 2, 3]] - with self.test_session(): - dist = dirichlet_lib.Dirichlet(alpha, validate_args=True) - self.evaluate(dist.prob([.1, .3, .6])) - self.evaluate(dist.prob([.2, .3, .5])) - # Either condition can trigger. - with self.assertRaisesOpError("samples must be positive"): - self.evaluate(dist.prob([-1., 1.5, 0.5])) - with self.assertRaisesOpError("samples must be positive"): - self.evaluate(dist.prob([0., .1, .9])) - with self.assertRaisesOpError( - "sample last-dimension must sum to `1`"): - self.evaluate(dist.prob([.1, .2, .8])) + dist = dirichlet_lib.Dirichlet(alpha, validate_args=True) + self.evaluate(dist.prob([.1, .3, .6])) + self.evaluate(dist.prob([.2, .3, .5])) + # Either condition can trigger. + with self.assertRaisesOpError("samples must be positive"): + self.evaluate(dist.prob([-1., 1.5, 0.5])) + with self.assertRaisesOpError("samples must be positive"): + self.evaluate(dist.prob([0., .1, .9])) + with self.assertRaisesOpError("sample last-dimension must sum to `1`"): + self.evaluate(dist.prob([.1, .2, .8])) def testPdfZeroBatches(self): - with self.test_session(): - alpha = [1., 2] - x = [.5, .5] - dist = dirichlet_lib.Dirichlet(alpha) - pdf = dist.prob(x) - self.assertAllClose(1., self.evaluate(pdf)) - self.assertEqual((), pdf.get_shape()) + alpha = [1., 2] + x = [.5, .5] + dist = dirichlet_lib.Dirichlet(alpha) + pdf = dist.prob(x) + self.assertAllClose(1., self.evaluate(pdf)) + self.assertEqual((), pdf.get_shape()) def testPdfZeroBatchesNontrivialX(self): - with self.test_session(): - alpha = [1., 2] - x = [.3, .7] - dist = dirichlet_lib.Dirichlet(alpha) - pdf = dist.prob(x) - self.assertAllClose(7. / 5, self.evaluate(pdf)) - self.assertEqual((), pdf.get_shape()) + alpha = [1., 2] + x = [.3, .7] + dist = dirichlet_lib.Dirichlet(alpha) + pdf = dist.prob(x) + self.assertAllClose(7. / 5, self.evaluate(pdf)) + self.assertEqual((), pdf.get_shape()) def testPdfUniformZeroBatches(self): - with self.test_session(): - # Corresponds to a uniform distribution - alpha = [1., 1, 1] - x = [[.2, .5, .3], [.3, .4, .3]] - dist = dirichlet_lib.Dirichlet(alpha) - pdf = dist.prob(x) - self.assertAllClose([2., 2.], self.evaluate(pdf)) - self.assertEqual((2), pdf.get_shape()) + # Corresponds to a uniform distribution + alpha = [1., 1, 1] + x = [[.2, .5, .3], [.3, .4, .3]] + dist = dirichlet_lib.Dirichlet(alpha) + pdf = dist.prob(x) + self.assertAllClose([2., 2.], self.evaluate(pdf)) + self.assertEqual((2), pdf.get_shape()) def testPdfAlphaStretchedInBroadcastWhenSameRank(self): - with self.test_session(): - alpha = [[1., 2]] - x = [[.5, .5], [.3, .7]] - dist = dirichlet_lib.Dirichlet(alpha) - pdf = dist.prob(x) - self.assertAllClose([1., 7. / 5], self.evaluate(pdf)) - self.assertEqual((2), pdf.get_shape()) + alpha = [[1., 2]] + x = [[.5, .5], [.3, .7]] + dist = dirichlet_lib.Dirichlet(alpha) + pdf = dist.prob(x) + self.assertAllClose([1., 7. / 5], self.evaluate(pdf)) + self.assertEqual((2), pdf.get_shape()) def testPdfAlphaStretchedInBroadcastWhenLowerRank(self): - with self.test_session(): - alpha = [1., 2] - x = [[.5, .5], [.2, .8]] - pdf = dirichlet_lib.Dirichlet(alpha).prob(x) - self.assertAllClose([1., 8. / 5], self.evaluate(pdf)) - self.assertEqual((2), pdf.get_shape()) + alpha = [1., 2] + x = [[.5, .5], [.2, .8]] + pdf = dirichlet_lib.Dirichlet(alpha).prob(x) + self.assertAllClose([1., 8. / 5], self.evaluate(pdf)) + self.assertEqual((2), pdf.get_shape()) def testPdfXStretchedInBroadcastWhenSameRank(self): - with self.test_session(): - alpha = [[1., 2], [2., 3]] - x = [[.5, .5]] - pdf = dirichlet_lib.Dirichlet(alpha).prob(x) - self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) - self.assertEqual((2), pdf.get_shape()) + alpha = [[1., 2], [2., 3]] + x = [[.5, .5]] + pdf = dirichlet_lib.Dirichlet(alpha).prob(x) + self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) + self.assertEqual((2), pdf.get_shape()) def testPdfXStretchedInBroadcastWhenLowerRank(self): - with self.test_session(): - alpha = [[1., 2], [2., 3]] - x = [.5, .5] - pdf = dirichlet_lib.Dirichlet(alpha).prob(x) - self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) - self.assertEqual((2), pdf.get_shape()) + alpha = [[1., 2], [2., 3]] + x = [.5, .5] + pdf = dirichlet_lib.Dirichlet(alpha).prob(x) + self.assertAllClose([1., 3. / 2], self.evaluate(pdf)) + self.assertEqual((2), pdf.get_shape()) def testMean(self): - with self.test_session(): - alpha = [1., 2, 3] - dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) - self.assertEqual(dirichlet.mean().get_shape(), [3]) - if not stats: - return - expected_mean = stats.dirichlet.mean(alpha) - self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean) + alpha = [1., 2, 3] + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) + self.assertEqual(dirichlet.mean().get_shape(), [3]) + if not stats: + return + expected_mean = stats.dirichlet.mean(alpha) + self.assertAllClose(self.evaluate(dirichlet.mean()), expected_mean) def testCovarianceFromSampling(self): alpha = np.array([[1., 2, 3], @@ -197,73 +184,66 @@ class DirichletTest(test.TestCase): self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.02, rtol=0.) def testVariance(self): - with self.test_session(): - alpha = [1., 2, 3] - denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1) - dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) - self.assertEqual(dirichlet.covariance().get_shape(), (3, 3)) - if not stats: - return - expected_covariance = np.diag(stats.dirichlet.var(alpha)) - expected_covariance += [[0., -2, -3], [-2, 0, -6], - [-3, -6, 0]] / denominator - self.assertAllClose( - self.evaluate(dirichlet.covariance()), expected_covariance) + alpha = [1., 2, 3] + denominator = np.sum(alpha)**2 * (np.sum(alpha) + 1) + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) + self.assertEqual(dirichlet.covariance().get_shape(), (3, 3)) + if not stats: + return + expected_covariance = np.diag(stats.dirichlet.var(alpha)) + expected_covariance += [[0., -2, -3], [-2, 0, -6], [-3, -6, 0] + ] / denominator + self.assertAllClose( + self.evaluate(dirichlet.covariance()), expected_covariance) def testMode(self): - with self.test_session(): - alpha = np.array([1.1, 2, 3]) - expected_mode = (alpha - 1) / (np.sum(alpha) - 3) - dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) - self.assertEqual(dirichlet.mode().get_shape(), [3]) - self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) + alpha = np.array([1.1, 2, 3]) + expected_mode = (alpha - 1) / (np.sum(alpha) - 3) + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) + self.assertEqual(dirichlet.mode().get_shape(), [3]) + self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) def testModeInvalid(self): - with self.test_session(): - alpha = np.array([1., 2, 3]) - dirichlet = dirichlet_lib.Dirichlet(concentration=alpha, - allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): - self.evaluate(dirichlet.mode()) + alpha = np.array([1., 2, 3]) + dirichlet = dirichlet_lib.Dirichlet( + concentration=alpha, allow_nan_stats=False) + with self.assertRaisesOpError("Condition x < y.*"): + self.evaluate(dirichlet.mode()) def testModeEnableAllowNanStats(self): - with self.test_session(): - alpha = np.array([1., 2, 3]) - dirichlet = dirichlet_lib.Dirichlet(concentration=alpha, - allow_nan_stats=True) - expected_mode = np.zeros_like(alpha) + np.nan + alpha = np.array([1., 2, 3]) + dirichlet = dirichlet_lib.Dirichlet( + concentration=alpha, allow_nan_stats=True) + expected_mode = np.zeros_like(alpha) + np.nan - self.assertEqual(dirichlet.mode().get_shape(), [3]) - self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) + self.assertEqual(dirichlet.mode().get_shape(), [3]) + self.assertAllClose(self.evaluate(dirichlet.mode()), expected_mode) def testEntropy(self): - with self.test_session(): - alpha = [1., 2, 3] - dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) - self.assertEqual(dirichlet.entropy().get_shape(), ()) - if not stats: - return - expected_entropy = stats.dirichlet.entropy(alpha) - self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy) + alpha = [1., 2, 3] + dirichlet = dirichlet_lib.Dirichlet(concentration=alpha) + self.assertEqual(dirichlet.entropy().get_shape(), ()) + if not stats: + return + expected_entropy = stats.dirichlet.entropy(alpha) + self.assertAllClose(self.evaluate(dirichlet.entropy()), expected_entropy) def testSample(self): - with self.test_session(): - alpha = [1., 2] - dirichlet = dirichlet_lib.Dirichlet(alpha) - n = constant_op.constant(100000) - samples = dirichlet.sample(n) - sample_values = self.evaluate(samples) - self.assertEqual(sample_values.shape, (100000, 2)) - self.assertTrue(np.all(sample_values > 0.0)) - if not stats: - return - self.assertLess( - stats.kstest( - # Beta is a univariate distribution. - sample_values[:, 0], - stats.beta( - a=1., b=2.).cdf)[0], - 0.01) + alpha = [1., 2] + dirichlet = dirichlet_lib.Dirichlet(alpha) + n = constant_op.constant(100000) + samples = dirichlet.sample(n) + sample_values = self.evaluate(samples) + self.assertEqual(sample_values.shape, (100000, 2)) + self.assertTrue(np.all(sample_values > 0.0)) + if not stats: + return + self.assertLess( + stats.kstest( + # Beta is a univariate distribution. + sample_values[:, 0], + stats.beta(a=1., b=2.).cdf)[0], + 0.01) def testDirichletFullyReparameterized(self): alpha = constant_op.constant([1.0, 2.0, 3.0]) diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py index 850da3e969..27d1291912 100644 --- a/tensorflow/python/kernel_tests/distributions/exponential_test.py +++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py @@ -22,7 +22,6 @@ import importlib import numpy as np -from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util @@ -48,121 +47,108 @@ stats = try_import("scipy.stats") class ExponentialTest(test.TestCase): def testExponentialLogPDF(self): - with session.Session(): - batch_size = 6 - lam = constant_op.constant([2.0] * batch_size) - lam_v = 2.0 - x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - exponential = exponential_lib.Exponential(rate=lam) + batch_size = 6 + lam = constant_op.constant([2.0] * batch_size) + lam_v = 2.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + exponential = exponential_lib.Exponential(rate=lam) - log_pdf = exponential.log_prob(x) - self.assertEqual(log_pdf.get_shape(), (6,)) + log_pdf = exponential.log_prob(x) + self.assertEqual(log_pdf.get_shape(), (6,)) - pdf = exponential.prob(x) - self.assertEqual(pdf.get_shape(), (6,)) + pdf = exponential.prob(x) + self.assertEqual(pdf.get_shape(), (6,)) - if not stats: - return - expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v) - self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) - self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) + if not stats: + return + expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v) + self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) + self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) def testExponentialCDF(self): - with session.Session(): - batch_size = 6 - lam = constant_op.constant([2.0] * batch_size) - lam_v = 2.0 - x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + batch_size = 6 + lam = constant_op.constant([2.0] * batch_size) + lam_v = 2.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - exponential = exponential_lib.Exponential(rate=lam) + exponential = exponential_lib.Exponential(rate=lam) - cdf = exponential.cdf(x) - self.assertEqual(cdf.get_shape(), (6,)) + cdf = exponential.cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) - if not stats: - return - expected_cdf = stats.expon.cdf(x, scale=1 / lam_v) - self.assertAllClose(self.evaluate(cdf), expected_cdf) + if not stats: + return + expected_cdf = stats.expon.cdf(x, scale=1 / lam_v) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testExponentialMean(self): - with session.Session(): - lam_v = np.array([1.0, 4.0, 2.5]) - exponential = exponential_lib.Exponential(rate=lam_v) - self.assertEqual(exponential.mean().get_shape(), (3,)) - if not stats: - return - expected_mean = stats.expon.mean(scale=1 / lam_v) - self.assertAllClose(self.evaluate(exponential.mean()), expected_mean) + lam_v = np.array([1.0, 4.0, 2.5]) + exponential = exponential_lib.Exponential(rate=lam_v) + self.assertEqual(exponential.mean().get_shape(), (3,)) + if not stats: + return + expected_mean = stats.expon.mean(scale=1 / lam_v) + self.assertAllClose(self.evaluate(exponential.mean()), expected_mean) def testExponentialVariance(self): - with session.Session(): - lam_v = np.array([1.0, 4.0, 2.5]) - exponential = exponential_lib.Exponential(rate=lam_v) - self.assertEqual(exponential.variance().get_shape(), (3,)) - if not stats: - return - expected_variance = stats.expon.var(scale=1 / lam_v) - self.assertAllClose( - self.evaluate(exponential.variance()), expected_variance) + lam_v = np.array([1.0, 4.0, 2.5]) + exponential = exponential_lib.Exponential(rate=lam_v) + self.assertEqual(exponential.variance().get_shape(), (3,)) + if not stats: + return + expected_variance = stats.expon.var(scale=1 / lam_v) + self.assertAllClose( + self.evaluate(exponential.variance()), expected_variance) def testExponentialEntropy(self): - with session.Session(): - lam_v = np.array([1.0, 4.0, 2.5]) - exponential = exponential_lib.Exponential(rate=lam_v) - self.assertEqual(exponential.entropy().get_shape(), (3,)) - if not stats: - return - expected_entropy = stats.expon.entropy(scale=1 / lam_v) - self.assertAllClose( - self.evaluate(exponential.entropy()), expected_entropy) + lam_v = np.array([1.0, 4.0, 2.5]) + exponential = exponential_lib.Exponential(rate=lam_v) + self.assertEqual(exponential.entropy().get_shape(), (3,)) + if not stats: + return + expected_entropy = stats.expon.entropy(scale=1 / lam_v) + self.assertAllClose(self.evaluate(exponential.entropy()), expected_entropy) def testExponentialSample(self): - with self.test_session(): - lam = constant_op.constant([3.0, 4.0]) - lam_v = [3.0, 4.0] - n = constant_op.constant(100000) - exponential = exponential_lib.Exponential(rate=lam) - - samples = exponential.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(sample_values.shape, (100000, 2)) - self.assertFalse(np.any(sample_values < 0.0)) - if not stats: - return - for i in range(2): - self.assertLess( - stats.kstest( - sample_values[:, i], stats.expon(scale=1.0 / lam_v[i]).cdf)[0], - 0.01) + lam = constant_op.constant([3.0, 4.0]) + lam_v = [3.0, 4.0] + n = constant_op.constant(100000) + exponential = exponential_lib.Exponential(rate=lam) + + samples = exponential.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(sample_values.shape, (100000, 2)) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + for i in range(2): + self.assertLess( + stats.kstest(sample_values[:, i], + stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) def testExponentialSampleMultiDimensional(self): - with self.test_session(): - batch_size = 2 - lam_v = [3.0, 22.0] - lam = constant_op.constant([lam_v] * batch_size) + batch_size = 2 + lam_v = [3.0, 22.0] + lam = constant_op.constant([lam_v] * batch_size) - exponential = exponential_lib.Exponential(rate=lam) + exponential = exponential_lib.Exponential(rate=lam) + + n = 100000 + samples = exponential.sample(n, seed=138) + self.assertEqual(samples.get_shape(), (n, batch_size, 2)) + + sample_values = self.evaluate(samples) - n = 100000 - samples = exponential.sample(n, seed=138) - self.assertEqual(samples.get_shape(), (n, batch_size, 2)) - - sample_values = self.evaluate(samples) - - self.assertFalse(np.any(sample_values < 0.0)) - if not stats: - return - for i in range(2): - self.assertLess( - stats.kstest( - sample_values[:, 0, i], - stats.expon(scale=1.0 / lam_v[i]).cdf)[0], - 0.01) - self.assertLess( - stats.kstest( - sample_values[:, 1, i], - stats.expon(scale=1.0 / lam_v[i]).cdf)[0], - 0.01) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + for i in range(2): + self.assertLess( + stats.kstest(sample_values[:, 0, i], + stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) + self.assertLess( + stats.kstest(sample_values[:, 1, i], + stats.expon(scale=1.0 / lam_v[i]).cdf)[0], 0.01) def testFullyReparameterized(self): lam = constant_op.constant([0.1, 1.0]) @@ -174,11 +160,10 @@ class ExponentialTest(test.TestCase): self.assertIsNotNone(grad_lam) def testExponentialWithSoftplusRate(self): - with self.test_session(): - lam = [-2.2, -3.4] - exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam) - self.assertAllClose( - self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate)) + lam = [-2.2, -3.4] + exponential = exponential_lib.ExponentialWithSoftplusRate(rate=lam) + self.assertAllClose( + self.evaluate(nn_ops.softplus(lam)), self.evaluate(exponential.rate)) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py index 297e20264c..4eff40b029 100644 --- a/tensorflow/python/kernel_tests/distributions/gamma_test.py +++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py @@ -50,221 +50,203 @@ stats = try_import("scipy.stats") class GammaTest(test.TestCase): def testGammaShape(self): - with self.test_session(): - alpha = constant_op.constant([3.0] * 5) - beta = constant_op.constant(11.0) - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + alpha = constant_op.constant([3.0] * 5) + beta = constant_op.constant(11.0) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,)) - self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), []) - self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([])) + self.assertEqual(self.evaluate(gamma.batch_shape_tensor()), (5,)) + self.assertEqual(gamma.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(self.evaluate(gamma.event_shape_tensor()), []) + self.assertEqual(gamma.event_shape, tensor_shape.TensorShape([])) def testGammaLogPDF(self): - with self.test_session(): - batch_size = 6 - alpha = constant_op.constant([2.0] * batch_size) - beta = constant_op.constant([3.0] * batch_size) - alpha_v = 2.0 - beta_v = 3.0 - x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - log_pdf = gamma.log_prob(x) - self.assertEqual(log_pdf.get_shape(), (6,)) - pdf = gamma.prob(x) - self.assertEqual(pdf.get_shape(), (6,)) - if not stats: - return - expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) - self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) + batch_size = 6 + alpha = constant_op.constant([2.0] * batch_size) + beta = constant_op.constant([3.0] * batch_size) + alpha_v = 2.0 + beta_v = 3.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + log_pdf = gamma.log_prob(x) + self.assertEqual(log_pdf.get_shape(), (6,)) + pdf = gamma.prob(x) + self.assertEqual(pdf.get_shape(), (6,)) + if not stats: + return + expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) + self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) + self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) def testGammaLogPDFMultidimensional(self): - with self.test_session(): - batch_size = 6 - alpha = constant_op.constant([[2.0, 4.0]] * batch_size) - beta = constant_op.constant([[3.0, 4.0]] * batch_size) - alpha_v = np.array([2.0, 4.0]) - beta_v = np.array([3.0, 4.0]) - x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - log_pdf = gamma.log_prob(x) - log_pdf_values = self.evaluate(log_pdf) - self.assertEqual(log_pdf.get_shape(), (6, 2)) - pdf = gamma.prob(x) - pdf_values = self.evaluate(pdf) - self.assertEqual(pdf.get_shape(), (6, 2)) - if not stats: - return - expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(log_pdf_values, expected_log_pdf) - self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) + batch_size = 6 + alpha = constant_op.constant([[2.0, 4.0]] * batch_size) + beta = constant_op.constant([[3.0, 4.0]] * batch_size) + alpha_v = np.array([2.0, 4.0]) + beta_v = np.array([3.0, 4.0]) + x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + log_pdf = gamma.log_prob(x) + log_pdf_values = self.evaluate(log_pdf) + self.assertEqual(log_pdf.get_shape(), (6, 2)) + pdf = gamma.prob(x) + pdf_values = self.evaluate(pdf) + self.assertEqual(pdf.get_shape(), (6, 2)) + if not stats: + return + expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) + self.assertAllClose(log_pdf_values, expected_log_pdf) + self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) def testGammaLogPDFMultidimensionalBroadcasting(self): - with self.test_session(): - batch_size = 6 - alpha = constant_op.constant([[2.0, 4.0]] * batch_size) - beta = constant_op.constant(3.0) - alpha_v = np.array([2.0, 4.0]) - beta_v = 3.0 - x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - log_pdf = gamma.log_prob(x) - log_pdf_values = self.evaluate(log_pdf) - self.assertEqual(log_pdf.get_shape(), (6, 2)) - pdf = gamma.prob(x) - pdf_values = self.evaluate(pdf) - self.assertEqual(pdf.get_shape(), (6, 2)) - - if not stats: - return - expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(log_pdf_values, expected_log_pdf) - self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) + batch_size = 6 + alpha = constant_op.constant([[2.0, 4.0]] * batch_size) + beta = constant_op.constant(3.0) + alpha_v = np.array([2.0, 4.0]) + beta_v = 3.0 + x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + log_pdf = gamma.log_prob(x) + log_pdf_values = self.evaluate(log_pdf) + self.assertEqual(log_pdf.get_shape(), (6, 2)) + pdf = gamma.prob(x) + pdf_values = self.evaluate(pdf) + self.assertEqual(pdf.get_shape(), (6, 2)) - def testGammaCDF(self): - with self.test_session(): - batch_size = 6 - alpha = constant_op.constant([2.0] * batch_size) - beta = constant_op.constant([3.0] * batch_size) - alpha_v = 2.0 - beta_v = 3.0 - x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + if not stats: + return + expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) + self.assertAllClose(log_pdf_values, expected_log_pdf) + self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - cdf = gamma.cdf(x) - self.assertEqual(cdf.get_shape(), (6,)) - if not stats: - return - expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v) - self.assertAllClose(self.evaluate(cdf), expected_cdf) + def testGammaCDF(self): + batch_size = 6 + alpha = constant_op.constant([2.0] * batch_size) + beta = constant_op.constant([3.0] * batch_size) + alpha_v = 2.0 + beta_v = 3.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + cdf = gamma.cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) + if not stats: + return + expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testGammaMean(self): - with self.test_session(): - alpha_v = np.array([1.0, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) - self.assertEqual(gamma.mean().get_shape(), (3,)) - if not stats: - return - expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v) - self.assertAllClose(self.evaluate(gamma.mean()), expected_means) + alpha_v = np.array([1.0, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) + self.assertEqual(gamma.mean().get_shape(), (3,)) + if not stats: + return + expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v) + self.assertAllClose(self.evaluate(gamma.mean()), expected_means) def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): - with self.test_session(): - alpha_v = np.array([5.5, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) - expected_modes = (alpha_v - 1) / beta_v - self.assertEqual(gamma.mode().get_shape(), (3,)) - self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) + alpha_v = np.array([5.5, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) + expected_modes = (alpha_v - 1) / beta_v + self.assertEqual(gamma.mode().get_shape(), (3,)) + self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): - with self.test_session(): - # Mode will not be defined for the first entry. - alpha_v = np.array([0.5, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - allow_nan_stats=False) - with self.assertRaisesOpError("x < y"): - self.evaluate(gamma.mode()) + # Mode will not be defined for the first entry. + alpha_v = np.array([0.5, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma( + concentration=alpha_v, rate=beta_v, allow_nan_stats=False) + with self.assertRaisesOpError("x < y"): + self.evaluate(gamma.mode()) def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self): - with self.test_session(): - # Mode will not be defined for the first entry. - alpha_v = np.array([0.5, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - allow_nan_stats=True) - expected_modes = (alpha_v - 1) / beta_v - expected_modes[0] = np.nan - self.assertEqual(gamma.mode().get_shape(), (3,)) - self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) + # Mode will not be defined for the first entry. + alpha_v = np.array([0.5, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma( + concentration=alpha_v, rate=beta_v, allow_nan_stats=True) + expected_modes = (alpha_v - 1) / beta_v + expected_modes[0] = np.nan + self.assertEqual(gamma.mode().get_shape(), (3,)) + self.assertAllClose(self.evaluate(gamma.mode()), expected_modes) def testGammaVariance(self): - with self.test_session(): - alpha_v = np.array([1.0, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) - self.assertEqual(gamma.variance().get_shape(), (3,)) - if not stats: - return - expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v) - self.assertAllClose(self.evaluate(gamma.variance()), expected_variances) + alpha_v = np.array([1.0, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) + self.assertEqual(gamma.variance().get_shape(), (3,)) + if not stats: + return + expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v) + self.assertAllClose(self.evaluate(gamma.variance()), expected_variances) def testGammaStd(self): - with self.test_session(): - alpha_v = np.array([1.0, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) - self.assertEqual(gamma.stddev().get_shape(), (3,)) - if not stats: - return - expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v) - self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev) + alpha_v = np.array([1.0, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) + self.assertEqual(gamma.stddev().get_shape(), (3,)) + if not stats: + return + expected_stddev = stats.gamma.std(alpha_v, scale=1. / beta_v) + self.assertAllClose(self.evaluate(gamma.stddev()), expected_stddev) def testGammaEntropy(self): - with self.test_session(): - alpha_v = np.array([1.0, 3.0, 2.5]) - beta_v = np.array([1.0, 4.0, 5.0]) - gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) - self.assertEqual(gamma.entropy().get_shape(), (3,)) - if not stats: - return - expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v) - self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy) + alpha_v = np.array([1.0, 3.0, 2.5]) + beta_v = np.array([1.0, 4.0, 5.0]) + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) + self.assertEqual(gamma.entropy().get_shape(), (3,)) + if not stats: + return + expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v) + self.assertAllClose(self.evaluate(gamma.entropy()), expected_entropy) def testGammaSampleSmallAlpha(self): - with self.test_session(): - alpha_v = 0.05 - beta_v = 1.0 - alpha = constant_op.constant(alpha_v) - beta = constant_op.constant(beta_v) - n = 100000 - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - samples = gamma.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(samples.get_shape(), (n,)) - self.assertEqual(sample_values.shape, (n,)) - self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) - if not stats: - return - self.assertAllClose( - sample_values.mean(), - stats.gamma.mean( - alpha_v, scale=1 / beta_v), - atol=.01) - self.assertAllClose( - sample_values.var(), - stats.gamma.var(alpha_v, scale=1 / beta_v), - atol=.15) + alpha_v = 0.05 + beta_v = 1.0 + alpha = constant_op.constant(alpha_v) + beta = constant_op.constant(beta_v) + n = 100000 + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + samples = gamma.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(samples.get_shape(), (n,)) + self.assertEqual(sample_values.shape, (n,)) + self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) + if not stats: + return + self.assertAllClose( + sample_values.mean(), + stats.gamma.mean(alpha_v, scale=1 / beta_v), + atol=.01) + self.assertAllClose( + sample_values.var(), + stats.gamma.var(alpha_v, scale=1 / beta_v), + atol=.15) def testGammaSample(self): - with self.test_session(): - alpha_v = 4.0 - beta_v = 3.0 - alpha = constant_op.constant(alpha_v) - beta = constant_op.constant(beta_v) - n = 100000 - gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) - samples = gamma.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(samples.get_shape(), (n,)) - self.assertEqual(sample_values.shape, (n,)) - self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) - if not stats: - return - self.assertAllClose( - sample_values.mean(), - stats.gamma.mean( - alpha_v, scale=1 / beta_v), - atol=.01) - self.assertAllClose( - sample_values.var(), - stats.gamma.var(alpha_v, scale=1 / beta_v), - atol=.15) + alpha_v = 4.0 + beta_v = 3.0 + alpha = constant_op.constant(alpha_v) + beta = constant_op.constant(beta_v) + n = 100000 + gamma = gamma_lib.Gamma(concentration=alpha, rate=beta) + samples = gamma.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(samples.get_shape(), (n,)) + self.assertEqual(sample_values.shape, (n,)) + self.assertTrue(self._kstest(alpha_v, beta_v, sample_values)) + if not stats: + return + self.assertAllClose( + sample_values.mean(), + stats.gamma.mean(alpha_v, scale=1 / beta_v), + atol=.01) + self.assertAllClose( + sample_values.var(), + stats.gamma.var(alpha_v, scale=1 / beta_v), + atol=.15) def testGammaFullyReparameterized(self): alpha = constant_op.constant(4.0) @@ -279,37 +261,37 @@ class GammaTest(test.TestCase): self.assertIsNotNone(grad_beta) def testGammaSampleMultiDimensional(self): - with self.test_session(): - alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 - beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 - gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) - n = 10000 - samples = gamma.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(samples.get_shape(), (n, 10, 100)) - self.assertEqual(sample_values.shape, (n, 10, 100)) - zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100 - alpha_bc = alpha_v + zeros - beta_bc = beta_v + zeros - if not stats: - return - self.assertAllClose( - sample_values.mean(axis=0), - stats.gamma.mean( - alpha_bc, scale=1 / beta_bc), - atol=0., rtol=.05) - self.assertAllClose( - sample_values.var(axis=0), - stats.gamma.var(alpha_bc, scale=1 / beta_bc), - atol=10.0, rtol=0.) - fails = 0 - trials = 0 - for ai, a in enumerate(np.reshape(alpha_v, [-1])): - for bi, b in enumerate(np.reshape(beta_v, [-1])): - s = sample_values[:, bi, ai] - trials += 1 - fails += 0 if self._kstest(a, b, s) else 1 - self.assertLess(fails, trials * 0.03) + alpha_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 + beta_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 + gamma = gamma_lib.Gamma(concentration=alpha_v, rate=beta_v) + n = 10000 + samples = gamma.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(samples.get_shape(), (n, 10, 100)) + self.assertEqual(sample_values.shape, (n, 10, 100)) + zeros = np.zeros_like(alpha_v + beta_v) # 10 x 100 + alpha_bc = alpha_v + zeros + beta_bc = beta_v + zeros + if not stats: + return + self.assertAllClose( + sample_values.mean(axis=0), + stats.gamma.mean(alpha_bc, scale=1 / beta_bc), + atol=0., + rtol=.05) + self.assertAllClose( + sample_values.var(axis=0), + stats.gamma.var(alpha_bc, scale=1 / beta_bc), + atol=10.0, + rtol=0.) + fails = 0 + trials = 0 + for ai, a in enumerate(np.reshape(alpha_v, [-1])): + for bi, b in enumerate(np.reshape(beta_v, [-1])): + s = sample_values[:, bi, ai] + trials += 1 + fails += 0 if self._kstest(a, b, s) else 1 + self.assertLess(fails, trials * 0.03) def _kstest(self, alpha, beta, samples): # Uses the Kolmogorov-Smirnov test for goodness of fit. @@ -320,30 +302,29 @@ class GammaTest(test.TestCase): return ks < 0.02 def testGammaPdfOfSampleMultiDims(self): - with self.test_session(): - gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]]) - num = 50000 - samples = gamma.sample(num, seed=137) - pdfs = gamma.prob(samples) - sample_vals, pdf_vals = self.evaluate([samples, pdfs]) - self.assertEqual(samples.get_shape(), (num, 2, 2)) - self.assertEqual(pdfs.get_shape(), (num, 2, 2)) - self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) - self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) - self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) - self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) - if not stats: - return - self.assertAllClose( - stats.gamma.mean( - [[7., 11.], [7., 11.]], scale=1 / np.array([[5., 5.], [6., 6.]])), - sample_vals.mean(axis=0), - atol=.1) - self.assertAllClose( - stats.gamma.var([[7., 11.], [7., 11.]], - scale=1 / np.array([[5., 5.], [6., 6.]])), - sample_vals.var(axis=0), - atol=.1) + gamma = gamma_lib.Gamma(concentration=[7., 11.], rate=[[5.], [6.]]) + num = 50000 + samples = gamma.sample(num, seed=137) + pdfs = gamma.prob(samples) + sample_vals, pdf_vals = self.evaluate([samples, pdfs]) + self.assertEqual(samples.get_shape(), (num, 2, 2)) + self.assertEqual(pdfs.get_shape(), (num, 2, 2)) + self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) + self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) + self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) + self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) + if not stats: + return + self.assertAllClose( + stats.gamma.mean([[7., 11.], [7., 11.]], + scale=1 / np.array([[5., 5.], [6., 6.]])), + sample_vals.mean(axis=0), + atol=.1) + self.assertAllClose( + stats.gamma.var([[7., 11.], [7., 11.]], + scale=1 / np.array([[5., 5.], [6., 6.]])), + sample_vals.var(axis=0), + atol=.1) def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3): s_p = zip(sample_vals, pdf_vals) @@ -356,32 +337,29 @@ class GammaTest(test.TestCase): self.assertNear(1., total, err=err) def testGammaNonPositiveInitializationParamsRaises(self): - with self.test_session(): - alpha_v = constant_op.constant(0.0, name="alpha") - beta_v = constant_op.constant(1.0, name="beta") - with self.assertRaisesOpError("x > 0"): - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - validate_args=True) - self.evaluate(gamma.mean()) - alpha_v = constant_op.constant(1.0, name="alpha") - beta_v = constant_op.constant(0.0, name="beta") - with self.assertRaisesOpError("x > 0"): - gamma = gamma_lib.Gamma(concentration=alpha_v, - rate=beta_v, - validate_args=True) - self.evaluate(gamma.mean()) + alpha_v = constant_op.constant(0.0, name="alpha") + beta_v = constant_op.constant(1.0, name="beta") + with self.assertRaisesOpError("x > 0"): + gamma = gamma_lib.Gamma( + concentration=alpha_v, rate=beta_v, validate_args=True) + self.evaluate(gamma.mean()) + alpha_v = constant_op.constant(1.0, name="alpha") + beta_v = constant_op.constant(0.0, name="beta") + with self.assertRaisesOpError("x > 0"): + gamma = gamma_lib.Gamma( + concentration=alpha_v, rate=beta_v, validate_args=True) + self.evaluate(gamma.mean()) def testGammaWithSoftplusConcentrationRate(self): - with self.test_session(): - alpha_v = constant_op.constant([0.0, -2.1], name="alpha") - beta_v = constant_op.constant([1.0, -3.6], name="beta") - gamma = gamma_lib.GammaWithSoftplusConcentrationRate( - concentration=alpha_v, rate=beta_v) - self.assertAllEqual(self.evaluate(nn_ops.softplus(alpha_v)), - self.evaluate(gamma.concentration)) - self.assertAllEqual(self.evaluate(nn_ops.softplus(beta_v)), - self.evaluate(gamma.rate)) + alpha_v = constant_op.constant([0.0, -2.1], name="alpha") + beta_v = constant_op.constant([1.0, -3.6], name="beta") + gamma = gamma_lib.GammaWithSoftplusConcentrationRate( + concentration=alpha_v, rate=beta_v) + self.assertAllEqual( + self.evaluate(nn_ops.softplus(alpha_v)), + self.evaluate(gamma.concentration)) + self.assertAllEqual( + self.evaluate(nn_ops.softplus(beta_v)), self.evaluate(gamma.rate)) def testGammaGammaKL(self): alpha0 = np.array([3.]) @@ -391,15 +369,14 @@ class GammaTest(test.TestCase): beta1 = np.array([0.5, 1., 1.5, 2., 2.5, 3.]) # Build graph. - with self.test_session(): - g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0) - g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) - x = g0.sample(int(1e4), seed=0) - kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) - kl_actual = kullback_leibler.kl_divergence(g0, g1) - - # Execute graph. - [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual]) + g0 = gamma_lib.Gamma(concentration=alpha0, rate=beta0) + g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1) + x = g0.sample(int(1e4), seed=0) + kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0) + kl_actual = kullback_leibler.kl_divergence(g0, g1) + + # Execute graph. + [kl_sample_, kl_actual_] = self.evaluate([kl_sample, kl_actual]) self.assertEqual(beta0.shape, kl_actual.get_shape()) diff --git a/tensorflow/python/kernel_tests/distributions/laplace_test.py b/tensorflow/python/kernel_tests/distributions/laplace_test.py index 24b243f647..630c2cb424 100644 --- a/tensorflow/python/kernel_tests/distributions/laplace_test.py +++ b/tensorflow/python/kernel_tests/distributions/laplace_test.py @@ -21,7 +21,6 @@ import importlib import numpy as np -from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import tensor_shape @@ -49,212 +48,198 @@ stats = try_import("scipy.stats") class LaplaceTest(test.TestCase): def testLaplaceShape(self): - with self.test_session(): - loc = constant_op.constant([3.0] * 5) - scale = constant_op.constant(11.0) - laplace = laplace_lib.Laplace(loc=loc, scale=scale) + loc = constant_op.constant([3.0] * 5) + scale = constant_op.constant(11.0) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) - self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,)) - self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), []) - self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([])) + self.assertEqual(self.evaluate(laplace.batch_shape_tensor()), (5,)) + self.assertEqual(laplace.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(self.evaluate(laplace.event_shape_tensor()), []) + self.assertEqual(laplace.event_shape, tensor_shape.TensorShape([])) def testLaplaceLogPDF(self): - with self.test_session(): - batch_size = 6 - loc = constant_op.constant([2.0] * batch_size) - scale = constant_op.constant([3.0] * batch_size) - loc_v = 2.0 - scale_v = 3.0 - x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - laplace = laplace_lib.Laplace(loc=loc, scale=scale) - log_pdf = laplace.log_prob(x) - self.assertEqual(log_pdf.get_shape(), (6,)) - if not stats: - return - expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) + batch_size = 6 + loc = constant_op.constant([2.0] * batch_size) + scale = constant_op.constant([3.0] * batch_size) + loc_v = 2.0 + scale_v = 3.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) + log_pdf = laplace.log_prob(x) + self.assertEqual(log_pdf.get_shape(), (6,)) + if not stats: + return + expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf) - pdf = laplace.prob(x) - self.assertEqual(pdf.get_shape(), (6,)) - self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) + pdf = laplace.prob(x) + self.assertEqual(pdf.get_shape(), (6,)) + self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf)) def testLaplaceLogPDFMultidimensional(self): - with self.test_session(): - batch_size = 6 - loc = constant_op.constant([[2.0, 4.0]] * batch_size) - scale = constant_op.constant([[3.0, 4.0]] * batch_size) - loc_v = np.array([2.0, 4.0]) - scale_v = np.array([3.0, 4.0]) - x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T - laplace = laplace_lib.Laplace(loc=loc, scale=scale) - log_pdf = laplace.log_prob(x) - log_pdf_values = self.evaluate(log_pdf) - self.assertEqual(log_pdf.get_shape(), (6, 2)) - - pdf = laplace.prob(x) - pdf_values = self.evaluate(pdf) - self.assertEqual(pdf.get_shape(), (6, 2)) - if not stats: - return - expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) - self.assertAllClose(log_pdf_values, expected_log_pdf) - self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) + batch_size = 6 + loc = constant_op.constant([[2.0, 4.0]] * batch_size) + scale = constant_op.constant([[3.0, 4.0]] * batch_size) + loc_v = np.array([2.0, 4.0]) + scale_v = np.array([3.0, 4.0]) + x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T + laplace = laplace_lib.Laplace(loc=loc, scale=scale) + log_pdf = laplace.log_prob(x) + log_pdf_values = self.evaluate(log_pdf) + self.assertEqual(log_pdf.get_shape(), (6, 2)) + + pdf = laplace.prob(x) + pdf_values = self.evaluate(pdf) + self.assertEqual(pdf.get_shape(), (6, 2)) + if not stats: + return + expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) + self.assertAllClose(log_pdf_values, expected_log_pdf) + self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) def testLaplaceLogPDFMultidimensionalBroadcasting(self): - with self.test_session(): - batch_size = 6 - loc = constant_op.constant([[2.0, 4.0]] * batch_size) - scale = constant_op.constant(3.0) - loc_v = np.array([2.0, 4.0]) - scale_v = 3.0 - x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T - laplace = laplace_lib.Laplace(loc=loc, scale=scale) - log_pdf = laplace.log_prob(x) - log_pdf_values = self.evaluate(log_pdf) - self.assertEqual(log_pdf.get_shape(), (6, 2)) - - pdf = laplace.prob(x) - pdf_values = self.evaluate(pdf) - self.assertEqual(pdf.get_shape(), (6, 2)) - if not stats: - return - expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) - self.assertAllClose(log_pdf_values, expected_log_pdf) - self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) + batch_size = 6 + loc = constant_op.constant([[2.0, 4.0]] * batch_size) + scale = constant_op.constant(3.0) + loc_v = np.array([2.0, 4.0]) + scale_v = 3.0 + x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T + laplace = laplace_lib.Laplace(loc=loc, scale=scale) + log_pdf = laplace.log_prob(x) + log_pdf_values = self.evaluate(log_pdf) + self.assertEqual(log_pdf.get_shape(), (6, 2)) + + pdf = laplace.prob(x) + pdf_values = self.evaluate(pdf) + self.assertEqual(pdf.get_shape(), (6, 2)) + if not stats: + return + expected_log_pdf = stats.laplace.logpdf(x, loc_v, scale=scale_v) + self.assertAllClose(log_pdf_values, expected_log_pdf) + self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) def testLaplaceCDF(self): - with self.test_session(): - batch_size = 6 - loc = constant_op.constant([2.0] * batch_size) - scale = constant_op.constant([3.0] * batch_size) - loc_v = 2.0 - scale_v = 3.0 - x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) + batch_size = 6 + loc = constant_op.constant([2.0] * batch_size) + scale = constant_op.constant([3.0] * batch_size) + loc_v = 2.0 + scale_v = 3.0 + x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) - laplace = laplace_lib.Laplace(loc=loc, scale=scale) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) - cdf = laplace.cdf(x) - self.assertEqual(cdf.get_shape(), (6,)) - if not stats: - return - expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(cdf), expected_cdf) + cdf = laplace.cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) + if not stats: + return + expected_cdf = stats.laplace.cdf(x, loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testLaplaceLogCDF(self): - with self.test_session(): - batch_size = 6 - loc = constant_op.constant([2.0] * batch_size) - scale = constant_op.constant([3.0] * batch_size) - loc_v = 2.0 - scale_v = 3.0 - x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32) + batch_size = 6 + loc = constant_op.constant([2.0] * batch_size) + scale = constant_op.constant([3.0] * batch_size) + loc_v = 2.0 + scale_v = 3.0 + x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32) - laplace = laplace_lib.Laplace(loc=loc, scale=scale) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) - cdf = laplace.log_cdf(x) - self.assertEqual(cdf.get_shape(), (6,)) - if not stats: - return - expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(cdf), expected_cdf) + cdf = laplace.log_cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) + if not stats: + return + expected_cdf = stats.laplace.logcdf(x, loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(cdf), expected_cdf) def testLaplaceLogSurvivalFunction(self): - with self.test_session(): - batch_size = 6 - loc = constant_op.constant([2.0] * batch_size) - scale = constant_op.constant([3.0] * batch_size) - loc_v = 2.0 - scale_v = 3.0 - x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32) + batch_size = 6 + loc = constant_op.constant([2.0] * batch_size) + scale = constant_op.constant([3.0] * batch_size) + loc_v = 2.0 + scale_v = 3.0 + x = np.array([-2.5, 2.5, -4.0, 0.1, 1.0, 2.0], dtype=np.float32) - laplace = laplace_lib.Laplace(loc=loc, scale=scale) + laplace = laplace_lib.Laplace(loc=loc, scale=scale) - sf = laplace.log_survival_function(x) - self.assertEqual(sf.get_shape(), (6,)) - if not stats: - return - expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(sf), expected_sf) + sf = laplace.log_survival_function(x) + self.assertEqual(sf.get_shape(), (6,)) + if not stats: + return + expected_sf = stats.laplace.logsf(x, loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(sf), expected_sf) def testLaplaceMean(self): - with self.test_session(): - loc_v = np.array([1.0, 3.0, 2.5]) - scale_v = np.array([1.0, 4.0, 5.0]) - laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) - self.assertEqual(laplace.mean().get_shape(), (3,)) - if not stats: - return - expected_means = stats.laplace.mean(loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(laplace.mean()), expected_means) + loc_v = np.array([1.0, 3.0, 2.5]) + scale_v = np.array([1.0, 4.0, 5.0]) + laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) + self.assertEqual(laplace.mean().get_shape(), (3,)) + if not stats: + return + expected_means = stats.laplace.mean(loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(laplace.mean()), expected_means) def testLaplaceMode(self): - with self.test_session(): - loc_v = np.array([0.5, 3.0, 2.5]) - scale_v = np.array([1.0, 4.0, 5.0]) - laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) - self.assertEqual(laplace.mode().get_shape(), (3,)) - self.assertAllClose(self.evaluate(laplace.mode()), loc_v) + loc_v = np.array([0.5, 3.0, 2.5]) + scale_v = np.array([1.0, 4.0, 5.0]) + laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) + self.assertEqual(laplace.mode().get_shape(), (3,)) + self.assertAllClose(self.evaluate(laplace.mode()), loc_v) def testLaplaceVariance(self): - with self.test_session(): - loc_v = np.array([1.0, 3.0, 2.5]) - scale_v = np.array([1.0, 4.0, 5.0]) - laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) - self.assertEqual(laplace.variance().get_shape(), (3,)) - if not stats: - return - expected_variances = stats.laplace.var(loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(laplace.variance()), expected_variances) + loc_v = np.array([1.0, 3.0, 2.5]) + scale_v = np.array([1.0, 4.0, 5.0]) + laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) + self.assertEqual(laplace.variance().get_shape(), (3,)) + if not stats: + return + expected_variances = stats.laplace.var(loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(laplace.variance()), expected_variances) def testLaplaceStd(self): - with self.test_session(): - loc_v = np.array([1.0, 3.0, 2.5]) - scale_v = np.array([1.0, 4.0, 5.0]) - laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) - self.assertEqual(laplace.stddev().get_shape(), (3,)) - if not stats: - return - expected_stddev = stats.laplace.std(loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev) + loc_v = np.array([1.0, 3.0, 2.5]) + scale_v = np.array([1.0, 4.0, 5.0]) + laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) + self.assertEqual(laplace.stddev().get_shape(), (3,)) + if not stats: + return + expected_stddev = stats.laplace.std(loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(laplace.stddev()), expected_stddev) def testLaplaceEntropy(self): - with self.test_session(): - loc_v = np.array([1.0, 3.0, 2.5]) - scale_v = np.array([1.0, 4.0, 5.0]) - laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) - self.assertEqual(laplace.entropy().get_shape(), (3,)) - if not stats: - return - expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v) - self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy) + loc_v = np.array([1.0, 3.0, 2.5]) + scale_v = np.array([1.0, 4.0, 5.0]) + laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) + self.assertEqual(laplace.entropy().get_shape(), (3,)) + if not stats: + return + expected_entropy = stats.laplace.entropy(loc_v, scale=scale_v) + self.assertAllClose(self.evaluate(laplace.entropy()), expected_entropy) def testLaplaceSample(self): - with session.Session(): - loc_v = 4.0 - scale_v = 3.0 - loc = constant_op.constant(loc_v) - scale = constant_op.constant(scale_v) - n = 100000 - laplace = laplace_lib.Laplace(loc=loc, scale=scale) - samples = laplace.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(samples.get_shape(), (n,)) - self.assertEqual(sample_values.shape, (n,)) - if not stats: - return - self.assertAllClose( - sample_values.mean(), - stats.laplace.mean( - loc_v, scale=scale_v), - rtol=0.05, - atol=0.) - self.assertAllClose( - sample_values.var(), - stats.laplace.var(loc_v, scale=scale_v), - rtol=0.05, - atol=0.) - self.assertTrue(self._kstest(loc_v, scale_v, sample_values)) + loc_v = 4.0 + scale_v = 3.0 + loc = constant_op.constant(loc_v) + scale = constant_op.constant(scale_v) + n = 100000 + laplace = laplace_lib.Laplace(loc=loc, scale=scale) + samples = laplace.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(samples.get_shape(), (n,)) + self.assertEqual(sample_values.shape, (n,)) + if not stats: + return + self.assertAllClose( + sample_values.mean(), + stats.laplace.mean(loc_v, scale=scale_v), + rtol=0.05, + atol=0.) + self.assertAllClose( + sample_values.var(), + stats.laplace.var(loc_v, scale=scale_v), + rtol=0.05, + atol=0.) + self.assertTrue(self._kstest(loc_v, scale_v, sample_values)) def testLaplaceFullyReparameterized(self): loc = constant_op.constant(4.0) @@ -269,39 +254,37 @@ class LaplaceTest(test.TestCase): self.assertIsNotNone(grad_scale) def testLaplaceSampleMultiDimensional(self): - with session.Session(): - loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 - scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 - laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) - n = 10000 - samples = laplace.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(samples.get_shape(), (n, 10, 100)) - self.assertEqual(sample_values.shape, (n, 10, 100)) - zeros = np.zeros_like(loc_v + scale_v) # 10 x 100 - loc_bc = loc_v + zeros - scale_bc = scale_v + zeros - if not stats: - return - self.assertAllClose( - sample_values.mean(axis=0), - stats.laplace.mean( - loc_bc, scale=scale_bc), - rtol=0.35, - atol=0.) - self.assertAllClose( - sample_values.var(axis=0), - stats.laplace.var(loc_bc, scale=scale_bc), - rtol=0.10, - atol=0.) - fails = 0 - trials = 0 - for ai, a in enumerate(np.reshape(loc_v, [-1])): - for bi, b in enumerate(np.reshape(scale_v, [-1])): - s = sample_values[:, bi, ai] - trials += 1 - fails += 0 if self._kstest(a, b, s) else 1 - self.assertLess(fails, trials * 0.03) + loc_v = np.array([np.arange(1, 101, dtype=np.float32)]) # 1 x 100 + scale_v = np.array([np.arange(1, 11, dtype=np.float32)]).T # 10 x 1 + laplace = laplace_lib.Laplace(loc=loc_v, scale=scale_v) + n = 10000 + samples = laplace.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(samples.get_shape(), (n, 10, 100)) + self.assertEqual(sample_values.shape, (n, 10, 100)) + zeros = np.zeros_like(loc_v + scale_v) # 10 x 100 + loc_bc = loc_v + zeros + scale_bc = scale_v + zeros + if not stats: + return + self.assertAllClose( + sample_values.mean(axis=0), + stats.laplace.mean(loc_bc, scale=scale_bc), + rtol=0.35, + atol=0.) + self.assertAllClose( + sample_values.var(axis=0), + stats.laplace.var(loc_bc, scale=scale_bc), + rtol=0.10, + atol=0.) + fails = 0 + trials = 0 + for ai, a in enumerate(np.reshape(loc_v, [-1])): + for bi, b in enumerate(np.reshape(scale_v, [-1])): + s = sample_values[:, bi, ai] + trials += 1 + fails += 0 if self._kstest(a, b, s) else 1 + self.assertLess(fails, trials * 0.03) def _kstest(self, loc, scale, samples): # Uses the Kolmogorov-Smirnov test for goodness of fit. @@ -349,30 +332,26 @@ class LaplaceTest(test.TestCase): self.assertNear(1., total, err=err) def testLaplaceNonPositiveInitializationParamsRaises(self): - with self.test_session(): - loc_v = constant_op.constant(0.0, name="loc") - scale_v = constant_op.constant(-1.0, name="scale") - with self.assertRaisesOpError( - "Condition x > 0 did not hold element-wise"): - laplace = laplace_lib.Laplace( - loc=loc_v, scale=scale_v, validate_args=True) - self.evaluate(laplace.mean()) - loc_v = constant_op.constant(1.0, name="loc") - scale_v = constant_op.constant(0.0, name="scale") - with self.assertRaisesOpError( - "Condition x > 0 did not hold element-wise"): - laplace = laplace_lib.Laplace( - loc=loc_v, scale=scale_v, validate_args=True) - self.evaluate(laplace.mean()) + loc_v = constant_op.constant(0.0, name="loc") + scale_v = constant_op.constant(-1.0, name="scale") + with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"): + laplace = laplace_lib.Laplace( + loc=loc_v, scale=scale_v, validate_args=True) + self.evaluate(laplace.mean()) + loc_v = constant_op.constant(1.0, name="loc") + scale_v = constant_op.constant(0.0, name="scale") + with self.assertRaisesOpError("Condition x > 0 did not hold element-wise"): + laplace = laplace_lib.Laplace( + loc=loc_v, scale=scale_v, validate_args=True) + self.evaluate(laplace.mean()) def testLaplaceWithSoftplusScale(self): - with self.test_session(): - loc_v = constant_op.constant([0.0, 1.0], name="loc") - scale_v = constant_op.constant([-1.0, 2.0], name="scale") - laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v) - self.assertAllClose( - self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale)) - self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc)) + loc_v = constant_op.constant([0.0, 1.0], name="loc") + scale_v = constant_op.constant([-1.0, 2.0], name="scale") + laplace = laplace_lib.LaplaceWithSoftplusScale(loc=loc_v, scale=scale_v) + self.assertAllClose( + self.evaluate(nn_ops.softplus(scale_v)), self.evaluate(laplace.scale)) + self.assertAllClose(self.evaluate(loc_v), self.evaluate(laplace.loc)) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py index 5dcd6f6df4..de73a40b23 100644 --- a/tensorflow/python/kernel_tests/distributions/normal_test.py +++ b/tensorflow/python/kernel_tests/distributions/normal_test.py @@ -61,16 +61,15 @@ class NormalTest(test.TestCase): self.assertAllEqual(all_true, is_finite) def _testParamShapes(self, sample_shape, expected): - with self.test_session(): - param_shapes = normal_lib.Normal.param_shapes(sample_shape) - mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] - self.assertAllEqual(expected, self.evaluate(mu_shape)) - self.assertAllEqual(expected, self.evaluate(sigma_shape)) - mu = array_ops.zeros(mu_shape) - sigma = array_ops.ones(sigma_shape) - self.assertAllEqual( - expected, - self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample()))) + param_shapes = normal_lib.Normal.param_shapes(sample_shape) + mu_shape, sigma_shape = param_shapes["loc"], param_shapes["scale"] + self.assertAllEqual(expected, self.evaluate(mu_shape)) + self.assertAllEqual(expected, self.evaluate(sigma_shape)) + mu = array_ops.zeros(mu_shape) + sigma = array_ops.ones(sigma_shape) + self.assertAllEqual( + expected, + self.evaluate(array_ops.shape(normal_lib.Normal(mu, sigma).sample()))) def _testParamStaticShapes(self, sample_shape, expected): param_shapes = normal_lib.Normal.param_static_shapes(sample_shape) @@ -93,154 +92,148 @@ class NormalTest(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNormalWithSoftplusScale(self): - with self.test_session(): - mu = array_ops.zeros((10, 3)) - rho = array_ops.ones((10, 3)) * -2. - normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho) - self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc)) - self.assertAllEqual( - self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale)) + mu = array_ops.zeros((10, 3)) + rho = array_ops.ones((10, 3)) * -2. + normal = normal_lib.NormalWithSoftplusScale(loc=mu, scale=rho) + self.assertAllEqual(self.evaluate(mu), self.evaluate(normal.loc)) + self.assertAllEqual( + self.evaluate(nn_ops.softplus(rho)), self.evaluate(normal.scale)) @test_util.run_in_graph_and_eager_modes def testNormalLogPDF(self): - with self.test_session(): - batch_size = 6 - mu = constant_op.constant([3.0] * batch_size) - sigma = constant_op.constant([math.sqrt(10.0)] * batch_size) - x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) - normal = normal_lib.Normal(loc=mu, scale=sigma) - - log_pdf = normal.log_prob(x) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(log_pdf).shape) - self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) + batch_size = 6 + mu = constant_op.constant([3.0] * batch_size) + sigma = constant_op.constant([math.sqrt(10.0)] * batch_size) + x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) + normal = normal_lib.Normal(loc=mu, scale=sigma) - pdf = normal.prob(x) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(pdf).shape) - self.assertAllEqual(normal.batch_shape, pdf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape) - - if not stats: - return - expected_log_pdf = stats.norm(self.evaluate(mu), - self.evaluate(sigma)).logpdf(x) - self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf)) - self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf)) + log_pdf = normal.log_prob(x) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(log_pdf).shape) + self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) + + pdf = normal.prob(x) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(pdf).shape) + self.assertAllEqual(normal.batch_shape, pdf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(pdf).shape) + + if not stats: + return + expected_log_pdf = stats.norm(self.evaluate(mu), + self.evaluate(sigma)).logpdf(x) + self.assertAllClose(expected_log_pdf, self.evaluate(log_pdf)) + self.assertAllClose(np.exp(expected_log_pdf), self.evaluate(pdf)) @test_util.run_in_graph_and_eager_modes def testNormalLogPDFMultidimensional(self): - with self.test_session(): - batch_size = 6 - mu = constant_op.constant([[3.0, -3.0]] * batch_size) - sigma = constant_op.constant([[math.sqrt(10.0), math.sqrt(15.0)]] * - batch_size) - x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T - normal = normal_lib.Normal(loc=mu, scale=sigma) - - log_pdf = normal.log_prob(x) - log_pdf_values = self.evaluate(log_pdf) - self.assertEqual(log_pdf.get_shape(), (6, 2)) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(log_pdf).shape) - self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) - - pdf = normal.prob(x) - pdf_values = self.evaluate(pdf) - self.assertEqual(pdf.get_shape(), (6, 2)) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), pdf_values.shape) - self.assertAllEqual(normal.batch_shape, pdf.get_shape()) - self.assertAllEqual(normal.batch_shape, pdf_values.shape) + batch_size = 6 + mu = constant_op.constant([[3.0, -3.0]] * batch_size) + sigma = constant_op.constant( + [[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size) + x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T + normal = normal_lib.Normal(loc=mu, scale=sigma) - if not stats: - return - expected_log_pdf = stats.norm(self.evaluate(mu), - self.evaluate(sigma)).logpdf(x) - self.assertAllClose(expected_log_pdf, log_pdf_values) - self.assertAllClose(np.exp(expected_log_pdf), pdf_values) + log_pdf = normal.log_prob(x) + log_pdf_values = self.evaluate(log_pdf) + self.assertEqual(log_pdf.get_shape(), (6, 2)) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), log_pdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(log_pdf).shape) + self.assertAllEqual(normal.batch_shape, log_pdf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(log_pdf).shape) + + pdf = normal.prob(x) + pdf_values = self.evaluate(pdf) + self.assertEqual(pdf.get_shape(), (6, 2)) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), pdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), pdf_values.shape) + self.assertAllEqual(normal.batch_shape, pdf.get_shape()) + self.assertAllEqual(normal.batch_shape, pdf_values.shape) + + if not stats: + return + expected_log_pdf = stats.norm(self.evaluate(mu), + self.evaluate(sigma)).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf_values) + self.assertAllClose(np.exp(expected_log_pdf), pdf_values) @test_util.run_in_graph_and_eager_modes def testNormalCDF(self): - with self.test_session(): - batch_size = 50 - mu = self._rng.randn(batch_size) - sigma = self._rng.rand(batch_size) + 1.0 - x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + batch_size = 50 + mu = self._rng.randn(batch_size) + sigma = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) - normal = normal_lib.Normal(loc=mu, scale=sigma) - cdf = normal.cdf(x) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(cdf).shape) - self.assertAllEqual(normal.batch_shape, cdf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) - if not stats: - return - expected_cdf = stats.norm(mu, sigma).cdf(x) - self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0) + normal = normal_lib.Normal(loc=mu, scale=sigma) + cdf = normal.cdf(x) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(cdf).shape) + self.assertAllEqual(normal.batch_shape, cdf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) + if not stats: + return + expected_cdf = stats.norm(mu, sigma).cdf(x) + self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0) @test_util.run_in_graph_and_eager_modes def testNormalSurvivalFunction(self): - with self.test_session(): - batch_size = 50 - mu = self._rng.randn(batch_size) - sigma = self._rng.rand(batch_size) + 1.0 - x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + batch_size = 50 + mu = self._rng.randn(batch_size) + sigma = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) - normal = normal_lib.Normal(loc=mu, scale=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) - sf = normal.survival_function(x) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(sf).shape) - self.assertAllEqual(normal.batch_shape, sf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) - if not stats: - return - expected_sf = stats.norm(mu, sigma).sf(x) - self.assertAllClose(expected_sf, self.evaluate(sf), atol=0) + sf = normal.survival_function(x) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(sf).shape) + self.assertAllEqual(normal.batch_shape, sf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) + if not stats: + return + expected_sf = stats.norm(mu, sigma).sf(x) + self.assertAllClose(expected_sf, self.evaluate(sf), atol=0) @test_util.run_in_graph_and_eager_modes def testNormalLogCDF(self): - with self.test_session(): - batch_size = 50 - mu = self._rng.randn(batch_size) - sigma = self._rng.rand(batch_size) + 1.0 - x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64) + batch_size = 50 + mu = self._rng.randn(batch_size) + sigma = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64) - normal = normal_lib.Normal(loc=mu, scale=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) - cdf = normal.log_cdf(x) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(cdf).shape) - self.assertAllEqual(normal.batch_shape, cdf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) + cdf = normal.log_cdf(x) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), cdf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(cdf).shape) + self.assertAllEqual(normal.batch_shape, cdf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(cdf).shape) - if not stats: - return - expected_cdf = stats.norm(mu, sigma).logcdf(x) - self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3) + if not stats: + return + expected_cdf = stats.norm(mu, sigma).logcdf(x) + self.assertAllClose(expected_cdf, self.evaluate(cdf), atol=0, rtol=1e-3) def testFiniteGradientAtDifficultPoints(self): for dtype in [np.float32, np.float64]: @@ -256,7 +249,7 @@ class NormalTest(test.TestCase): ]: value = func(x) grads = gradients_impl.gradients(value, [mu, sigma]) - with self.test_session(graph=g): + with self.session(graph=g): variables.global_variables_initializer().run() self.assertAllFinite(value) self.assertAllFinite(grads[0]) @@ -264,112 +257,106 @@ class NormalTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNormalLogSurvivalFunction(self): - with self.test_session(): - batch_size = 50 - mu = self._rng.randn(batch_size) - sigma = self._rng.rand(batch_size) + 1.0 - x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64) + batch_size = 50 + mu = self._rng.randn(batch_size) + sigma = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64) - normal = normal_lib.Normal(loc=mu, scale=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) - sf = normal.log_survival_function(x) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(sf).shape) - self.assertAllEqual(normal.batch_shape, sf.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) + sf = normal.log_survival_function(x) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), sf.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(sf).shape) + self.assertAllEqual(normal.batch_shape, sf.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(sf).shape) - if not stats: - return - expected_sf = stats.norm(mu, sigma).logsf(x) - self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5) + if not stats: + return + expected_sf = stats.norm(mu, sigma).logsf(x) + self.assertAllClose(expected_sf, self.evaluate(sf), atol=0, rtol=1e-5) @test_util.run_in_graph_and_eager_modes def testNormalEntropyWithScalarInputs(self): # Scipy.stats.norm cannot deal with the shapes in the other test. - with self.test_session(): - mu_v = 2.34 - sigma_v = 4.56 - normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) - - entropy = normal.entropy() - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(entropy).shape) - self.assertAllEqual(normal.batch_shape, entropy.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) - # scipy.stats.norm cannot deal with these shapes. - if not stats: - return - expected_entropy = stats.norm(mu_v, sigma_v).entropy() - self.assertAllClose(expected_entropy, self.evaluate(entropy)) + mu_v = 2.34 + sigma_v = 4.56 + normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) + + entropy = normal.entropy() + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(entropy).shape) + self.assertAllEqual(normal.batch_shape, entropy.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) + # scipy.stats.norm cannot deal with these shapes. + if not stats: + return + expected_entropy = stats.norm(mu_v, sigma_v).entropy() + self.assertAllClose(expected_entropy, self.evaluate(entropy)) @test_util.run_in_graph_and_eager_modes def testNormalEntropy(self): - with self.test_session(): - mu_v = np.array([1.0, 1.0, 1.0]) - sigma_v = np.array([[1.0, 2.0, 3.0]]).T - normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) - - # scipy.stats.norm cannot deal with these shapes. - sigma_broadcast = mu_v * sigma_v - expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast** - 2) - entropy = normal.entropy() - np.testing.assert_allclose(expected_entropy, self.evaluate(entropy)) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(entropy).shape) - self.assertAllEqual(normal.batch_shape, entropy.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) + mu_v = np.array([1.0, 1.0, 1.0]) + sigma_v = np.array([[1.0, 2.0, 3.0]]).T + normal = normal_lib.Normal(loc=mu_v, scale=sigma_v) + + # scipy.stats.norm cannot deal with these shapes. + sigma_broadcast = mu_v * sigma_v + expected_entropy = 0.5 * np.log(2 * np.pi * np.exp(1) * sigma_broadcast**2) + entropy = normal.entropy() + np.testing.assert_allclose(expected_entropy, self.evaluate(entropy)) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), entropy.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(entropy).shape) + self.assertAllEqual(normal.batch_shape, entropy.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(entropy).shape) @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNormalMeanAndMode(self): - with self.test_session(): - # Mu will be broadcast to [7, 7, 7]. - mu = [7.] - sigma = [11., 12., 13.] + # Mu will be broadcast to [7, 7, 7]. + mu = [7.] + sigma = [11., 12., 13.] - normal = normal_lib.Normal(loc=mu, scale=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) - self.assertAllEqual((3,), normal.mean().get_shape()) - self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean())) + self.assertAllEqual((3,), normal.mean().get_shape()) + self.assertAllEqual([7., 7, 7], self.evaluate(normal.mean())) - self.assertAllEqual((3,), normal.mode().get_shape()) - self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode())) + self.assertAllEqual((3,), normal.mode().get_shape()) + self.assertAllEqual([7., 7, 7], self.evaluate(normal.mode())) @test_util.run_in_graph_and_eager_modes def testNormalQuantile(self): - with self.test_session(): - batch_size = 52 - mu = self._rng.randn(batch_size) - sigma = self._rng.rand(batch_size) + 1.0 - p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64) - # Quantile performs piecewise rational approximation so adding some - # special input values to make sure we hit all the pieces. - p = np.hstack((p, np.exp(-33), 1. - np.exp(-33))) + batch_size = 52 + mu = self._rng.randn(batch_size) + sigma = self._rng.rand(batch_size) + 1.0 + p = np.linspace(0., 1.0, batch_size - 2).astype(np.float64) + # Quantile performs piecewise rational approximation so adding some + # special input values to make sure we hit all the pieces. + p = np.hstack((p, np.exp(-33), 1. - np.exp(-33))) - normal = normal_lib.Normal(loc=mu, scale=sigma) - x = normal.quantile(p) + normal = normal_lib.Normal(loc=mu, scale=sigma) + x = normal.quantile(p) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), x.get_shape()) - self.assertAllEqual( - self.evaluate(normal.batch_shape_tensor()), - self.evaluate(x).shape) - self.assertAllEqual(normal.batch_shape, x.get_shape()) - self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), x.get_shape()) + self.assertAllEqual( + self.evaluate(normal.batch_shape_tensor()), + self.evaluate(x).shape) + self.assertAllEqual(normal.batch_shape, x.get_shape()) + self.assertAllEqual(normal.batch_shape, self.evaluate(x).shape) - if not stats: - return - expected_x = stats.norm(mu, sigma).ppf(p) - self.assertAllClose(expected_x, self.evaluate(x), atol=0.) + if not stats: + return + expected_x = stats.norm(mu, sigma).ppf(p) + self.assertAllClose(expected_x, self.evaluate(x), atol=0.) def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype): g = ops.Graph() @@ -385,7 +372,7 @@ class NormalTest(test.TestCase): value = dist.quantile(p) grads = gradients_impl.gradients(value, [mu, p]) - with self.test_session(graph=g): + with self.cached_session(graph=g): variables.global_variables_initializer().run() self.assertAllFinite(grads[0]) self.assertAllFinite(grads[1]) @@ -398,61 +385,58 @@ class NormalTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNormalVariance(self): - with self.test_session(): - # sigma will be broadcast to [7, 7, 7] - mu = [1., 2., 3.] - sigma = [7.] + # sigma will be broadcast to [7, 7, 7] + mu = [1., 2., 3.] + sigma = [7.] - normal = normal_lib.Normal(loc=mu, scale=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) - self.assertAllEqual((3,), normal.variance().get_shape()) - self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance())) + self.assertAllEqual((3,), normal.variance().get_shape()) + self.assertAllEqual([49., 49, 49], self.evaluate(normal.variance())) @test_util.run_in_graph_and_eager_modes def testNormalStandardDeviation(self): - with self.test_session(): - # sigma will be broadcast to [7, 7, 7] - mu = [1., 2., 3.] - sigma = [7.] + # sigma will be broadcast to [7, 7, 7] + mu = [1., 2., 3.] + sigma = [7.] - normal = normal_lib.Normal(loc=mu, scale=sigma) + normal = normal_lib.Normal(loc=mu, scale=sigma) - self.assertAllEqual((3,), normal.stddev().get_shape()) - self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev())) + self.assertAllEqual((3,), normal.stddev().get_shape()) + self.assertAllEqual([7., 7, 7], self.evaluate(normal.stddev())) @test_util.run_in_graph_and_eager_modes def testNormalSample(self): - with self.test_session(): - mu = constant_op.constant(3.0) - sigma = constant_op.constant(math.sqrt(3.0)) - mu_v = 3.0 - sigma_v = np.sqrt(3.0) - n = constant_op.constant(100000) - normal = normal_lib.Normal(loc=mu, scale=sigma) - samples = normal.sample(n) - sample_values = self.evaluate(samples) - # Note that the standard error for the sample mean is ~ sigma / sqrt(n). - # The sample variance similarly is dependent on sigma and n. - # Thus, the tolerances below are very sensitive to number of samples - # as well as the variances chosen. - self.assertEqual(sample_values.shape, (100000,)) - self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1) - self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1) - - expected_samples_shape = tensor_shape.TensorShape( - [self.evaluate(n)]).concatenate( - tensor_shape.TensorShape( - self.evaluate(normal.batch_shape_tensor()))) - - self.assertAllEqual(expected_samples_shape, samples.get_shape()) - self.assertAllEqual(expected_samples_shape, sample_values.shape) - - expected_samples_shape = ( - tensor_shape.TensorShape([self.evaluate(n)]).concatenate( - normal.batch_shape)) - - self.assertAllEqual(expected_samples_shape, samples.get_shape()) - self.assertAllEqual(expected_samples_shape, sample_values.shape) + mu = constant_op.constant(3.0) + sigma = constant_op.constant(math.sqrt(3.0)) + mu_v = 3.0 + sigma_v = np.sqrt(3.0) + n = constant_op.constant(100000) + normal = normal_lib.Normal(loc=mu, scale=sigma) + samples = normal.sample(n) + sample_values = self.evaluate(samples) + # Note that the standard error for the sample mean is ~ sigma / sqrt(n). + # The sample variance similarly is dependent on sigma and n. + # Thus, the tolerances below are very sensitive to number of samples + # as well as the variances chosen. + self.assertEqual(sample_values.shape, (100000,)) + self.assertAllClose(sample_values.mean(), mu_v, atol=1e-1) + self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1) + + expected_samples_shape = tensor_shape.TensorShape( + [self.evaluate(n)]).concatenate( + tensor_shape.TensorShape( + self.evaluate(normal.batch_shape_tensor()))) + + self.assertAllEqual(expected_samples_shape, samples.get_shape()) + self.assertAllEqual(expected_samples_shape, sample_values.shape) + + expected_samples_shape = ( + tensor_shape.TensorShape([self.evaluate(n)]).concatenate( + normal.batch_shape)) + + self.assertAllEqual(expected_samples_shape, samples.get_shape()) + self.assertAllEqual(expected_samples_shape, sample_values.shape) def testNormalFullyReparameterized(self): mu = constant_op.constant(4.0) @@ -468,66 +452,63 @@ class NormalTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNormalSampleMultiDimensional(self): - with self.test_session(): - batch_size = 2 - mu = constant_op.constant([[3.0, -3.0]] * batch_size) - sigma = constant_op.constant([[math.sqrt(2.0), math.sqrt(3.0)]] * - batch_size) - mu_v = [3.0, -3.0] - sigma_v = [np.sqrt(2.0), np.sqrt(3.0)] - n = constant_op.constant(100000) - normal = normal_lib.Normal(loc=mu, scale=sigma) - samples = normal.sample(n) - sample_values = self.evaluate(samples) - # Note that the standard error for the sample mean is ~ sigma / sqrt(n). - # The sample variance similarly is dependent on sigma and n. - # Thus, the tolerances below are very sensitive to number of samples - # as well as the variances chosen. - self.assertEqual(samples.get_shape(), (100000, batch_size, 2)) - self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1) - self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1) - self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1) - self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1) - - expected_samples_shape = tensor_shape.TensorShape( - [self.evaluate(n)]).concatenate( - tensor_shape.TensorShape( - self.evaluate(normal.batch_shape_tensor()))) - self.assertAllEqual(expected_samples_shape, samples.get_shape()) - self.assertAllEqual(expected_samples_shape, sample_values.shape) - - expected_samples_shape = ( - tensor_shape.TensorShape([self.evaluate(n)]).concatenate( - normal.batch_shape)) - self.assertAllEqual(expected_samples_shape, samples.get_shape()) - self.assertAllEqual(expected_samples_shape, sample_values.shape) + batch_size = 2 + mu = constant_op.constant([[3.0, -3.0]] * batch_size) + sigma = constant_op.constant( + [[math.sqrt(2.0), math.sqrt(3.0)]] * batch_size) + mu_v = [3.0, -3.0] + sigma_v = [np.sqrt(2.0), np.sqrt(3.0)] + n = constant_op.constant(100000) + normal = normal_lib.Normal(loc=mu, scale=sigma) + samples = normal.sample(n) + sample_values = self.evaluate(samples) + # Note that the standard error for the sample mean is ~ sigma / sqrt(n). + # The sample variance similarly is dependent on sigma and n. + # Thus, the tolerances below are very sensitive to number of samples + # as well as the variances chosen. + self.assertEqual(samples.get_shape(), (100000, batch_size, 2)) + self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-1) + self.assertAllClose(sample_values[:, 0, 0].std(), sigma_v[0], atol=1e-1) + self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=1e-1) + self.assertAllClose(sample_values[:, 0, 1].std(), sigma_v[1], atol=1e-1) + + expected_samples_shape = tensor_shape.TensorShape( + [self.evaluate(n)]).concatenate( + tensor_shape.TensorShape( + self.evaluate(normal.batch_shape_tensor()))) + self.assertAllEqual(expected_samples_shape, samples.get_shape()) + self.assertAllEqual(expected_samples_shape, sample_values.shape) + + expected_samples_shape = ( + tensor_shape.TensorShape([self.evaluate(n)]).concatenate( + normal.batch_shape)) + self.assertAllEqual(expected_samples_shape, samples.get_shape()) + self.assertAllEqual(expected_samples_shape, sample_values.shape) @test_util.run_in_graph_and_eager_modes def testNegativeSigmaFails(self): - with self.test_session(): - with self.assertRaisesOpError("Condition x > 0 did not hold"): - normal = normal_lib.Normal( - loc=[1.], scale=[-5.], validate_args=True, name="G") - self.evaluate(normal.mean()) + with self.assertRaisesOpError("Condition x > 0 did not hold"): + normal = normal_lib.Normal( + loc=[1.], scale=[-5.], validate_args=True, name="G") + self.evaluate(normal.mean()) @test_util.run_in_graph_and_eager_modes def testNormalShape(self): - with self.test_session(): - mu = constant_op.constant([-3.0] * 5) - sigma = constant_op.constant(11.0) - normal = normal_lib.Normal(loc=mu, scale=sigma) + mu = constant_op.constant([-3.0] * 5) + sigma = constant_op.constant(11.0) + normal = normal_lib.Normal(loc=mu, scale=sigma) - self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5]) - self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), []) - self.assertEqual(normal.event_shape, tensor_shape.TensorShape([])) + self.assertEqual(self.evaluate(normal.batch_shape_tensor()), [5]) + self.assertEqual(normal.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(self.evaluate(normal.event_shape_tensor()), []) + self.assertEqual(normal.event_shape, tensor_shape.TensorShape([])) def testNormalShapeWithPlaceholders(self): mu = array_ops.placeholder(dtype=dtypes.float32) sigma = array_ops.placeholder(dtype=dtypes.float32) normal = normal_lib.Normal(loc=mu, scale=sigma) - with self.test_session() as sess: + with self.cached_session() as sess: # get_batch_shape should return an "" tensor. self.assertEqual(normal.batch_shape, tensor_shape.TensorShape(None)) self.assertEqual(normal.event_shape, ()) diff --git a/tensorflow/python/kernel_tests/distributions/special_math_test.py b/tensorflow/python/kernel_tests/distributions/special_math_test.py index a634194ce5..cc43e12168 100644 --- a/tensorflow/python/kernel_tests/distributions/special_math_test.py +++ b/tensorflow/python/kernel_tests/distributions/special_math_test.py @@ -92,22 +92,21 @@ class NdtriTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testNdtri(self): """Verifies that ndtri computation is correct.""" - with self.test_session(): - if not special: - return + if not special: + return - p = np.linspace(0., 1.0, 50).astype(np.float64) - # Quantile performs piecewise rational approximation so adding some - # special input values to make sure we hit all the pieces. - p = np.hstack((p, np.exp(-32), 1. - np.exp(-32), - np.exp(-2), 1. - np.exp(-2))) - expected_x = special.ndtri(p) - x = special_math.ndtri(p) - self.assertAllClose(expected_x, self.evaluate(x), atol=0.) + p = np.linspace(0., 1.0, 50).astype(np.float64) + # Quantile performs piecewise rational approximation so adding some + # special input values to make sure we hit all the pieces. + p = np.hstack((p, np.exp(-32), 1. - np.exp(-32), np.exp(-2), + 1. - np.exp(-2))) + expected_x = special.ndtri(p) + x = special_math.ndtri(p) + self.assertAllClose(expected_x, self.evaluate(x), atol=0.) def testNdtriDynamicShape(self): """Verifies that ndtri computation is correct.""" - with self.test_session() as sess: + with self.cached_session() as sess: if not special: return @@ -286,7 +285,7 @@ class NdtrGradientTest(test.TestCase): def _test_grad_accuracy(self, dtype, grid_spec, error_spec): raw_grid = _make_grid(dtype, grid_spec) grid = ops.convert_to_tensor(raw_grid) - with self.test_session(): + with self.cached_session(): fn = sm.log_ndtr if self._use_log else sm.ndtr # If there are N points in the grid, @@ -355,7 +354,7 @@ class LogNdtrGradientTest(NdtrGradientTest): class ErfInvTest(test.TestCase): def testErfInvValues(self): - with self.test_session(): + with self.cached_session(): if not special: return @@ -366,7 +365,7 @@ class ErfInvTest(test.TestCase): self.assertAllClose(expected_x, x.eval(), atol=0.) def testErfInvIntegerInput(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(TypeError): x = np.array([1, 2, 3]).astype(np.int32) @@ -397,7 +396,7 @@ class LogCDFLaplaceTest(test.TestCase): self.assertAllEqual(np.ones_like(x, dtype=np.bool), x) def _test_grid_log(self, dtype, scipy_dtype, grid_spec, error_spec): - with self.test_session(): + with self.cached_session(): grid = _make_grid(dtype, grid_spec) actual = sm.log_cdf_laplace(grid).eval() @@ -439,7 +438,7 @@ class LogCDFLaplaceTest(test.TestCase): ErrorSpec(rtol=0.05, atol=0)) def test_float32_extreme_values_result_and_gradient_finite_and_nonzero(self): - with self.test_session() as sess: + with self.cached_session() as sess: # On the lower branch, log_cdf_laplace(x) = x, so we know this will be # fine, but test to -200 anyways. grid = _make_grid( @@ -458,7 +457,7 @@ class LogCDFLaplaceTest(test.TestCase): self.assertFalse(np.any(grad_ == 0)) def test_float64_extreme_values_result_and_gradient_finite_and_nonzero(self): - with self.test_session() as sess: + with self.cached_session() as sess: # On the lower branch, log_cdf_laplace(x) = x, so we know this will be # fine, but test to -200 anyways. grid = _make_grid( diff --git a/tensorflow/python/kernel_tests/distributions/student_t_test.py b/tensorflow/python/kernel_tests/distributions/student_t_test.py index 05590542ef..b34b538160 100644 --- a/tensorflow/python/kernel_tests/distributions/student_t_test.py +++ b/tensorflow/python/kernel_tests/distributions/student_t_test.py @@ -50,100 +50,96 @@ stats = try_import("scipy.stats") class StudentTTest(test.TestCase): def testStudentPDFAndLogPDF(self): - with self.test_session(): - batch_size = 6 - df = constant_op.constant([3.] * batch_size) - mu = constant_op.constant([7.] * batch_size) - sigma = constant_op.constant([8.] * batch_size) - df_v = 3. - mu_v = 7. - sigma_v = 8. - t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) - student = student_t.StudentT(df, loc=mu, scale=-sigma) - - log_pdf = student.log_prob(t) - self.assertEquals(log_pdf.get_shape(), (6,)) - log_pdf_values = self.evaluate(log_pdf) - pdf = student.prob(t) - self.assertEquals(pdf.get_shape(), (6,)) - pdf_values = self.evaluate(pdf) - - if not stats: - return - - expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) - expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) - self.assertAllClose(expected_log_pdf, log_pdf_values) - self.assertAllClose(np.log(expected_pdf), log_pdf_values) - self.assertAllClose(expected_pdf, pdf_values) - self.assertAllClose(np.exp(expected_log_pdf), pdf_values) + batch_size = 6 + df = constant_op.constant([3.] * batch_size) + mu = constant_op.constant([7.] * batch_size) + sigma = constant_op.constant([8.] * batch_size) + df_v = 3. + mu_v = 7. + sigma_v = 8. + t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) + student = student_t.StudentT(df, loc=mu, scale=-sigma) + + log_pdf = student.log_prob(t) + self.assertEquals(log_pdf.get_shape(), (6,)) + log_pdf_values = self.evaluate(log_pdf) + pdf = student.prob(t) + self.assertEquals(pdf.get_shape(), (6,)) + pdf_values = self.evaluate(pdf) + + if not stats: + return + + expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) + expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) + self.assertAllClose(expected_log_pdf, log_pdf_values) + self.assertAllClose(np.log(expected_pdf), log_pdf_values) + self.assertAllClose(expected_pdf, pdf_values) + self.assertAllClose(np.exp(expected_log_pdf), pdf_values) def testStudentLogPDFMultidimensional(self): - with self.test_session(): - batch_size = 6 - df = constant_op.constant([[1.5, 7.2]] * batch_size) - mu = constant_op.constant([[3., -3.]] * batch_size) - sigma = constant_op.constant([[-math.sqrt(10.), math.sqrt(15.)]] * - batch_size) - df_v = np.array([1.5, 7.2]) - mu_v = np.array([3., -3.]) - sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)]) - t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T - student = student_t.StudentT(df, loc=mu, scale=sigma) - log_pdf = student.log_prob(t) - log_pdf_values = self.evaluate(log_pdf) - self.assertEqual(log_pdf.get_shape(), (6, 2)) - pdf = student.prob(t) - pdf_values = self.evaluate(pdf) - self.assertEqual(pdf.get_shape(), (6, 2)) - - if not stats: - return - expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) - expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) - self.assertAllClose(expected_log_pdf, log_pdf_values) - self.assertAllClose(np.log(expected_pdf), log_pdf_values) - self.assertAllClose(expected_pdf, pdf_values) - self.assertAllClose(np.exp(expected_log_pdf), pdf_values) + batch_size = 6 + df = constant_op.constant([[1.5, 7.2]] * batch_size) + mu = constant_op.constant([[3., -3.]] * batch_size) + sigma = constant_op.constant( + [[-math.sqrt(10.), math.sqrt(15.)]] * batch_size) + df_v = np.array([1.5, 7.2]) + mu_v = np.array([3., -3.]) + sigma_v = np.array([np.sqrt(10.), np.sqrt(15.)]) + t = np.array([[-2.5, 2.5, 4., 0., -1., 2.]], dtype=np.float32).T + student = student_t.StudentT(df, loc=mu, scale=sigma) + log_pdf = student.log_prob(t) + log_pdf_values = self.evaluate(log_pdf) + self.assertEqual(log_pdf.get_shape(), (6, 2)) + pdf = student.prob(t) + pdf_values = self.evaluate(pdf) + self.assertEqual(pdf.get_shape(), (6, 2)) + + if not stats: + return + expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) + expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) + self.assertAllClose(expected_log_pdf, log_pdf_values) + self.assertAllClose(np.log(expected_pdf), log_pdf_values) + self.assertAllClose(expected_pdf, pdf_values) + self.assertAllClose(np.exp(expected_log_pdf), pdf_values) def testStudentCDFAndLogCDF(self): - with self.test_session(): - batch_size = 6 - df = constant_op.constant([3.] * batch_size) - mu = constant_op.constant([7.] * batch_size) - sigma = constant_op.constant([-8.] * batch_size) - df_v = 3. - mu_v = 7. - sigma_v = 8. - t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) - student = student_t.StudentT(df, loc=mu, scale=sigma) - - log_cdf = student.log_cdf(t) - self.assertEquals(log_cdf.get_shape(), (6,)) - log_cdf_values = self.evaluate(log_cdf) - cdf = student.cdf(t) - self.assertEquals(cdf.get_shape(), (6,)) - cdf_values = self.evaluate(cdf) - - if not stats: - return - expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v) - expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v) - self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5) - self.assertAllClose( - np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5) - self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5) - self.assertAllClose( - np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5) + batch_size = 6 + df = constant_op.constant([3.] * batch_size) + mu = constant_op.constant([7.] * batch_size) + sigma = constant_op.constant([-8.] * batch_size) + df_v = 3. + mu_v = 7. + sigma_v = 8. + t = np.array([-2.5, 2.5, 8., 0., -1., 2.], dtype=np.float32) + student = student_t.StudentT(df, loc=mu, scale=sigma) + + log_cdf = student.log_cdf(t) + self.assertEquals(log_cdf.get_shape(), (6,)) + log_cdf_values = self.evaluate(log_cdf) + cdf = student.cdf(t) + self.assertEquals(cdf.get_shape(), (6,)) + cdf_values = self.evaluate(cdf) + + if not stats: + return + expected_log_cdf = stats.t.logcdf(t, df_v, loc=mu_v, scale=sigma_v) + expected_cdf = stats.t.cdf(t, df_v, loc=mu_v, scale=sigma_v) + self.assertAllClose(expected_log_cdf, log_cdf_values, atol=0., rtol=1e-5) + self.assertAllClose( + np.log(expected_cdf), log_cdf_values, atol=0., rtol=1e-5) + self.assertAllClose(expected_cdf, cdf_values, atol=0., rtol=1e-5) + self.assertAllClose( + np.exp(expected_log_cdf), cdf_values, atol=0., rtol=1e-5) def testStudentEntropy(self): df_v = np.array([[2., 3., 7.]]) # 1x3 mu_v = np.array([[1., -1, 0]]) # 1x3 sigma_v = np.array([[1., -2., 3.]]).T # transposed => 3x1 - with self.test_session(): - student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v) - ent = student.entropy() - ent_values = self.evaluate(ent) + student = student_t.StudentT(df=df_v, loc=mu_v, scale=sigma_v) + ent = student.entropy() + ent_values = self.evaluate(ent) # Help scipy broadcast to 3x3 ones = np.array([[1, 1, 1]]) @@ -160,90 +156,81 @@ class StudentTTest(test.TestCase): self.assertAllClose(expected_entropy, ent_values) def testStudentSample(self): - with self.test_session(): - df = constant_op.constant(4.) - mu = constant_op.constant(3.) - sigma = constant_op.constant(-math.sqrt(10.)) - df_v = 4. - mu_v = 3. - sigma_v = np.sqrt(10.) - n = constant_op.constant(200000) - student = student_t.StudentT(df=df, loc=mu, scale=sigma) - samples = student.sample(n, seed=123456) - sample_values = self.evaluate(samples) - n_val = 200000 - self.assertEqual(sample_values.shape, (n_val,)) - self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0) - self.assertAllClose( - sample_values.var(), - sigma_v**2 * df_v / (df_v - 2), - rtol=0.1, - atol=0) - self._checkKLApprox(df_v, mu_v, sigma_v, sample_values) + df = constant_op.constant(4.) + mu = constant_op.constant(3.) + sigma = constant_op.constant(-math.sqrt(10.)) + df_v = 4. + mu_v = 3. + sigma_v = np.sqrt(10.) + n = constant_op.constant(200000) + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + samples = student.sample(n, seed=123456) + sample_values = self.evaluate(samples) + n_val = 200000 + self.assertEqual(sample_values.shape, (n_val,)) + self.assertAllClose(sample_values.mean(), mu_v, rtol=0.1, atol=0) + self.assertAllClose( + sample_values.var(), sigma_v**2 * df_v / (df_v - 2), rtol=0.1, atol=0) + self._checkKLApprox(df_v, mu_v, sigma_v, sample_values) # Test that sampling with the same seed twice gives the same results. def testStudentSampleMultipleTimes(self): - with self.test_session(): - df = constant_op.constant(4.) - mu = constant_op.constant(3.) - sigma = constant_op.constant(math.sqrt(10.)) - n = constant_op.constant(100) + df = constant_op.constant(4.) + mu = constant_op.constant(3.) + sigma = constant_op.constant(math.sqrt(10.)) + n = constant_op.constant(100) - random_seed.set_random_seed(654321) - student = student_t.StudentT( - df=df, loc=mu, scale=sigma, name="student_t1") - samples1 = self.evaluate(student.sample(n, seed=123456)) + random_seed.set_random_seed(654321) + student = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t1") + samples1 = self.evaluate(student.sample(n, seed=123456)) - random_seed.set_random_seed(654321) - student2 = student_t.StudentT( - df=df, loc=mu, scale=sigma, name="student_t2") - samples2 = self.evaluate(student2.sample(n, seed=123456)) + random_seed.set_random_seed(654321) + student2 = student_t.StudentT(df=df, loc=mu, scale=sigma, name="student_t2") + samples2 = self.evaluate(student2.sample(n, seed=123456)) - self.assertAllClose(samples1, samples2) + self.assertAllClose(samples1, samples2) def testStudentSampleSmallDfNoNan(self): - with self.test_session(): - df_v = [1e-1, 1e-5, 1e-10, 1e-20] - df = constant_op.constant(df_v) - n = constant_op.constant(200000) - student = student_t.StudentT(df=df, loc=1., scale=1.) - samples = student.sample(n, seed=123456) - sample_values = self.evaluate(samples) - n_val = 200000 - self.assertEqual(sample_values.shape, (n_val, 4)) - self.assertTrue(np.all(np.logical_not(np.isnan(sample_values)))) + df_v = [1e-1, 1e-5, 1e-10, 1e-20] + df = constant_op.constant(df_v) + n = constant_op.constant(200000) + student = student_t.StudentT(df=df, loc=1., scale=1.) + samples = student.sample(n, seed=123456) + sample_values = self.evaluate(samples) + n_val = 200000 + self.assertEqual(sample_values.shape, (n_val, 4)) + self.assertTrue(np.all(np.logical_not(np.isnan(sample_values)))) def testStudentSampleMultiDimensional(self): - with self.test_session(): - batch_size = 7 - df = constant_op.constant([[5., 7.]] * batch_size) - mu = constant_op.constant([[3., -3.]] * batch_size) - sigma = constant_op.constant([[math.sqrt(10.), math.sqrt(15.)]] * - batch_size) - df_v = [5., 7.] - mu_v = [3., -3.] - sigma_v = [np.sqrt(10.), np.sqrt(15.)] - n = constant_op.constant(200000) - student = student_t.StudentT(df=df, loc=mu, scale=sigma) - samples = student.sample(n, seed=123456) - sample_values = self.evaluate(samples) - self.assertEqual(samples.get_shape(), (200000, batch_size, 2)) - self.assertAllClose( - sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0) - self.assertAllClose( - sample_values[:, 0, 0].var(), - sigma_v[0]**2 * df_v[0] / (df_v[0] - 2), - rtol=0.2, - atol=0) - self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0]) - self.assertAllClose( - sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0) - self.assertAllClose( - sample_values[:, 0, 1].var(), - sigma_v[1]**2 * df_v[1] / (df_v[1] - 2), - rtol=0.2, - atol=0) - self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1]) + batch_size = 7 + df = constant_op.constant([[5., 7.]] * batch_size) + mu = constant_op.constant([[3., -3.]] * batch_size) + sigma = constant_op.constant( + [[math.sqrt(10.), math.sqrt(15.)]] * batch_size) + df_v = [5., 7.] + mu_v = [3., -3.] + sigma_v = [np.sqrt(10.), np.sqrt(15.)] + n = constant_op.constant(200000) + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + samples = student.sample(n, seed=123456) + sample_values = self.evaluate(samples) + self.assertEqual(samples.get_shape(), (200000, batch_size, 2)) + self.assertAllClose( + sample_values[:, 0, 0].mean(), mu_v[0], rtol=0.1, atol=0) + self.assertAllClose( + sample_values[:, 0, 0].var(), + sigma_v[0]**2 * df_v[0] / (df_v[0] - 2), + rtol=0.2, + atol=0) + self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0]) + self.assertAllClose( + sample_values[:, 0, 1].mean(), mu_v[1], rtol=0.1, atol=0) + self.assertAllClose( + sample_values[:, 0, 1].var(), + sigma_v[1]**2 * df_v[1] / (df_v[1] - 2), + rtol=0.2, + atol=0) + self._checkKLApprox(df_v[1], mu_v[1], sigma_v[1], sample_values[:, 0, 1]) def _checkKLApprox(self, df, mu, sigma, samples): n = samples.size @@ -325,114 +312,102 @@ class StudentTTest(test.TestCase): _check2d_rows(student_t.StudentT(df=7., loc=3., scale=[[2.], [3.], [4.]])) def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): - with self.test_session(): - mu = [1., 3.3, 4.4] - student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.]) - mean = self.evaluate(student.mean()) - self.assertAllClose([1., 3.3, 4.4], mean) + mu = [1., 3.3, 4.4] + student = student_t.StudentT(df=[3., 5., 7.], loc=mu, scale=[3., 2., 1.]) + mean = self.evaluate(student.mean()) + self.assertAllClose([1., 3.3, 4.4], mean) def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self): - with self.test_session(): - mu = [1., 3.3, 4.4] - student = student_t.StudentT( - df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], - allow_nan_stats=False) - with self.assertRaisesOpError("x < y"): - self.evaluate(student.mean()) + mu = [1., 3.3, 4.4] + student = student_t.StudentT( + df=[0.5, 5., 7.], loc=mu, scale=[3., 2., 1.], allow_nan_stats=False) + with self.assertRaisesOpError("x < y"): + self.evaluate(student.mean()) def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self): - with self.test_session(): - mu = [-2, 0., 1., 3.3, 4.4] - sigma = [5., 4., 3., 2., 1.] - student = student_t.StudentT( - df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, - allow_nan_stats=True) - mean = self.evaluate(student.mean()) - self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) + mu = [-2, 0., 1., 3.3, 4.4] + sigma = [5., 4., 3., 2., 1.] + student = student_t.StudentT( + df=[0.5, 1., 3., 5., 7.], loc=mu, scale=sigma, allow_nan_stats=True) + mean = self.evaluate(student.mean()) + self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self): - with self.test_session(): - # df = 0.5 ==> undefined mean ==> undefined variance. - # df = 1.5 ==> infinite variance. - df = [0.5, 1.5, 3., 5., 7.] - mu = [-2, 0., 1., 3.3, 4.4] - sigma = [5., 4., 3., 2., 1.] - student = student_t.StudentT( - df=df, loc=mu, scale=sigma, allow_nan_stats=True) - var = self.evaluate(student.variance()) - ## scipy uses inf for variance when the mean is undefined. When mean is - # undefined we say variance is undefined as well. So test the first - # member of var, making sure it is NaN, then replace with inf and compare - # to scipy. - self.assertTrue(np.isnan(var[0])) - var[0] = np.inf - - if not stats: - return - expected_var = [ - stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) - ] - self.assertAllClose(expected_var, var) + # df = 0.5 ==> undefined mean ==> undefined variance. + # df = 1.5 ==> infinite variance. + df = [0.5, 1.5, 3., 5., 7.] + mu = [-2, 0., 1., 3.3, 4.4] + sigma = [5., 4., 3., 2., 1.] + student = student_t.StudentT( + df=df, loc=mu, scale=sigma, allow_nan_stats=True) + var = self.evaluate(student.variance()) + ## scipy uses inf for variance when the mean is undefined. When mean is + # undefined we say variance is undefined as well. So test the first + # member of var, making sure it is NaN, then replace with inf and compare + # to scipy. + self.assertTrue(np.isnan(var[0])) + var[0] = np.inf + + if not stats: + return + expected_var = [ + stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) + ] + self.assertAllClose(expected_var, var) def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers( self): - with self.test_session(): - # df = 1.5 ==> infinite variance. - df = [1.5, 3., 5., 7.] - mu = [0., 1., 3.3, 4.4] - sigma = [4., 3., 2., 1.] - student = student_t.StudentT(df=df, loc=mu, scale=sigma) - var = self.evaluate(student.variance()) + # df = 1.5 ==> infinite variance. + df = [1.5, 3., 5., 7.] + mu = [0., 1., 3.3, 4.4] + sigma = [4., 3., 2., 1.] + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + var = self.evaluate(student.variance()) - if not stats: - return - expected_var = [ - stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) - ] - self.assertAllClose(expected_var, var) + if not stats: + return + expected_var = [ + stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) + ] + self.assertAllClose(expected_var, var) def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): - with self.test_session(): - # df <= 1 ==> variance not defined - student = student_t.StudentT( - df=1., loc=0., scale=1., allow_nan_stats=False) - with self.assertRaisesOpError("x < y"): - self.evaluate(student.variance()) + # df <= 1 ==> variance not defined + student = student_t.StudentT(df=1., loc=0., scale=1., allow_nan_stats=False) + with self.assertRaisesOpError("x < y"): + self.evaluate(student.variance()) - with self.test_session(): - # df <= 1 ==> variance not defined - student = student_t.StudentT( - df=0.5, loc=0., scale=1., allow_nan_stats=False) - with self.assertRaisesOpError("x < y"): - self.evaluate(student.variance()) + # df <= 1 ==> variance not defined + student = student_t.StudentT( + df=0.5, loc=0., scale=1., allow_nan_stats=False) + with self.assertRaisesOpError("x < y"): + self.evaluate(student.variance()) def testStd(self): - with self.test_session(): - # Defined for all batch members. - df = [3.5, 5., 3., 5., 7.] - mu = [-2.2] - sigma = [5., 4., 3., 2., 1.] - student = student_t.StudentT(df=df, loc=mu, scale=sigma) - # Test broadcast of mu across shape of df/sigma - stddev = self.evaluate(student.stddev()) - mu *= len(df) + # Defined for all batch members. + df = [3.5, 5., 3., 5., 7.] + mu = [-2.2] + sigma = [5., 4., 3., 2., 1.] + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + # Test broadcast of mu across shape of df/sigma + stddev = self.evaluate(student.stddev()) + mu *= len(df) - if not stats: - return - expected_stddev = [ - stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) - ] - self.assertAllClose(expected_stddev, stddev) + if not stats: + return + expected_stddev = [ + stats.t.std(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma) + ] + self.assertAllClose(expected_stddev, stddev) def testMode(self): - with self.test_session(): - df = [0.5, 1., 3] - mu = [-1, 0., 1] - sigma = [5., 4., 3.] - student = student_t.StudentT(df=df, loc=mu, scale=sigma) - # Test broadcast of mu across shape of df/sigma - mode = self.evaluate(student.mode()) - self.assertAllClose([-1., 0, 1], mode) + df = [0.5, 1., 3] + mu = [-1, 0., 1] + sigma = [5., 4., 3.] + student = student_t.StudentT(df=df, loc=mu, scale=sigma) + # Test broadcast of mu across shape of df/sigma + mode = self.evaluate(student.mode()) + self.assertAllClose([-1., 0, 1], mode) def testPdfOfSample(self): student = student_t.StudentT(df=3., loc=np.pi, scale=1.) @@ -510,25 +485,23 @@ class StudentTTest(test.TestCase): self.assertNear(1., total, err=err) def testNegativeDofFails(self): - with self.test_session(): - with self.assertRaisesOpError(r"Condition x > 0 did not hold"): - student = student_t.StudentT( - df=[2, -5.], loc=0., scale=1., validate_args=True, name="S") - self.evaluate(student.mean()) + with self.assertRaisesOpError(r"Condition x > 0 did not hold"): + student = student_t.StudentT( + df=[2, -5.], loc=0., scale=1., validate_args=True, name="S") + self.evaluate(student.mean()) def testStudentTWithAbsDfSoftplusScale(self): - with self.test_session(): - df = constant_op.constant([-3.2, -4.6]) - mu = constant_op.constant([-4.2, 3.4]) - sigma = constant_op.constant([-6.4, -8.8]) - student = student_t.StudentTWithAbsDfSoftplusScale( - df=df, loc=mu, scale=sigma) - self.assertAllClose( - math_ops.floor(self.evaluate(math_ops.abs(df))), - self.evaluate(student.df)) - self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc)) - self.assertAllClose( - self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale)) + df = constant_op.constant([-3.2, -4.6]) + mu = constant_op.constant([-4.2, 3.4]) + sigma = constant_op.constant([-6.4, -8.8]) + student = student_t.StudentTWithAbsDfSoftplusScale( + df=df, loc=mu, scale=sigma) + self.assertAllClose( + math_ops.floor(self.evaluate(math_ops.abs(df))), + self.evaluate(student.df)) + self.assertAllClose(self.evaluate(mu), self.evaluate(student.loc)) + self.assertAllClose( + self.evaluate(nn_ops.softplus(sigma)), self.evaluate(student.scale)) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/distributions/uniform_test.py b/tensorflow/python/kernel_tests/distributions/uniform_test.py index bc9c267b9a..9cdcd369c1 100644 --- a/tensorflow/python/kernel_tests/distributions/uniform_test.py +++ b/tensorflow/python/kernel_tests/distributions/uniform_test.py @@ -50,255 +50,239 @@ class UniformTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testUniformRange(self): - with self.test_session(): - a = 3.0 - b = 10.0 - uniform = uniform_lib.Uniform(low=a, high=b) - self.assertAllClose(a, self.evaluate(uniform.low)) - self.assertAllClose(b, self.evaluate(uniform.high)) - self.assertAllClose(b - a, self.evaluate(uniform.range())) + a = 3.0 + b = 10.0 + uniform = uniform_lib.Uniform(low=a, high=b) + self.assertAllClose(a, self.evaluate(uniform.low)) + self.assertAllClose(b, self.evaluate(uniform.high)) + self.assertAllClose(b - a, self.evaluate(uniform.range())) @test_util.run_in_graph_and_eager_modes def testUniformPDF(self): - with self.test_session(): - a = constant_op.constant([-3.0] * 5 + [15.0]) - b = constant_op.constant([11.0] * 5 + [20.0]) - uniform = uniform_lib.Uniform(low=a, high=b) + a = constant_op.constant([-3.0] * 5 + [15.0]) + b = constant_op.constant([11.0] * 5 + [20.0]) + uniform = uniform_lib.Uniform(low=a, high=b) - a_v = -3.0 - b_v = 11.0 - x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32) + a_v = -3.0 + b_v = 11.0 + x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32) - def _expected_pdf(): - pdf = np.zeros_like(x) + 1.0 / (b_v - a_v) - pdf[x > b_v] = 0.0 - pdf[x < a_v] = 0.0 - pdf[5] = 1.0 / (20.0 - 15.0) - return pdf + def _expected_pdf(): + pdf = np.zeros_like(x) + 1.0 / (b_v - a_v) + pdf[x > b_v] = 0.0 + pdf[x < a_v] = 0.0 + pdf[5] = 1.0 / (20.0 - 15.0) + return pdf - expected_pdf = _expected_pdf() + expected_pdf = _expected_pdf() - pdf = uniform.prob(x) - self.assertAllClose(expected_pdf, self.evaluate(pdf)) + pdf = uniform.prob(x) + self.assertAllClose(expected_pdf, self.evaluate(pdf)) - log_pdf = uniform.log_prob(x) - self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf)) + log_pdf = uniform.log_prob(x) + self.assertAllClose(np.log(expected_pdf), self.evaluate(log_pdf)) @test_util.run_in_graph_and_eager_modes def testUniformShape(self): - with self.test_session(): - a = constant_op.constant([-3.0] * 5) - b = constant_op.constant(11.0) - uniform = uniform_lib.Uniform(low=a, high=b) + a = constant_op.constant([-3.0] * 5) + b = constant_op.constant(11.0) + uniform = uniform_lib.Uniform(low=a, high=b) - self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,)) - self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5])) - self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), []) - self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([])) + self.assertEqual(self.evaluate(uniform.batch_shape_tensor()), (5,)) + self.assertEqual(uniform.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(self.evaluate(uniform.event_shape_tensor()), []) + self.assertEqual(uniform.event_shape, tensor_shape.TensorShape([])) @test_util.run_in_graph_and_eager_modes def testUniformPDFWithScalarEndpoint(self): - with self.test_session(): - a = constant_op.constant([0.0, 5.0]) - b = constant_op.constant(10.0) - uniform = uniform_lib.Uniform(low=a, high=b) + a = constant_op.constant([0.0, 5.0]) + b = constant_op.constant(10.0) + uniform = uniform_lib.Uniform(low=a, high=b) - x = np.array([0.0, 8.0], dtype=np.float32) - expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)]) + x = np.array([0.0, 8.0], dtype=np.float32) + expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)]) - pdf = uniform.prob(x) - self.assertAllClose(expected_pdf, self.evaluate(pdf)) + pdf = uniform.prob(x) + self.assertAllClose(expected_pdf, self.evaluate(pdf)) @test_util.run_in_graph_and_eager_modes def testUniformCDF(self): - with self.test_session(): - batch_size = 6 - a = constant_op.constant([1.0] * batch_size) - b = constant_op.constant([11.0] * batch_size) - a_v = 1.0 - b_v = 11.0 - x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32) + batch_size = 6 + a = constant_op.constant([1.0] * batch_size) + b = constant_op.constant([11.0] * batch_size) + a_v = 1.0 + b_v = 11.0 + x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32) - uniform = uniform_lib.Uniform(low=a, high=b) + uniform = uniform_lib.Uniform(low=a, high=b) - def _expected_cdf(): - cdf = (x - a_v) / (b_v - a_v) - cdf[x >= b_v] = 1 - cdf[x < a_v] = 0 - return cdf + def _expected_cdf(): + cdf = (x - a_v) / (b_v - a_v) + cdf[x >= b_v] = 1 + cdf[x < a_v] = 0 + return cdf - cdf = uniform.cdf(x) - self.assertAllClose(_expected_cdf(), self.evaluate(cdf)) + cdf = uniform.cdf(x) + self.assertAllClose(_expected_cdf(), self.evaluate(cdf)) - log_cdf = uniform.log_cdf(x) - self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf)) + log_cdf = uniform.log_cdf(x) + self.assertAllClose(np.log(_expected_cdf()), self.evaluate(log_cdf)) @test_util.run_in_graph_and_eager_modes def testUniformEntropy(self): - with self.test_session(): - a_v = np.array([1.0, 1.0, 1.0]) - b_v = np.array([[1.5, 2.0, 3.0]]) - uniform = uniform_lib.Uniform(low=a_v, high=b_v) + a_v = np.array([1.0, 1.0, 1.0]) + b_v = np.array([[1.5, 2.0, 3.0]]) + uniform = uniform_lib.Uniform(low=a_v, high=b_v) - expected_entropy = np.log(b_v - a_v) - self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy())) + expected_entropy = np.log(b_v - a_v) + self.assertAllClose(expected_entropy, self.evaluate(uniform.entropy())) @test_util.run_in_graph_and_eager_modes def testUniformAssertMaxGtMin(self): - with self.test_session(): - a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32) - b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32) + a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32) + b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32) - with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, - "x < y"): - uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True) - self.evaluate(uniform.low) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "x < y"): + uniform = uniform_lib.Uniform(low=a_v, high=b_v, validate_args=True) + self.evaluate(uniform.low) @test_util.run_in_graph_and_eager_modes def testUniformSample(self): - with self.test_session(): - a = constant_op.constant([3.0, 4.0]) - b = constant_op.constant(13.0) - a1_v = 3.0 - a2_v = 4.0 - b_v = 13.0 - n = constant_op.constant(100000) - uniform = uniform_lib.Uniform(low=a, high=b) - - samples = uniform.sample(n, seed=137) - sample_values = self.evaluate(samples) - self.assertEqual(sample_values.shape, (100000, 2)) - self.assertAllClose( - sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.) - self.assertAllClose( - sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.) - self.assertFalse( - np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v)) - self.assertFalse( - np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v)) + a = constant_op.constant([3.0, 4.0]) + b = constant_op.constant(13.0) + a1_v = 3.0 + a2_v = 4.0 + b_v = 13.0 + n = constant_op.constant(100000) + uniform = uniform_lib.Uniform(low=a, high=b) + + samples = uniform.sample(n, seed=137) + sample_values = self.evaluate(samples) + self.assertEqual(sample_values.shape, (100000, 2)) + self.assertAllClose( + sample_values[::, 0].mean(), (b_v + a1_v) / 2, atol=1e-1, rtol=0.) + self.assertAllClose( + sample_values[::, 1].mean(), (b_v + a2_v) / 2, atol=1e-1, rtol=0.) + self.assertFalse( + np.any(sample_values[::, 0] < a1_v) or np.any(sample_values >= b_v)) + self.assertFalse( + np.any(sample_values[::, 1] < a2_v) or np.any(sample_values >= b_v)) @test_util.run_in_graph_and_eager_modes def _testUniformSampleMultiDimensional(self): # DISABLED: Please enable this test once b/issues/30149644 is resolved. - with self.test_session(): - batch_size = 2 - a_v = [3.0, 22.0] - b_v = [13.0, 35.0] - a = constant_op.constant([a_v] * batch_size) - b = constant_op.constant([b_v] * batch_size) - - uniform = uniform_lib.Uniform(low=a, high=b) - - n_v = 100000 - n = constant_op.constant(n_v) - samples = uniform.sample(n) - self.assertEqual(samples.get_shape(), (n_v, batch_size, 2)) - - sample_values = self.evaluate(samples) - - self.assertFalse( - np.any(sample_values[:, 0, 0] < a_v[0]) or - np.any(sample_values[:, 0, 0] >= b_v[0])) - self.assertFalse( - np.any(sample_values[:, 0, 1] < a_v[1]) or - np.any(sample_values[:, 0, 1] >= b_v[1])) - - self.assertAllClose( - sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2) - self.assertAllClose( - sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2) + batch_size = 2 + a_v = [3.0, 22.0] + b_v = [13.0, 35.0] + a = constant_op.constant([a_v] * batch_size) + b = constant_op.constant([b_v] * batch_size) + + uniform = uniform_lib.Uniform(low=a, high=b) + + n_v = 100000 + n = constant_op.constant(n_v) + samples = uniform.sample(n) + self.assertEqual(samples.get_shape(), (n_v, batch_size, 2)) + + sample_values = self.evaluate(samples) + + self.assertFalse( + np.any(sample_values[:, 0, 0] < a_v[0]) or + np.any(sample_values[:, 0, 0] >= b_v[0])) + self.assertFalse( + np.any(sample_values[:, 0, 1] < a_v[1]) or + np.any(sample_values[:, 0, 1] >= b_v[1])) + + self.assertAllClose( + sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2, atol=1e-2) + self.assertAllClose( + sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2, atol=1e-2) @test_util.run_in_graph_and_eager_modes def testUniformMean(self): - with self.test_session(): - a = 10.0 - b = 100.0 - uniform = uniform_lib.Uniform(low=a, high=b) - if not stats: - return - s_uniform = stats.uniform(loc=a, scale=b - a) - self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean()) + a = 10.0 + b = 100.0 + uniform = uniform_lib.Uniform(low=a, high=b) + if not stats: + return + s_uniform = stats.uniform(loc=a, scale=b - a) + self.assertAllClose(self.evaluate(uniform.mean()), s_uniform.mean()) @test_util.run_in_graph_and_eager_modes def testUniformVariance(self): - with self.test_session(): - a = 10.0 - b = 100.0 - uniform = uniform_lib.Uniform(low=a, high=b) - if not stats: - return - s_uniform = stats.uniform(loc=a, scale=b - a) - self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var()) + a = 10.0 + b = 100.0 + uniform = uniform_lib.Uniform(low=a, high=b) + if not stats: + return + s_uniform = stats.uniform(loc=a, scale=b - a) + self.assertAllClose(self.evaluate(uniform.variance()), s_uniform.var()) @test_util.run_in_graph_and_eager_modes def testUniformStd(self): - with self.test_session(): - a = 10.0 - b = 100.0 - uniform = uniform_lib.Uniform(low=a, high=b) - if not stats: - return - s_uniform = stats.uniform(loc=a, scale=b - a) - self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std()) + a = 10.0 + b = 100.0 + uniform = uniform_lib.Uniform(low=a, high=b) + if not stats: + return + s_uniform = stats.uniform(loc=a, scale=b - a) + self.assertAllClose(self.evaluate(uniform.stddev()), s_uniform.std()) @test_util.run_in_graph_and_eager_modes def testUniformNans(self): - with self.test_session(): - a = 10.0 - b = [11.0, 100.0] - uniform = uniform_lib.Uniform(low=a, high=b) + a = 10.0 + b = [11.0, 100.0] + uniform = uniform_lib.Uniform(low=a, high=b) - no_nans = constant_op.constant(1.0) - nans = constant_op.constant(0.0) / constant_op.constant(0.0) - self.assertTrue(self.evaluate(math_ops.is_nan(nans))) - with_nans = array_ops.stack([no_nans, nans]) + no_nans = constant_op.constant(1.0) + nans = constant_op.constant(0.0) / constant_op.constant(0.0) + self.assertTrue(self.evaluate(math_ops.is_nan(nans))) + with_nans = array_ops.stack([no_nans, nans]) - pdf = uniform.prob(with_nans) + pdf = uniform.prob(with_nans) - is_nan = self.evaluate(math_ops.is_nan(pdf)) - self.assertFalse(is_nan[0]) - self.assertTrue(is_nan[1]) + is_nan = self.evaluate(math_ops.is_nan(pdf)) + self.assertFalse(is_nan[0]) + self.assertTrue(is_nan[1]) @test_util.run_in_graph_and_eager_modes def testUniformSamplePdf(self): - with self.test_session(): - a = 10.0 - b = [11.0, 100.0] - uniform = uniform_lib.Uniform(a, b) - self.assertTrue( - self.evaluate( - math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0))) + a = 10.0 + b = [11.0, 100.0] + uniform = uniform_lib.Uniform(a, b) + self.assertTrue( + self.evaluate( + math_ops.reduce_all(uniform.prob(uniform.sample(10)) > 0))) @test_util.run_in_graph_and_eager_modes def testUniformBroadcasting(self): - with self.test_session(): - a = 10.0 - b = [11.0, 20.0] - uniform = uniform_lib.Uniform(a, b) + a = 10.0 + b = [11.0, 20.0] + uniform = uniform_lib.Uniform(a, b) - pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]]) - expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]]) - self.assertAllClose(expected_pdf, self.evaluate(pdf)) + pdf = uniform.prob([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]]) + expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]]) + self.assertAllClose(expected_pdf, self.evaluate(pdf)) @test_util.run_in_graph_and_eager_modes def testUniformSampleWithShape(self): - with self.test_session(): - a = 10.0 - b = [11.0, 20.0] - uniform = uniform_lib.Uniform(a, b) - - pdf = uniform.prob(uniform.sample((2, 3))) - # pylint: disable=bad-continuation - expected_pdf = [ - [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]], - [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]], - ] - # pylint: enable=bad-continuation - self.assertAllClose(expected_pdf, self.evaluate(pdf)) - - pdf = uniform.prob(uniform.sample()) - expected_pdf = [1.0, 0.1] - self.assertAllClose(expected_pdf, self.evaluate(pdf)) + a = 10.0 + b = [11.0, 20.0] + uniform = uniform_lib.Uniform(a, b) + + pdf = uniform.prob(uniform.sample((2, 3))) + # pylint: disable=bad-continuation + expected_pdf = [ + [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]], + [[1.0, 0.1], [1.0, 0.1], [1.0, 0.1]], + ] + # pylint: enable=bad-continuation + self.assertAllClose(expected_pdf, self.evaluate(pdf)) + + pdf = uniform.prob(uniform.sample()) + expected_pdf = [1.0, 0.1] + self.assertAllClose(expected_pdf, self.evaluate(pdf)) def testFullyReparameterized(self): a = constant_op.constant(0.1) diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py index 61faa8466e..27d652c2c6 100644 --- a/tensorflow/python/kernel_tests/distributions/util_test.py +++ b/tensorflow/python/kernel_tests/distributions/util_test.py @@ -69,7 +69,7 @@ class AssertCloseTest(test.TestCase): w = array_ops.placeholder(dtypes.float32) feed_dict = {x: [1., 5, 10, 15, 20], y: [1.1, 5, 10, 15, 20], z: [1.0001, 5, 10, 15, 20], w: [1e-8, 5, 10, 15, 20]} - with self.test_session(): + with self.cached_session(): with ops.control_dependencies([du.assert_integer_form(x)]): array_ops.identity(x).eval(feed_dict=feed_dict) @@ -122,58 +122,52 @@ class GetLogitsAndProbsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testImproperArguments(self): - with self.test_session(): - with self.assertRaises(ValueError): - du.get_logits_and_probs(logits=None, probs=None) + with self.assertRaises(ValueError): + du.get_logits_and_probs(logits=None, probs=None) - with self.assertRaises(ValueError): - du.get_logits_and_probs(logits=[0.1], probs=[0.1]) + with self.assertRaises(ValueError): + du.get_logits_and_probs(logits=[0.1], probs=[0.1]) @test_util.run_in_graph_and_eager_modes def testLogits(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) logits = _logit(p) - with self.test_session(): - new_logits, new_p = du.get_logits_and_probs( - logits=logits, validate_args=True) + new_logits, new_p = du.get_logits_and_probs( + logits=logits, validate_args=True) - self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.) - self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.) + self.assertAllClose(p, self.evaluate(new_p), rtol=1e-5, atol=0.) + self.assertAllClose(logits, self.evaluate(new_logits), rtol=1e-5, atol=0.) @test_util.run_in_graph_and_eager_modes def testLogitsMultidimensional(self): p = np.array([0.2, 0.3, 0.5], dtype=np.float32) logits = np.log(p) - with self.test_session(): - new_logits, new_p = du.get_logits_and_probs( - logits=logits, multidimensional=True, validate_args=True) + new_logits, new_p = du.get_logits_and_probs( + logits=logits, multidimensional=True, validate_args=True) - self.assertAllClose(self.evaluate(new_p), p) - self.assertAllClose(self.evaluate(new_logits), logits) + self.assertAllClose(self.evaluate(new_p), p) + self.assertAllClose(self.evaluate(new_logits), logits) @test_util.run_in_graph_and_eager_modes def testProbability(self): p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32) - with self.test_session(): - new_logits, new_p = du.get_logits_and_probs( - probs=p, validate_args=True) + new_logits, new_p = du.get_logits_and_probs(probs=p, validate_args=True) - self.assertAllClose(_logit(p), self.evaluate(new_logits)) - self.assertAllClose(p, self.evaluate(new_p)) + self.assertAllClose(_logit(p), self.evaluate(new_logits)) + self.assertAllClose(p, self.evaluate(new_p)) @test_util.run_in_graph_and_eager_modes def testProbabilityMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) - with self.test_session(): - new_logits, new_p = du.get_logits_and_probs( - probs=p, multidimensional=True, validate_args=True) + new_logits, new_p = du.get_logits_and_probs( + probs=p, multidimensional=True, validate_args=True) - self.assertAllClose(np.log(p), self.evaluate(new_logits)) - self.assertAllClose(p, self.evaluate(new_p)) + self.assertAllClose(np.log(p), self.evaluate(new_logits)) + self.assertAllClose(p, self.evaluate(new_p)) @test_util.run_in_graph_and_eager_modes def testProbabilityValidateArgs(self): @@ -183,29 +177,23 @@ class GetLogitsAndProbsTest(test.TestCase): # Component greater than 1. p3 = [2, 0.2, 0.5, 0.3, .2] - with self.test_session(): - _, prob = du.get_logits_and_probs( - probs=p, validate_args=True) - self.evaluate(prob) - - with self.assertRaisesOpError("Condition x >= 0"): - _, prob = du.get_logits_and_probs( - probs=p2, validate_args=True) - self.evaluate(prob) + _, prob = du.get_logits_and_probs(probs=p, validate_args=True) + self.evaluate(prob) - _, prob = du.get_logits_and_probs( - probs=p2, validate_args=False) + with self.assertRaisesOpError("Condition x >= 0"): + _, prob = du.get_logits_and_probs(probs=p2, validate_args=True) self.evaluate(prob) - with self.assertRaisesOpError("probs has components greater than 1"): - _, prob = du.get_logits_and_probs( - probs=p3, validate_args=True) - self.evaluate(prob) + _, prob = du.get_logits_and_probs(probs=p2, validate_args=False) + self.evaluate(prob) - _, prob = du.get_logits_and_probs( - probs=p3, validate_args=False) + with self.assertRaisesOpError("probs has components greater than 1"): + _, prob = du.get_logits_and_probs(probs=p3, validate_args=True) self.evaluate(prob) + _, prob = du.get_logits_and_probs(probs=p3, validate_args=False) + self.evaluate(prob) + @test_util.run_in_graph_and_eager_modes def testProbabilityValidateArgsMultidimensional(self): p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32) @@ -216,41 +204,39 @@ class GetLogitsAndProbsTest(test.TestCase): # Does not sum to 1. p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32) - with self.test_session(): - _, prob = du.get_logits_and_probs( - probs=p, multidimensional=True) - self.evaluate(prob) - - with self.assertRaisesOpError("Condition x >= 0"): - _, prob = du.get_logits_and_probs( - probs=p2, multidimensional=True, validate_args=True) - self.evaluate(prob) + _, prob = du.get_logits_and_probs(probs=p, multidimensional=True) + self.evaluate(prob) + with self.assertRaisesOpError("Condition x >= 0"): _, prob = du.get_logits_and_probs( - probs=p2, multidimensional=True, validate_args=False) + probs=p2, multidimensional=True, validate_args=True) self.evaluate(prob) - with self.assertRaisesOpError( - "(probs has components greater than 1|probs does not sum to 1)"): - _, prob = du.get_logits_and_probs( - probs=p3, multidimensional=True, validate_args=True) - self.evaluate(prob) + _, prob = du.get_logits_and_probs( + probs=p2, multidimensional=True, validate_args=False) + self.evaluate(prob) + with self.assertRaisesOpError( + "(probs has components greater than 1|probs does not sum to 1)"): _, prob = du.get_logits_and_probs( - probs=p3, multidimensional=True, validate_args=False) + probs=p3, multidimensional=True, validate_args=True) self.evaluate(prob) - with self.assertRaisesOpError("probs does not sum to 1"): - _, prob = du.get_logits_and_probs( - probs=p4, multidimensional=True, validate_args=True) - self.evaluate(prob) + _, prob = du.get_logits_and_probs( + probs=p3, multidimensional=True, validate_args=False) + self.evaluate(prob) + with self.assertRaisesOpError("probs does not sum to 1"): _, prob = du.get_logits_and_probs( - probs=p4, multidimensional=True, validate_args=False) + probs=p4, multidimensional=True, validate_args=True) self.evaluate(prob) + _, prob = du.get_logits_and_probs( + probs=p4, multidimensional=True, validate_args=False) + self.evaluate(prob) + def testProbsMultidimShape(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): p = array_ops.ones([int(2**11+1)], dtype=np.float16) du.get_logits_and_probs( @@ -264,7 +250,7 @@ class GetLogitsAndProbsTest(test.TestCase): prob.eval(feed_dict={p: np.ones([int(2**11+1)])}) def testLogitsMultidimShape(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): l = array_ops.ones([int(2**11+1)], dtype=np.float16) du.get_logits_and_probs( @@ -281,7 +267,7 @@ class GetLogitsAndProbsTest(test.TestCase): class EmbedCheckCategoricalEventShapeTest(test.TestCase): def testTooSmall(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): param = array_ops.ones([1], dtype=np.float16) checked_param = du.embed_check_categorical_event_shape( @@ -295,7 +281,7 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase): checked_param.eval(feed_dict={param: np.ones([1])}) def testTooLarge(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(ValueError): param = array_ops.ones([int(2**11+1)], dtype=dtypes.float16) checked_param = du.embed_check_categorical_event_shape( @@ -310,18 +296,17 @@ class EmbedCheckCategoricalEventShapeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testUnsupportedDtype(self): - with self.test_session(): - param = ops.convert_to_tensor( - np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype), - dtype=dtypes.qint16) - with self.assertRaises(TypeError): - du.embed_check_categorical_event_shape(param) + param = ops.convert_to_tensor( + np.ones([2**11 + 1]).astype(dtypes.qint16.as_numpy_dtype), + dtype=dtypes.qint16) + with self.assertRaises(TypeError): + du.embed_check_categorical_event_shape(param) class EmbedCheckIntegerCastingClosedTest(test.TestCase): def testCorrectlyAssertsNonnegative(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Elements must be non-negative"): x = array_ops.placeholder(dtype=dtypes.float16) x_checked = du.embed_check_integer_casting_closed( @@ -329,7 +314,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase): x_checked.eval(feed_dict={x: np.array([1, -1], dtype=np.float16)}) def testCorrectlyAssersIntegerForm(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Elements must be int16-equivalent."): x = array_ops.placeholder(dtype=dtypes.float16) x_checked = du.embed_check_integer_casting_closed( @@ -337,7 +322,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase): x_checked.eval(feed_dict={x: np.array([1, 1.5], dtype=np.float16)}) def testCorrectlyAssertsLargestPossibleInteger(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Elements cannot exceed 32767."): x = array_ops.placeholder(dtype=dtypes.int32) x_checked = du.embed_check_integer_casting_closed( @@ -345,7 +330,7 @@ class EmbedCheckIntegerCastingClosedTest(test.TestCase): x_checked.eval(feed_dict={x: np.array([1, 2**15], dtype=np.int32)}) def testCorrectlyAssertsSmallestPossibleInteger(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Elements cannot be smaller than 0."): x = array_ops.placeholder(dtype=dtypes.int32) x_checked = du.embed_check_integer_casting_closed( @@ -365,29 +350,27 @@ class LogCombinationsTest(test.TestCase): log_combs = np.log(special.binom(n, k)) - with self.test_session(): - n = np.array(n, dtype=np.float32) - counts = [[1., 1], [2., 3], [4., 8], [11, 4]] - log_binom = du.log_combinations(n, counts) - self.assertEqual([4], log_binom.get_shape()) - self.assertAllClose(log_combs, self.evaluate(log_binom)) + n = np.array(n, dtype=np.float32) + counts = [[1., 1], [2., 3], [4., 8], [11, 4]] + log_binom = du.log_combinations(n, counts) + self.assertEqual([4], log_binom.get_shape()) + self.assertAllClose(log_combs, self.evaluate(log_binom)) def testLogCombinationsShape(self): # Shape [2, 2] n = [[2, 5], [12, 15]] - with self.test_session(): - n = np.array(n, dtype=np.float32) - # Shape [2, 2, 4] - counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]] - log_binom = du.log_combinations(n, counts) - self.assertEqual([2, 2], log_binom.get_shape()) + n = np.array(n, dtype=np.float32) + # Shape [2, 2, 4] + counts = [[[1., 1, 0, 0], [2., 2, 1, 0]], [[4., 4, 1, 3], [10, 1, 1, 4]]] + log_binom = du.log_combinations(n, counts) + self.assertEqual([2, 2], log_binom.get_shape()) class DynamicShapeTest(test.TestCase): def testSameDynamicShape(self): - with self.test_session(): + with self.cached_session(): scalar = constant_op.constant(2.0) scalar1 = array_ops.placeholder(dtype=dtypes.float32) @@ -497,22 +480,21 @@ class RotateTransposeTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testRollStatic(self): - with self.test_session(): - if context.executing_eagerly(): - error_message = r"Attempt to convert a value \(None\)" - else: - error_message = "None values not supported." - with self.assertRaisesRegexp(ValueError, error_message): - du.rotate_transpose(None, 1) - for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))): - for shift in np.arange(-5, 5): - y = du.rotate_transpose(x, shift) - self.assertAllEqual( - self._np_rotate_transpose(x, shift), self.evaluate(y)) - self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list()) + if context.executing_eagerly(): + error_message = r"Attempt to convert a value \(None\)" + else: + error_message = "None values not supported." + with self.assertRaisesRegexp(ValueError, error_message): + du.rotate_transpose(None, 1) + for x in (np.ones(1), np.ones((2, 1)), np.ones((3, 2, 1))): + for shift in np.arange(-5, 5): + y = du.rotate_transpose(x, shift) + self.assertAllEqual( + self._np_rotate_transpose(x, shift), self.evaluate(y)) + self.assertAllEqual(np.roll(x.shape, shift), y.get_shape().as_list()) def testRollDynamic(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtypes.float32) shift = array_ops.placeholder(dtypes.int32) for x_value in (np.ones( @@ -530,7 +512,7 @@ class RotateTransposeTest(test.TestCase): class PickVectorTest(test.TestCase): def testCorrectlyPicksVector(self): - with self.test_session(): + with self.cached_session(): x = np.arange(10, 12) y = np.arange(15, 18) self.assertAllEqual( @@ -568,19 +550,19 @@ class PreferStaticRankTest(test.TestCase): def testDynamicRankEndsUpBeingNonEmpty(self): x = array_ops.placeholder(np.float64, shape=None) rank = du.prefer_static_rank(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(2, rank.eval(feed_dict={x: np.zeros((2, 3))})) def testDynamicRankEndsUpBeingEmpty(self): x = array_ops.placeholder(np.int32, shape=None) rank = du.prefer_static_rank(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(1, rank.eval(feed_dict={x: []})) def testDynamicRankEndsUpBeingScalar(self): x = array_ops.placeholder(np.int32, shape=None) rank = du.prefer_static_rank(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(0, rank.eval(feed_dict={x: 1})) @@ -607,19 +589,19 @@ class PreferStaticShapeTest(test.TestCase): def testDynamicShapeEndsUpBeingNonEmpty(self): x = array_ops.placeholder(np.float64, shape=None) shape = du.prefer_static_shape(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual((2, 3), shape.eval(feed_dict={x: np.zeros((2, 3))})) def testDynamicShapeEndsUpBeingEmpty(self): x = array_ops.placeholder(np.int32, shape=None) shape = du.prefer_static_shape(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(np.array([0]), shape.eval(feed_dict={x: []})) def testDynamicShapeEndsUpBeingScalar(self): x = array_ops.placeholder(np.int32, shape=None) shape = du.prefer_static_shape(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(np.array([]), shape.eval(feed_dict={x: 1})) @@ -646,20 +628,20 @@ class PreferStaticValueTest(test.TestCase): def testDynamicValueEndsUpBeingNonEmpty(self): x = array_ops.placeholder(np.float64, shape=None) value = du.prefer_static_value(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(np.zeros((2, 3)), value.eval(feed_dict={x: np.zeros((2, 3))})) def testDynamicValueEndsUpBeingEmpty(self): x = array_ops.placeholder(np.int32, shape=None) value = du.prefer_static_value(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(np.array([]), value.eval(feed_dict={x: []})) def testDynamicValueEndsUpBeingScalar(self): x = array_ops.placeholder(np.int32, shape=None) value = du.prefer_static_value(x) - with self.test_session(): + with self.cached_session(): self.assertAllEqual(np.array(1), value.eval(feed_dict={x: 1})) @@ -691,7 +673,7 @@ class FillTriangularTest(test.TestCase): def _run_test(self, x_, use_deferred_shape=False, **kwargs): x_ = np.asarray(x_) - with self.test_session() as sess: + with self.cached_session() as sess: static_shape = None if use_deferred_shape else x_.shape x_pl = array_ops.placeholder_with_default(x_, shape=static_shape) # Add `zeros_like(x)` such that x's value and gradient are identical. We @@ -761,7 +743,7 @@ class FillTriangularInverseTest(FillTriangularTest): def _run_test(self, x_, use_deferred_shape=False, **kwargs): x_ = np.asarray(x_) - with self.test_session() as sess: + with self.cached_session() as sess: static_shape = None if use_deferred_shape else x_.shape x_pl = array_ops.placeholder_with_default(x_, shape=static_shape) zeros_like_x_pl = (x_pl * array_ops.stop_gradient(x_pl - 1.) @@ -795,7 +777,7 @@ class ReduceWeightedLogSumExp(test.TestCase): logx_ = np.array([[0., -1, 1000.], [0, 1, -1000.], [-5, 0, 5]]) - with self.test_session() as sess: + with self.cached_session() as sess: logx = constant_op.constant(logx_) expected = math_ops.reduce_logsumexp(logx, axis=-1) grad_expected = gradients_impl.gradients(expected, logx)[0] @@ -818,7 +800,7 @@ class ReduceWeightedLogSumExp(test.TestCase): [1, -2, 1], [1, 0, 1]]) expected, _ = self._reduce_weighted_logsumexp(logx_, w_, axis=-1) - with self.test_session() as sess: + with self.cached_session() as sess: logx = constant_op.constant(logx_) w = constant_op.constant(w_) actual, actual_sgn = du.reduce_weighted_logsumexp( @@ -836,7 +818,7 @@ class ReduceWeightedLogSumExp(test.TestCase): [1, 0, 1]]) expected, _ = self._reduce_weighted_logsumexp( logx_, w_, axis=-1, keep_dims=True) - with self.test_session() as sess: + with self.cached_session() as sess: logx = constant_op.constant(logx_) w = constant_op.constant(w_) actual, actual_sgn = du.reduce_weighted_logsumexp( @@ -848,7 +830,7 @@ class ReduceWeightedLogSumExp(test.TestCase): def testDocString(self): """This test verifies the correctness of the docstring examples.""" - with self.test_session(): + with self.cached_session(): x = constant_op.constant([[0., 0, 0], [0, 0, 0]]) @@ -952,7 +934,7 @@ class SoftplusTest(test.TestCase): use_gpu=True) def testGradient(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant( [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9], shape=[2, 5], @@ -968,7 +950,7 @@ class SoftplusTest(test.TestCase): self.assertLess(err, 1e-4) def testInverseSoftplusGradientNeverNan(self): - with self.test_session(): + with self.cached_session(): # Note that this range contains both zero and inf. x = constant_op.constant(np.logspace(-8, 6).astype(np.float16)) y = du.softplus_inverse(x) @@ -977,7 +959,7 @@ class SoftplusTest(test.TestCase): self.assertAllEqual(np.zeros_like(grads).astype(np.bool), np.isnan(grads)) def testInverseSoftplusGradientFinite(self): - with self.test_session(): + with self.cached_session(): # This range of x is all finite, and so is 1 / x. So the # gradient and its approximations should be finite as well. x = constant_op.constant(np.logspace(-4.8, 4.5).astype(np.float16)) -- GitLab From ffcbd466a04f6c65623882dd4657d2e558521bb9 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Tue, 4 Sep 2018 16:06:26 -0700 Subject: [PATCH 084/540] Internal change. PiperOrigin-RevId: 211542593 --- tensorflow/contrib/lite/RELEASE.md | 8 -------- tensorflow/contrib/lite/g3doc/README.md | 4 ---- tensorflow/contrib/lite/g3doc/api_docs/python/index.md | 10 ---------- 3 files changed, 22 deletions(-) delete mode 100644 tensorflow/contrib/lite/RELEASE.md delete mode 100644 tensorflow/contrib/lite/g3doc/README.md delete mode 100644 tensorflow/contrib/lite/g3doc/api_docs/python/index.md diff --git a/tensorflow/contrib/lite/RELEASE.md b/tensorflow/contrib/lite/RELEASE.md deleted file mode 100644 index 8fd63d5cee..0000000000 --- a/tensorflow/contrib/lite/RELEASE.md +++ /dev/null @@ -1,8 +0,0 @@ -# Release 0.1.7 - -* TensorFlow Lite 0.1.7 is based on tag `tflite-v0.1.7` (git commit - fa1db5eb0da85b5baccc2a46d534fdeb3bb473d0). -* To reproduce the iOS library, it's required to cherry pick git commit - f1f1d5172fe5bfeaeb2cf657ffc43ba744187bee to fix a dependency issue. -* The code is based on TensorFlow 1.8.0 release candidate and it's very close - to TensorFlow 1.8.0 release. diff --git a/tensorflow/contrib/lite/g3doc/README.md b/tensorflow/contrib/lite/g3doc/README.md deleted file mode 100644 index e3db478481..0000000000 --- a/tensorflow/contrib/lite/g3doc/README.md +++ /dev/null @@ -1,4 +0,0 @@ -This is a *work-in-progress* TF Lite subsite for: -https://www.tensorflow.org/mobile - -DO NOT PUBLISH diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md b/tensorflow/contrib/lite/g3doc/api_docs/python/index.md deleted file mode 100644 index 70031a3c3d..0000000000 --- a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md +++ /dev/null @@ -1,10 +0,0 @@ -Project: /mobile/_project.yaml -Book: /mobile/_book.yaml -page_type: reference - - - - -# All symbols in TensorFlow Lite - -TEMP PAGE -- GitLab From 7a2f0e251951fff033c57970a76d7339a79fc185 Mon Sep 17 00:00:00 2001 From: Pete Warden Date: Tue, 4 Sep 2018 16:09:43 -0700 Subject: [PATCH 085/540] Replace floating point functionality with integer alternative for microcontrollers PiperOrigin-RevId: 211543125 --- .../kernels/internal/quantization_util.cc | 210 ++++++++++++++++++ .../lite/kernels/internal/quantization_util.h | 38 ++++ .../internal/quantization_util_test.cc | 133 +++++++++++ 3 files changed, 381 insertions(+) diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index f882f9910e..544ef16ce1 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -23,6 +23,32 @@ limitations under the License. namespace tflite { +namespace { +// These constants are used to manipulate the binary representation of doubles. +// Double-precision binary64 floating point format is: +// Bit | 63 | 62-52 | 51-0 | +// | Sign | Exponent | Fraction | +// To avoid 64-bit integers as much as possible, I break this into high and +// low 32-bit chunks. High is: +// Bit | 31 | 30-20 | 19-0 | +// | Sign | Exponent | High Fraction | +// Low is: +// Bit | 31-0 | +// | Low Fraction | +// We then access the components through logical bit-wise operations to +// extract the parts needed, with the positions and masks derived from the +// layout shown above. +constexpr uint64_t kSignMask = 0x8000000000000000LL; +constexpr uint64_t kExponentMask = 0x7ff0000000000000LL; +constexpr int32_t kExponentShift = 52; +constexpr int32_t kExponentBias = 1023; +constexpr uint32_t kExponentIsBadNum = 0x7ff; +constexpr uint64_t kFractionMask = 0x000fffffffc00000LL; +constexpr uint32_t kFractionShift = 22; +constexpr uint32_t kFractionRoundingMask = 0x003fffff; +constexpr uint32_t kFractionRoundingThreshold = 0x00200000; +} // namespace + void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, int* shift) { if (double_multiplier == 0.) { @@ -30,8 +56,16 @@ void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, *shift = 0; return; } +#ifdef TFLITE_EMULATE_FLOAT + // If we're trying to avoid the use of floating-point instructions (for + // example on microcontrollers) then use an alternative implementation + // that only requires integer and bitwise operations. To enable this, you + // need to set the define during the build process for your platform. + int64_t q_fixed = IntegerFrExp(double_multiplier, shift); +#else // TFLITE_EMULATE_FLOAT const double q = std::frexp(double_multiplier, shift); auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); +#endif // TFLITE_EMULATE_FLOAT TFLITE_CHECK(q_fixed <= (1ll << 31)); if (q_fixed == (1ll << 31)) { q_fixed /= 2; @@ -60,6 +94,163 @@ void QuantizeMultiplierSmallerThanOneExp(double double_multiplier, *left_shift = shift; } +int64_t IntegerFrExp(double input, int* shift) { + // Make sure our assumptions about the double layout hold. + TFLITE_CHECK_EQ(8, sizeof(double)); + + // We want to access the bits of the input double value directly, which is + // tricky to do safely, so use a union to handle the casting. + union { + double double_value; + uint64_t double_as_uint; + } cast_union; + cast_union.double_value = input; + const uint64_t u = cast_union.double_as_uint; + + // If the bitfield is all zeros apart from the sign bit, this is a normalized + // zero value, so return standard values for this special case. + if ((u & ~kSignMask) == 0) { + *shift = 0; + return 0; + } + + // Deal with NaNs and Infs, which are always indicated with a fixed pattern in + // the exponent, and distinguished by whether the fractions are zero or + // non-zero. + const uint32_t exponent_part = ((u & kExponentMask) >> kExponentShift); + if (exponent_part == kExponentIsBadNum) { + *shift = std::numeric_limits::max(); + if (u & kFractionMask) { + // NaN, so just return zero (with the exponent set to INT_MAX). + return 0; + } else { + // Infinity, so return +/- INT_MAX. + if (u & kSignMask) { + return std::numeric_limits::min(); + } else { + return std::numeric_limits::max(); + } + } + } + + // The shift is fairly easy to extract from the high bits of the double value, + // just by masking it out and applying a bias. The std::frexp() implementation + // always returns values between 0.5 and 1.0 though, whereas the exponent + // assumes 1.0 to 2.0 is the standard range, so I add on one to match that + // interface. + *shift = (exponent_part - kExponentBias) + 1; + + // There's an implicit high bit in the double format definition, so make sure + // we include that at the top, and then reconstruct the rest of the fractional + // value from the remaining fragments. + int64_t fraction = 0x40000000 + ((u & kFractionMask) >> kFractionShift); + + // We're cutting off some bits at the bottom, so to exactly match the standard + // frexp implementation here we'll apply rounding by adding one to the least + // significant bit of the result if the discarded portion is over half of the + // maximum. + if ((u & kFractionRoundingMask) > kFractionRoundingThreshold) { + fraction += 1; + } + // Negate the fraction if the sign bit was set. + if (u & kSignMask) { + fraction *= -1; + } + + return fraction; +} + +double DoubleFromFractionAndShift(int64_t fraction, int shift) { + union { + double double_value; + uint64_t double_as_uint; + } result; + + // Detect NaNs and infinities. + if (shift == std::numeric_limits::max()) { + if (fraction == 0) { + return NAN; + } else if (fraction > 0) { + return INFINITY; + } else { + return -INFINITY; + } + } + + // Return a normalized zero for a zero fraction. + if (fraction == 0) { + result.double_as_uint = 0; + return result.double_value; + } + + bool is_negative = (fraction < 0); + int64_t encoded_fraction = is_negative ? -fraction : fraction; + int64_t encoded_shift = (shift - 1); + while (encoded_fraction < 0x40000000) { + encoded_fraction *= 2; + encoded_shift -= 1; + } + while (encoded_fraction > 0x80000000) { + encoded_fraction /= 2; + encoded_shift += 1; + } + encoded_fraction -= 0x40000000; + if (encoded_shift < -1022) { + encoded_shift = -1023; + } else if (encoded_shift > 1022) { + encoded_shift = 1023; + } + encoded_shift += kExponentBias; + uint64_t encoded_sign = is_negative ? kSignMask : 0; + result.double_as_uint = encoded_sign | (encoded_shift << kExponentShift) | + (encoded_fraction << kFractionShift); + return result.double_value; +} + +double IntegerDoubleMultiply(double a, double b) { + int a_shift; + const int64_t a_fraction = IntegerFrExp(a, &a_shift); + int b_shift; + const int64_t b_fraction = IntegerFrExp(b, &b_shift); + // Detect NaNs and infinities. + if (a_shift == std::numeric_limits::max() || + (b_shift == std::numeric_limits::max())) { + return NAN; + } + const int result_shift = a_shift + b_shift + 1; + const int64_t result_fraction = (a_fraction * b_fraction) >> 32; + return DoubleFromFractionAndShift(result_fraction, result_shift); +} + +int IntegerDoubleCompare(double a, double b) { + int a_shift; + const int64_t a_fraction = IntegerFrExp(a, &a_shift); + int b_shift; + const int64_t b_fraction = IntegerFrExp(b, &b_shift); + + // Detect NaNs and infinities. + if (a_shift == std::numeric_limits::max() || + (b_shift == std::numeric_limits::max())) { + return 1; + } + + if ((a_fraction == 0) && (b_fraction < 0)) { + return 1; + } else if ((a_fraction < 0) && (b_fraction == 0)) { + return -1; + } else if (a_shift < b_shift) { + return -1; + } else if (a_shift > b_shift) { + return 1; + } else if (a_fraction < b_fraction) { + return -1; + } else if (a_fraction > b_fraction) { + return 1; + } else { + return 0; + } +} + void PreprocessSoftmaxScaling(double beta, double input_scale, int input_integer_bits, int32_t* quantized_multiplier, int* left_shift) { @@ -72,8 +263,20 @@ void PreprocessSoftmaxScaling(double beta, double input_scale, // result is double equivalent of Q0.31 (actually with more precision). Thus // this generates a Q(input_integer_bits).(31-input_integer_bits) // representation. +#ifdef TFLITE_EMULATE_FLOAT + const double input_beta = IntegerDoubleMultiply(beta, input_scale); + int shift; + int64_t fraction = IntegerFrExp(input_beta, &shift); + shift += (31 - input_integer_bits); + double input_beta_real_multiplier = + DoubleFromFractionAndShift(fraction, shift); + if (IntegerDoubleCompare(input_beta_real_multiplier, (1ll << 31) - 1.0) > 0) { + input_beta_real_multiplier = (1ll << 31) - 1.0; + } +#else // TFLITE_EMULATE_FLOAT const double input_beta_real_multiplier = std::min( beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0); +#endif // TFLITE_EMULATE_FLOAT QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier, quantized_multiplier, left_shift); @@ -97,6 +300,12 @@ void PreprocessLogSoftmaxScalingExp(double beta, double input_scale, } int CalculateInputRadius(int input_integer_bits, int input_left_shift) { +#ifdef TFLITE_EMULATE_FLOAT + int64_t result = (1 << input_integer_bits) - 1; + result <<= (31 - input_integer_bits); + result >>= input_left_shift; + return result; +#else // TFLITE_EMULATE_FLOAT const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) * (1ll << (31 - input_integer_bits)) / (1ll << input_left_shift); @@ -104,6 +313,7 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift) { // After scaling the difference, the result would be at the maximum. Thus we // must ensure that our value has lower magnitude. return static_cast(std::floor(max_input_rescaled)); +#endif // TFLITE_EMULATE_FLOAT } void NudgeQuantizationRange(const float min, const float max, diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index 9ee4a47fbb..d74a1bac97 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -195,6 +195,44 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier, void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, int* shift); +// Splits a double input value into a returned fraction, and a shift value from +// the exponent, using only bitwise and integer operations to support +// microcontrollers and other environments without floating-point support. +// +// This is designed to be a replacement for how std::frexp() is used within the +// QuantizeMultiplier() function, and so has a different signature than the +// standard version, returning a 64-bit integer rather than a double. This +// result has a maximum value of 1<<31, with the fraction expressed as a +// proportion of that maximum. +// +// std::frexp() returns NaNs and infinities unmodified, but since we're +// returning integers that can't represent those values, instead we return +// a shift of std::numeric_limits::max() for all bad numbers, with an int64 +// result of 0 for NaNs, std:numeric_limits::max() for +INFINITY, and +// std::numeric_limits::min() for -INFINITY. Denormalized inputs will +// result in return values that end up truncating some bits at the end, +// reflecting the loss of precision inherent in denormalization. +int64_t IntegerFrExp(double input, int* shift); + +// Converts an integer fraction in the format produced by IntegerFrExp (where +// 0x40000000 is 1.0) and an exponent shift (between -1022 and +1022) into an +// IEEE binary64 double format result. The implementation uses only integer and +// bitwise operators, so no floating point hardware support or emulation is +// needed. This is here so quantized operations can run non-time-critical +// preparation calculations on microcontrollers and other platforms without +// float support. +double DoubleFromFractionAndShift(int64_t fraction, int shift); + +// Performs a multiplication of two numbers in double format, using only integer +// and bitwise instructions. This is aimed at supporting housekeeping functions +// for quantized operations on microcontrollers without floating-point hardware. +double IntegerDoubleMultiply(double a, double b); + +// Returns -1 if a is less than b, 0 if a and b are equal, and +1 if a is +// greater than b. It is implemented using only integer and logical instructions +// so that it can be easily run on microcontrollers for quantized operations. +int IntegerDoubleCompare(double a, double b); + // This first creates a multiplier in a double equivalent of // Q(input_integer_bits).(31-input_integer_bits) representation, with extra // precision in the double's fractional bits. It then splits the result into diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index 00fc3e91dc..14281f25c6 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -191,6 +191,139 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) { EXPECT_EQ(qp.zero_point, 255); } +TEST(QuantizationUtilTest, IntegerFrExp) { + int shift; + int64_t result = IntegerFrExp(0.0, &shift); + EXPECT_EQ(0, result); + EXPECT_EQ(0, shift); + + result = IntegerFrExp(1.0, &shift); + EXPECT_NEAR(0x40000000, result, 1); + EXPECT_EQ(1, shift); + + result = IntegerFrExp(0.25, &shift); + EXPECT_NEAR(0x40000000, result, 1); + EXPECT_EQ(-1, shift); + + result = IntegerFrExp(-1.0, &shift); + EXPECT_NEAR(-(1 << 30), result, 1); + EXPECT_EQ(1, shift); + + result = IntegerFrExp(123.45, &shift); + EXPECT_NEAR(2071147315, result, 1); + EXPECT_EQ(7, shift); + + result = IntegerFrExp(NAN, &shift); + EXPECT_NEAR(0, result, 1); + EXPECT_EQ(0x7fffffff, shift); + + result = IntegerFrExp(INFINITY, &shift); + EXPECT_NEAR(std::numeric_limits::max(), result, 1); + EXPECT_EQ(0x7fffffff, shift); + + result = IntegerFrExp(-INFINITY, &shift); + EXPECT_NEAR(std::numeric_limits::min(), result, 1); + EXPECT_EQ(0x7fffffff, shift); +} + +TEST(QuantizationUtilTest, IntegerFrExpVersusDouble) { + int shift; + int32_t result = IntegerFrExp(0.0, &shift); + EXPECT_EQ(result, 0); + EXPECT_EQ(shift, 0); + + int double_shift; + double double_result = std::frexp(0.0, &double_shift); + EXPECT_EQ(double_result, 0); + EXPECT_EQ(double_shift, 0); + + result = IntegerFrExp(1.0, &shift); + EXPECT_NEAR(result, 0x40000000, 1); + EXPECT_EQ(shift, 1); + double_result = std::frexp(1.0, &double_shift); + EXPECT_NEAR(double_result, 0.5, 1e-5); + EXPECT_EQ(double_shift, 1); + + result = IntegerFrExp(0.25, &shift); + EXPECT_NEAR(result, 0x40000000, 1); + EXPECT_EQ(shift, -1); + double_result = std::frexp(0.25, &double_shift); + EXPECT_NEAR(double_result, 0.5, 1e-5); + EXPECT_EQ(double_shift, -1); + + result = IntegerFrExp(-1.0, &shift); + EXPECT_NEAR(result, -(1 << 30), 1); + EXPECT_EQ(shift, 1); + double_result = std::frexp(-1.0, &double_shift); + EXPECT_NEAR(double_result, -0.5, 1e-5); + EXPECT_EQ(double_shift, 1); + + result = IntegerFrExp(123.45, &shift); + EXPECT_NEAR(result, (0.964453 * (1L << 31)), 1000); + EXPECT_EQ(shift, 7); + double_result = std::frexp(123.45, &double_shift); + EXPECT_NEAR(double_result, 0.964453, 1e-5); + EXPECT_EQ(double_shift, 7); +} + +TEST(QuantizationUtilTest, DoubleFromFractionAndShift) { + double result = DoubleFromFractionAndShift(0, 0); + EXPECT_EQ(0, result); + + result = DoubleFromFractionAndShift(0x40000000, 1); + EXPECT_NEAR(1.0, result, 1e-5); + + result = DoubleFromFractionAndShift(0x40000000, 2); + EXPECT_NEAR(2.0, result, 1e-5); + + int shift; + int64_t fraction = IntegerFrExp(3.0, &shift); + result = DoubleFromFractionAndShift(fraction, shift); + EXPECT_NEAR(3.0, result, 1e-5); + + fraction = IntegerFrExp(123.45, &shift); + result = DoubleFromFractionAndShift(fraction, shift); + EXPECT_NEAR(123.45, result, 1e-5); + + fraction = IntegerFrExp(-23.232323, &shift); + result = DoubleFromFractionAndShift(fraction, shift); + EXPECT_NEAR(-23.232323, result, 1e-5); + + fraction = IntegerFrExp(NAN, &shift); + result = DoubleFromFractionAndShift(fraction, shift); + EXPECT_TRUE(std::isnan(result)); + + fraction = IntegerFrExp(INFINITY, &shift); + result = DoubleFromFractionAndShift(fraction, shift); + EXPECT_FALSE(std::isfinite(result)); +} + +TEST(QuantizationUtilTest, IntegerDoubleMultiply) { + EXPECT_NEAR(1.0, IntegerDoubleMultiply(1.0, 1.0), 1e-5); + EXPECT_NEAR(2.0, IntegerDoubleMultiply(1.0, 2.0), 1e-5); + EXPECT_NEAR(2.0, IntegerDoubleMultiply(2.0, 1.0), 1e-5); + EXPECT_NEAR(4.0, IntegerDoubleMultiply(2.0, 2.0), 1e-5); + EXPECT_NEAR(0.5, IntegerDoubleMultiply(1.0, 0.5), 1e-5); + EXPECT_NEAR(0.25, IntegerDoubleMultiply(0.5, 0.5), 1e-5); + EXPECT_NEAR(-1.0, IntegerDoubleMultiply(1.0, -1.0), 1e-5); + EXPECT_NEAR(-1.0, IntegerDoubleMultiply(-1.0, 1.0), 1e-5); + EXPECT_NEAR(1.0, IntegerDoubleMultiply(-1.0, -1.0), 1e-5); + EXPECT_NEAR(15000000.0, IntegerDoubleMultiply(3000.0, 5000.0), 1e-5); + EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(NAN, 5000.0))); + EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(3000.0, NAN))); +} + +TEST(QuantizationUtilTest, IntegerDoubleCompare) { + EXPECT_EQ(-1, IntegerDoubleCompare(0.0, 1.0)); + EXPECT_EQ(1, IntegerDoubleCompare(1.0, 0.0)); + EXPECT_EQ(0, IntegerDoubleCompare(1.0, 1.0)); + EXPECT_EQ(0, IntegerDoubleCompare(0.0, 0.0)); + EXPECT_EQ(-1, IntegerDoubleCompare(-10.0, 10.0)); + EXPECT_EQ(1, IntegerDoubleCompare(123.45, 10.0)); + EXPECT_EQ(1, IntegerDoubleCompare(NAN, INFINITY)); + EXPECT_EQ(1, IntegerDoubleCompare(INFINITY, NAN)); +} + #ifdef GTEST_HAS_DEATH_TEST TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) { EXPECT_DEATH(ChooseQuantizationParams(10.0, -30.0), ""); -- GitLab From 84ada6e2ce3d830f5cf3490e30f408f7459d0eab Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 16:11:35 -0700 Subject: [PATCH 086/540] Fix flakiness in ConvolutionInPlaneTest.testVertConvWithBlankImage by switching from assertAllEqual to assertAllClose. PiperOrigin-RevId: 211543406 --- tensorflow/contrib/layers/python/layers/layers_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index eee90864b4..52c9c4f3be 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1288,7 +1288,7 @@ class ConvolutionInPlaneTest(test.TestCase): result = sess.run(vert_gradients) expected = np.zeros((1, 9, 10, 1)) - self.assertAllEqual(result, expected) + self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5) def testVertConvWithVaryingImage(self): image = np.asmatrix(('1.0 2.0 3.0;' '1.1 2.0 4.0;' '-4.3 0.0 8.9')) -- GitLab From 462f1871ee405ba7184a6d4c113d15b764e80324 Mon Sep 17 00:00:00 2001 From: Sergii Khomenko Date: Wed, 5 Sep 2018 01:06:21 +0200 Subject: [PATCH 087/540] Add an explicit reason for NotImplementedError on eager model save --- tensorflow/python/keras/engine/network.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index cd74e36e68..f8c23ed124 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -1355,7 +1355,9 @@ class Network(base_layer.Layer): ``` """ if not self._is_graph_network: - raise NotImplementedError + raise NotImplementedError( + 'Currently `save` requires model to be a graph network. Consider ' + 'using `save_weights`, in order to save the weights of the model.') from tensorflow.python.keras.models import save_model # pylint: disable=g-import-not-at-top save_model(self, filepath, overwrite, include_optimizer) -- GitLab From a2e3dcdb4f439f05592b3e4698cb25a28d85a3b7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 16:52:11 -0700 Subject: [PATCH 088/540] There were two different error reporting formats within TensorFlow: `{{key value}}` and `^^key:value^^`. This change consolidate these two format. PiperOrigin-RevId: 211550259 --- tensorflow/core/common_runtime/placer.cc | 52 +++++++------------ tensorflow/core/common_runtime/placer.h | 2 - tensorflow/core/common_runtime/placer_test.cc | 50 +++++------------- tensorflow/core/lib/core/errors.h | 4 +- tensorflow/core/protobuf/config.proto | 9 ++-- tensorflow/python/client/session.py | 4 +- .../python/framework/error_interpolation.py | 14 ++--- .../framework/error_interpolation_test.py | 22 ++++---- ...nsorflow.-config-proto.-experimental.pbtxt | 10 ++-- .../golden/v1/tensorflow.-config-proto.pbtxt | 10 ++-- ...nsorflow.-config-proto.-experimental.pbtxt | 10 ++-- .../golden/v2/tensorflow.-config-proto.pbtxt | 10 ++-- 12 files changed, 76 insertions(+), 121 deletions(-) diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index 7f3c25d81d..3b59995433 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -254,9 +254,11 @@ class ColocationGraph { old_root_member.device_name, allow_soft_placement_); if (!s.ok()) { - return errors::InvalidArgument("Cannot colocate nodes '", x.name(), - "' and '", y.name(), ": ", - s.error_message()); + return errors::InvalidArgument( + "Cannot colocate nodes ", + errors::FormatColocationNodeForError(x.name()), " and ", + errors::FormatColocationNodeForError(y.name()), ": ", + s.error_message()); } // Ensure that the common root has at least one supported device @@ -267,8 +269,10 @@ class ColocationGraph { old_root_member.supported_device_types); if (new_root_member.supported_device_types.empty()) { return errors::InvalidArgument( - "Cannot colocate nodes '", x.name(), "' and '", y.name(), - "' because no device type supports both of those nodes and the " + "Cannot colocate nodes ", + errors::FormatColocationNodeForError(x.name()), " and ", + errors::FormatColocationNodeForError(y.name()), + " because no device type supports both of those nodes and the " "other nodes colocated with them.", DebugInfo(x_root), DebugInfo(y_root)); } @@ -376,8 +380,9 @@ class ColocationGraph { // merged set device is different, so print both. return errors::InvalidArgument( "Could not satisfy explicit device specification '", - node->requested_device(), - "' because the node was colocated with a group of nodes that " + node->requested_device(), "' because the node ", + errors::FormatColocationNodeForError(node->name()), + " was colocated with a group of nodes that ", "required incompatible device '", DeviceNameUtils::ParsedNameToString( members_[node_root].device_name), @@ -809,10 +814,10 @@ Status Placer::Run() { std::vector* devices; Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { - return AttachDef(errors::InvalidArgument( - "Cannot assign a device for operation ", - RichNodeName(node), ": ", status.error_message()), - *node); + return AttachDef( + errors::InvalidArgument("Cannot assign a device for operation ", + node->name(), ": ", status.error_message()), + *node); } // Returns the first device in sorted devices list so we will always @@ -856,10 +861,10 @@ Status Placer::Run() { std::vector* devices; Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { - return AttachDef(errors::InvalidArgument( - "Cannot assign a device for operation ", - RichNodeName(node), ": ", status.error_message()), - *node); + return AttachDef( + errors::InvalidArgument("Cannot assign a device for operation ", + node->name(), ": ", status.error_message()), + *node); } int assigned_device = -1; @@ -925,21 +930,4 @@ void Placer::LogDeviceAssignment(const Node* node) const { } } -bool Placer::ClientHandlesErrorFormatting() const { - return options_ != nullptr && - options_->config.experimental().client_handles_error_formatting(); -} - -// Returns the node name in single quotes. If the client handles formatted -// errors, appends a formatting tag which the client will reformat into, for -// example, " (defined at filename:123)". -// TODO(shikharagarwal): Remove this function once -// client_handles_error_formatting flag is removed. -string Placer::RichNodeName(const Node* node) const { - if (ClientHandlesErrorFormatting()) { - return errors::FormatNodeNameForError(node->name()); - } - return strings::StrCat("'", node->name(), "'"); -} - } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h index cefcdd25db..f97ffe7372 100644 --- a/tensorflow/core/common_runtime/placer.h +++ b/tensorflow/core/common_runtime/placer.h @@ -87,8 +87,6 @@ class Placer { // placement if the SessionOptions entry in 'options_' requests it. void AssignAndLog(int assigned_device, Node* node) const; void LogDeviceAssignment(const Node* node) const; - bool ClientHandlesErrorFormatting() const; - string RichNodeName(const Node* node) const; Graph* const graph_; // Not owned. const DeviceSet* const devices_; // Not owned. diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 83d27e2730..9b8a95e3b6 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -800,11 +800,11 @@ TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) { } Status s = Place(&g); - EXPECT_TRUE( - str_util::StrContains(s.error_message(), - "Cannot colocate nodes 'foo' and 'in' because no " - "device type supports both of those nodes and the " - "other nodes colocated with them")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "Cannot colocate nodes {{colocation_node foo}} and " + "{{colocation_node in}} because no device type supports both of those " + "nodes and the other nodes colocated with them")); } TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) { @@ -867,9 +867,9 @@ TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) { Status s = Place(&g); EXPECT_TRUE(str_util::StrContains( s.error_message(), - "Cannot colocate nodes 'var3' and 'assign3' because no " - "device type supports both of those nodes and the other " - "nodes colocated with them.")); + "Cannot colocate nodes {{colocation_node var3}} and {{colocation_node " + "assign3}} because no device type supports both of those nodes and the " + "other nodes colocated with them.")); } TEST_F(PlacerTest, TestColocationAndReferenceConnections) { @@ -1154,35 +1154,12 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) { } SessionOptions options; - options.config.mutable_experimental()->set_client_handles_error_formatting( - true); Status s = Place(&g, &options); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); LOG(WARNING) << s.error_message(); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), "Cannot assign a device for operation {{node in}}")); -} - -// Test that the "Cannot assign a device" error message does not contain a -// format tag when not it shouldn't -TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementNoFormatTag) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestDevice", - b.opts().WithName("in").WithDevice("/device:fakegpu:11")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - options.config.mutable_experimental()->set_client_handles_error_formatting( - false); - Status s = Place(&g, &options); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), "Cannot assign a device for operation 'in'")); - EXPECT_FALSE(str_util::StrContains( - s.error_message(), "'in' (defined at ^^node:in:${file}:${line}^^)")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "Cannot assign a device for operation in")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "{{node in}}")); } // Test that placement fails when a node requests an explicit device that is not @@ -1288,8 +1265,9 @@ TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) { Status s = Place(&g); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), "Cannot colocate nodes 'var' and 'assign'")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "Cannot colocate nodes {{colocation_node " + "var}} and {{colocation_node assign}}")); } // Test that a generator node follows its consumers (where there are several diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h index 982901a39c..d5cbe6c616 100644 --- a/tensorflow/core/lib/core/errors.h +++ b/tensorflow/core/lib/core/errors.h @@ -136,11 +136,9 @@ string FormatNodeNamesForError(const T& names) { ::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s)); }); } -// TODO(b/113350742): Consolidate the two different formats `{{key value}}` and -// `^^key:value^^` in a follow-on CL. // LINT.IfChange inline string FormatColocationNodeForError(const string& name) { - return strings::StrCat("^^colocation_node:", name, "^^"); + return strings::StrCat("{{colocation_node ", name, "}}"); } // LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py) template diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index da3a99565e..625d5649e6 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -390,9 +390,12 @@ message ConfigProto { message Experimental { // Task name for group resolution. string collective_group_leader = 1; - // Whether the client will format templated errors. For example, the string: - // "The node was defined on ^^node:Foo:${file}:${line}^^". - bool client_handles_error_formatting = 2; + + // We removed the flag client_handles_error_formatting. Marking the tag + // number as reserved. + // TODO(shikharagarwal): Should we just remove this tag so that it can be + // used in future for other purpose? + reserved 2; // Which executor to use, the default executor will be used // if it is an empty string or "DEFAULT" diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 1841dd998b..e4273fe8a0 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1302,9 +1302,7 @@ class BaseSession(SessionInterface): node_def = op.node_def except KeyError: pass - if (self._config is not None and - self._config.experimental.client_handles_error_formatting): - message = error_interpolation.interpolate(message, self._graph) + message = error_interpolation.interpolate(message, self._graph) raise type(e)(node_def, op, message) def _extend_graph(self): diff --git a/tensorflow/python/framework/error_interpolation.py b/tensorflow/python/framework/error_interpolation.py index a69018d00d..46bda2e621 100644 --- a/tensorflow/python/framework/error_interpolation.py +++ b/tensorflow/python/framework/error_interpolation.py @@ -15,7 +15,7 @@ """Function for interpolating formatted errors from the TensorFlow runtime. Exposes the function `interpolate` to interpolate messages with tags of the form -^^type:name:format^^. +{{type name}}. """ from __future__ import absolute_import @@ -32,7 +32,7 @@ import six from tensorflow.python.util import tf_stack _NAME_REGEX = r"[A-Za-z0-9.][A-Za-z0-9_.\-/]*?" -_TAG_REGEX = r"\^\^({name}):({name})\^\^".format(name=_NAME_REGEX) +_TAG_REGEX = r"{{{{({name}) ({name})}}}}".format(name=_NAME_REGEX) _INTERPOLATION_REGEX = r"^(.*?)({tag})".format(tag=_TAG_REGEX) _INTERPOLATION_PATTERN = re.compile(_INTERPOLATION_REGEX) @@ -48,8 +48,8 @@ def _parse_message(message): """Parses the message. Splits the message into separators and tags. Tags are named tuples - representing the string ^^type:name^^ and they are separated by - separators. For example, in "123^^node:Foo^^456^^node:Bar^^789", there are + representing the string {{type name}} and they are separated by + separators. For example, in "123{{node Foo}}456{{node Bar}}789", there are two tags and three separators. The separators are the numeric characters. Args: @@ -58,7 +58,7 @@ def _parse_message(message): Returns: (list of separator strings, list of _ParseTags). - For example, if message is "123^^node:Foo^^456" then this function + For example, if message is "123{{node Foo}}456" then this function returns (["123", "456"], [_ParseTag("node", "Foo")]) """ seps = [] @@ -276,7 +276,7 @@ def interpolate(error_message, graph): message. Returns: - The string with tags of the form ^^type:name^^ interpolated. + The string with tags of the form {{type name}} interpolated. """ seps, tags = _parse_message(error_message) subs = [] @@ -288,7 +288,7 @@ def interpolate(error_message, graph): except KeyError: op = None - msg = "^^%s:%s^^" % (t.type, t.name) + msg = "{{%s %s}}" % (t.type, t.name) if op is not None: field_dict = compute_field_dict(op) if t.type == "node": diff --git a/tensorflow/python/framework/error_interpolation_test.py b/tensorflow/python/framework/error_interpolation_test.py index a7c7bbf28b..d312b825d2 100644 --- a/tensorflow/python/framework/error_interpolation_test.py +++ b/tensorflow/python/framework/error_interpolation_test.py @@ -167,20 +167,20 @@ class InterpolateFilenamesAndLineNumbersTest(test.TestCase): self.assertEqual(interpolated_string, normal_string) def testOneTagWithAFakeNameResultsInPlaceholders(self): - one_tag_string = "^^node:MinusOne^^" + one_tag_string = "{{node MinusOne}}" interpolated_string = error_interpolation.interpolate( one_tag_string, self.graph) self.assertEqual(one_tag_string, interpolated_string) def testTwoTagsNoSeps(self): - two_tags_no_seps = "^^node:One^^^^node:Three^^" + two_tags_no_seps = "{{node One}}{{node Three}}" interpolated_string = error_interpolation.interpolate( two_tags_no_seps, self.graph) self.assertRegexpMatches(interpolated_string, "constant_op.py:[0-9]+.*constant_op.py:[0-9]+") def testTwoTagsWithSeps(self): - two_tags_with_seps = ";;;^^node:Two^^,,,^^node:Three^^;;;" + two_tags_with_seps = ";;;{{node Two}},,,{{node Three}};;;" interpolated_string = error_interpolation.interpolate( two_tags_with_seps, self.graph) expected_regex = ( @@ -206,23 +206,23 @@ class InterpolateDeviceSummaryTest(test.TestCase): self.graph = self.three.graph def testNodeZeroHasNoDeviceSummaryInfo(self): - message = "^^colocation_node:zero^^" + message = "{{colocation_node zero}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("No device assignments were active", result) def testNodeOneHasExactlyOneInterpolatedDevice(self): - message = "^^colocation_node:one^^" + message = "{{colocation_node one}}" result = error_interpolation.interpolate(message, self.graph) self.assertEqual(2, result.count("tf.device(/cpu)")) def testNodeTwoHasTwoInterpolatedDevice(self): - message = "^^colocation_node:two^^" + message = "{{colocation_node two}}" result = error_interpolation.interpolate(message, self.graph) self.assertEqual(2, result.count("tf.device(/cpu)")) self.assertEqual(2, result.count("tf.device(/cpu:0)")) def testNodeThreeHasFancyFunctionDisplayNameForInterpolatedDevice(self): - message = "^^colocation_node:three^^" + message = "{{colocation_node three}}" result = error_interpolation.interpolate(message, self.graph) num_devices = result.count("tf.device") self.assertEqual(2, num_devices) @@ -256,12 +256,12 @@ class InterpolateColocationSummaryTest(test.TestCase): self.graph = node_three.graph def testNodeThreeHasColocationInterpolation(self): - message = "^^colocation_node:Three_with_one^^" + message = "{{colocation_node Three_with_one}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("colocate_with(One)", result) def testNodeFourHasColocationInterpolationForNodeThreeOnly(self): - message = "^^colocation_node:Four_with_three^^" + message = "{{colocation_node Four_with_three}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("colocate_with(Three_with_one)", result) self.assertNotIn( @@ -269,13 +269,13 @@ class InterpolateColocationSummaryTest(test.TestCase): "Node One should not appear in Four_with_three's summary:\n%s" % result) def testNodeFiveHasColocationInterpolationForNodeOneAndTwo(self): - message = "^^colocation_node:Five_with_one_with_two^^" + message = "{{colocation_node Five_with_one_with_two}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("colocate_with(One)", result) self.assertIn("colocate_with(Two)", result) def testColocationInterpolationForNodeLackingColocation(self): - message = "^^colocation_node:One^^" + message = "{{colocation_node One}}" result = error_interpolation.interpolate(message, self.graph) self.assertIn("No node-device colocations", result) self.assertNotIn("Two", result) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt index eb41deee13..9f6dcd8fdb 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt @@ -8,17 +8,15 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_STRING } - field { - name: "client_handles_error_formatting" - number: 2 - label: LABEL_OPTIONAL - type: TYPE_BOOL - } field { name: "executor_type" number: 3 label: LABEL_OPTIONAL type: TYPE_STRING } + reserved_range { + start: 2 + end: 3 + } } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt index e565b903d2..f3a515163d 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt @@ -131,18 +131,16 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_STRING } - field { - name: "client_handles_error_formatting" - number: 2 - label: LABEL_OPTIONAL - type: TYPE_BOOL - } field { name: "executor_type" number: 3 label: LABEL_OPTIONAL type: TYPE_STRING } + reserved_range { + start: 2 + end: 3 + } } } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt index eb41deee13..9f6dcd8fdb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.-experimental.pbtxt @@ -8,17 +8,15 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_STRING } - field { - name: "client_handles_error_formatting" - number: 2 - label: LABEL_OPTIONAL - type: TYPE_BOOL - } field { name: "executor_type" number: 3 label: LABEL_OPTIONAL type: TYPE_STRING } + reserved_range { + start: 2 + end: 3 + } } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt index e565b903d2..f3a515163d 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-config-proto.pbtxt @@ -131,18 +131,16 @@ tf_proto { label: LABEL_OPTIONAL type: TYPE_STRING } - field { - name: "client_handles_error_formatting" - number: 2 - label: LABEL_OPTIONAL - type: TYPE_BOOL - } field { name: "executor_type" number: 3 label: LABEL_OPTIONAL type: TYPE_STRING } + reserved_range { + start: 2 + end: 3 + } } } } -- GitLab From 964c1dfcc9e55fbaf9e31efd310385b6fe2563d7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 17:04:02 -0700 Subject: [PATCH 089/540] Add support for quantized (hybrid) bidirectional sequential LSTM Op. PiperOrigin-RevId: 211552101 --- .../kernels/bidirectional_sequence_lstm.cc | 699 ++++++++++++++---- 1 file changed, 546 insertions(+), 153 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index af47b33922..cde4f55a16 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -108,9 +108,26 @@ constexpr int kBwInputCellStateTensor = 38; constexpr int kFwOutputTensor = 0; constexpr int kBwOutputTensor = 1; +// Temporary tensors. +enum TemporaryTensor { + // Scratch buffers for input, forget, etc. gates + kFwScratchBuffer = 0, + kBwScratchBuffer = 1, + // Quantized tensors needed for the hybrid kernel. + kInputQuantized = 2, + kFwActivationStateQuantized = 3, + kBwActivationStateQuantized = 4, + kFwCellStateQuantized = 5, + kBwCellStateQuantized = 6, + kScalingFactors = 7, + kProductScalingFactors = 8, + kRecoveredCellWeights = 9, + kNumTemporaryTensors = 10 +}; + void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index); + context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index); return scratch_tensor_index; } @@ -131,7 +148,7 @@ TfLiteStatus CheckLstmTensorDimensions( int input_gate_bias_tensor, int forget_gate_bias_tensor, int cell_gate_bias_tensor, int output_gate_bias_tensor, int projection_weights_tensor, int projection_bias_tensor) { - auto* params = reinterpret_cast(node->builtin_data); + const auto* params = reinterpret_cast(node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -324,7 +341,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TF_LITE_ENSURE(context, input->dims->size > 1); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->dims->size, 3); const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; @@ -370,11 +388,19 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output, fw_output_size)); - // Create a scratch buffer tensor. + // The weights are of consistent type, so it suffices to check one. + const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8); + TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); - node->temporaries->data[0] = *scratch_tensor_index; - TfLiteTensor* fw_scratch_buffer = GetTemporary(context, node, /*index=*/0); + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors); + } else { + node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers. + } + // Create a scratch buffer tensor. + node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index; + TfLiteTensor* fw_scratch_buffer = + GetTemporary(context, node, kFwScratchBuffer); fw_scratch_buffer->type = input->type; fw_scratch_buffer->allocation_type = kTfLiteArenaRw; @@ -435,8 +461,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell); // Create a scratch buffer tensor. - node->temporaries->data[1] = *(scratch_tensor_index) + 1; - TfLiteTensor* bw_scratch_buffer = GetTemporary(context, node, /*index=*/1); + node->temporaries->data[kBwScratchBuffer] = + *(scratch_tensor_index) + kBwScratchBuffer; + TfLiteTensor* bw_scratch_buffer = + GetTemporary(context, node, kBwScratchBuffer); bw_scratch_buffer->type = input->type; bw_scratch_buffer->allocation_type = kTfLiteArenaRw; @@ -454,18 +482,441 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer, bw_scratch_buffer_size)); + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // output_state and cell_state tensors. + node->temporaries->data[kInputQuantized] = + *scratch_tensor_index + kInputQuantized; + TfLiteTensor* input_quantized = + GetTemporary(context, node, kInputQuantized); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + + node->temporaries->data[kFwActivationStateQuantized] = + *scratch_tensor_index + kFwActivationStateQuantized; + TfLiteTensor* fw_activation_state_quantized = + GetTemporary(context, node, kFwActivationStateQuantized); + fw_activation_state_quantized->type = kTfLiteUInt8; + fw_activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims, + fw_activation_state->dims)) { + TfLiteIntArray* fw_activation_state_quantized_size = + TfLiteIntArrayCopy(fw_activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, fw_activation_state_quantized, + fw_activation_state_quantized_size)); + } + node->temporaries->data[kBwActivationStateQuantized] = + *scratch_tensor_index + kBwActivationStateQuantized; + TfLiteTensor* bw_activation_state_quantized = + GetTemporary(context, node, kBwActivationStateQuantized); + bw_activation_state_quantized->type = kTfLiteUInt8; + bw_activation_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims, + bw_activation_state->dims)) { + TfLiteIntArray* bw_activation_state_quantized_size = + TfLiteIntArrayCopy(bw_activation_state->dims); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, bw_activation_state_quantized, + bw_activation_state_quantized_size)); + } + node->temporaries->data[kFwCellStateQuantized] = + *scratch_tensor_index + kFwCellStateQuantized; + TfLiteTensor* fw_cell_state_quantized = + GetTemporary(context, node, kFwCellStateQuantized); + fw_cell_state_quantized->type = kTfLiteUInt8; + fw_cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims, + fw_cell_state->dims)) { + TfLiteIntArray* fw_cell_state_quantized_size = + TfLiteIntArrayCopy(fw_cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, fw_cell_state_quantized, + fw_cell_state_quantized_size)); + } + node->temporaries->data[kBwCellStateQuantized] = + *scratch_tensor_index + kBwCellStateQuantized; + TfLiteTensor* bw_cell_state_quantized = + GetTemporary(context, node, kBwCellStateQuantized); + bw_cell_state_quantized->type = kTfLiteUInt8; + bw_cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims, + bw_cell_state->dims)) { + TfLiteIntArray* bw_cell_state_quantized_size = + TfLiteIntArrayCopy(bw_cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, bw_cell_state_quantized, + bw_cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[kScalingFactors] = + *scratch_tensor_index + kScalingFactors; + TfLiteTensor* scaling_factors = + GetTemporary(context, node, kScalingFactors); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[kProductScalingFactors] = + *scratch_tensor_index + kProductScalingFactors; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, kProductScalingFactors); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[kRecoveredCellWeights] = + *scratch_tensor_index + kRecoveredCellWeights; + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, kRecoveredCellWeights); + recovered_cell_weights->type = kTfLiteFloat32; + recovered_cell_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); + recovered_cell_weights_size->data[0] = n_fw_cell; + if (!TfLiteIntArrayEqual(recovered_cell_weights->dims, + recovered_cell_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_cell_weights, + recovered_cell_weights_size)); + } + } + return kTfLiteOk; +} + +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, bool forward_sequence, + TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + const float* input_to_input_weights_ptr = + (use_cifg) ? nullptr : input_to_input_weights->data.f; + const float* recurrent_to_input_weights_ptr = + (use_cifg) ? nullptr : recurrent_to_input_weights->data.f; + const float* input_gate_bias_ptr = + (use_cifg) ? nullptr : input_gate_bias->data.f; + const float* cell_to_input_weights_ptr = + (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr; + const float* cell_to_forget_weights_ptr = + (use_peephole) ? cell_to_forget_weights->data.f : nullptr; + const float* cell_to_output_weights_ptr = + (use_peephole) ? cell_to_output_weights->data.f : nullptr; + const float* projection_weights_ptr = + (projection_weights == nullptr) ? nullptr : projection_weights->data.f; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Loop through the sequence. + if (forward_sequence) { + for (int t = 0; t < max_time; t++) { + const float* input_ptr = input->data.f + t * n_batch * n_input; + float* output_ptr_time = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr, input_to_input_weights_ptr, + input_to_forget_weights->data.f, input_to_cell_weights->data.f, + input_to_output_weights->data.f, recurrent_to_input_weights_ptr, + recurrent_to_forget_weights->data.f, + recurrent_to_cell_weights->data.f, + recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, + cell_to_forget_weights_ptr, cell_to_output_weights_ptr, + input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, + output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, + params, n_batch, n_cell, n_input, n_output, activation_state->data.f, + cell_state->data.f, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, output_ptr_time); + } + } else { + // Loop through the sequence backwards. + for (int t = max_time - 1; t >= 0; t--) { + const float* input_ptr = input->data.f + t * n_batch * n_input; + float* output_ptr_time = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr, input_to_input_weights_ptr, + input_to_forget_weights->data.f, input_to_cell_weights->data.f, + input_to_output_weights->data.f, recurrent_to_input_weights_ptr, + recurrent_to_forget_weights->data.f, + recurrent_to_cell_weights->data.f, + recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, + cell_to_forget_weights_ptr, cell_to_output_weights_ptr, + input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, + output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, + params, n_batch, n_cell, n_input, n_output, activation_state->data.f, + cell_state->data.f, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, output_ptr_time); + } + } + return kTfLiteOk; +} + +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, bool forward_sequence, + TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, + TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, + TfLiteTensor* input_quantized, TfLiteTensor* output_state_quantized, + TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, + TfLiteTensor* cell_state, TfLiteTensor* output) { + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast(projection_weights->data.uint8); + const float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast(input_quantized->data.uint8); + int8_t* quantized_output_state_ptr = + reinterpret_cast(output_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + if (forward_sequence) { + // Feed the sequence into the LSTM step-by-step. + for (int t = 0; t < max_time; t++) { + const float* input_ptr = input->data.f + t * n_batch * n_input; + float* output_ptr = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, + projection_weights_scale, projection_bias_ptr, params, n_batch, + n_cell, n_input, n_output, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors_ptr, + prod_scaling_factors_ptr, recovered_cell_weights_ptr, + quantized_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, + output_ptr); + } + } else { + // Loop through the sequence backwards. + for (int t = max_time - 1; t >= 0; t--) { + const float* input_ptr = input->data.f + t * n_batch * n_input; + float* output_ptr = output->data.f + t * n_batch * n_output; + + kernel_utils::LstmStep( + input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, + projection_weights_scale, projection_bias_ptr, params, n_batch, + n_cell, n_input, n_output, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors_ptr, + prod_scaling_factors_ptr, recovered_cell_weights_ptr, + quantized_input_ptr, quantized_output_state_ptr, + quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, + output_ptr); + } + } + return kTfLiteOk; } // The LSTM Op engine. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); + const auto* params = reinterpret_cast(node->builtin_data); // Input tensor. const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const int max_time = input->dims->data[0]; - const int n_batch = input->dims->data[1]; - const int n_input = input->dims->data[2]; // Tensors for the forward cell. const TfLiteTensor* fw_input_to_input_weights = @@ -559,149 +1010,91 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetVariableInput(context, node, kBwInputCellStateTensor); TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); - // n_cell and n_output will be the same size when there is no projection. - const int n_fw_cell = fw_input_to_output_weights->dims->data[0]; - const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. - const bool fw_use_cifg = (fw_input_to_input_weights == nullptr); - const bool fw_use_peephole = (fw_cell_to_output_weights != nullptr); - - // Index the scratch buffers pointers to the global scratch buffer. TfLiteTensor* fw_scratch_buffer = - &context->tensors[node->temporaries->data[0]]; - float* fw_input_gate_scratch = nullptr; - float* fw_cell_scratch = nullptr; - float* fw_forget_gate_scratch = nullptr; - float* fw_output_gate_scratch = nullptr; - if (fw_use_cifg) { - fw_cell_scratch = fw_scratch_buffer->data.f; - fw_forget_gate_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch; - fw_output_gate_scratch = - fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch; - } else { - fw_input_gate_scratch = fw_scratch_buffer->data.f; - fw_cell_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch; - fw_forget_gate_scratch = - fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch; - fw_output_gate_scratch = - fw_scratch_buffer->data.f + 3 * n_fw_cell * n_batch; - } - - // Check optional tensors, the respective pointers can be null. - const float* fw_input_to_input_weights_ptr = - (fw_use_cifg) ? nullptr : fw_input_to_input_weights->data.f; - const float* fw_recurrent_to_input_weights_ptr = - (fw_use_cifg) ? nullptr : fw_recurrent_to_input_weights->data.f; - const float* fw_input_gate_bias_ptr = - (fw_use_cifg) ? nullptr : fw_input_gate_bias->data.f; - const float* fw_cell_to_input_weights_ptr = - (fw_use_peephole && !fw_use_cifg) ? fw_cell_to_input_weights->data.f - : nullptr; - const float* fw_cell_to_forget_weights_ptr = - (fw_use_peephole) ? fw_cell_to_forget_weights->data.f : nullptr; - const float* fw_cell_to_output_weights_ptr = - (fw_use_peephole) ? fw_cell_to_output_weights->data.f : nullptr; - const float* fw_projection_weights_ptr = (fw_projection_weights == nullptr) - ? nullptr - : fw_projection_weights->data.f; - const float* fw_projection_bias_ptr = - (fw_projection_bias == nullptr) ? nullptr : fw_projection_bias->data.f; - - // Loop through the sequence. - for (int t = 0; t < max_time; t++) { - const float* input_ptr_batch = input->data.f + t * n_batch * n_input; - float* output_ptr_time = fw_output->data.f + t * n_batch * n_fw_output; - - kernel_utils::LstmStep( - input_ptr_batch, fw_input_to_input_weights_ptr, - fw_input_to_forget_weights->data.f, fw_input_to_cell_weights->data.f, - fw_input_to_output_weights->data.f, fw_recurrent_to_input_weights_ptr, - fw_recurrent_to_forget_weights->data.f, - fw_recurrent_to_cell_weights->data.f, - fw_recurrent_to_output_weights->data.f, fw_cell_to_input_weights_ptr, - fw_cell_to_forget_weights_ptr, fw_cell_to_output_weights_ptr, - fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f, - fw_cell_bias->data.f, fw_output_gate_bias->data.f, - fw_projection_weights_ptr, fw_projection_bias_ptr, params, n_batch, - n_fw_cell, n_input, n_fw_output, fw_activation_state->data.f, - fw_cell_state->data.f, fw_input_gate_scratch, fw_forget_gate_scratch, - fw_cell_scratch, fw_output_gate_scratch, output_ptr_time); - } - - // n_cell and n_output will be the same size when there is no projection. - const int n_bw_cell = bw_input_to_output_weights->dims->data[0]; - const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1]; - - // Since we have already checked that weights are all there or none, we can - // check the existense of only one to the get the condition. - const bool bw_use_cifg = (bw_input_to_input_weights == nullptr); - const bool bw_use_peephole = (bw_cell_to_output_weights != nullptr); - - // Index the scratch buffers pointers to the global scratch buffer. + GetTemporary(context, node, kFwScratchBuffer); TfLiteTensor* bw_scratch_buffer = - &context->tensors[node->temporaries->data[1]]; - float* bw_input_gate_scratch = nullptr; - float* bw_cell_scratch = nullptr; - float* bw_forget_gate_scratch = nullptr; - float* bw_output_gate_scratch = nullptr; - if (bw_use_cifg) { - bw_cell_scratch = bw_scratch_buffer->data.f; - bw_forget_gate_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch; - bw_output_gate_scratch = - bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch; - } else { - bw_input_gate_scratch = bw_scratch_buffer->data.f; - bw_cell_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch; - bw_forget_gate_scratch = - bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch; - bw_output_gate_scratch = - bw_scratch_buffer->data.f + 3 * n_bw_cell * n_batch; + GetTemporary(context, node, kBwScratchBuffer); + + switch (fw_input_to_output_weights->type) { + case kTfLiteFloat32: { + TfLiteStatus fw_pass_status = EvalFloat( + input, fw_input_to_input_weights, fw_input_to_forget_weights, + fw_input_to_cell_weights, fw_input_to_output_weights, + fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, + fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, + fw_cell_to_input_weights, fw_cell_to_forget_weights, + fw_cell_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, + fw_cell_bias, fw_output_gate_bias, fw_projection_weights, + fw_projection_bias, params, /*forward_sequence=*/true, + fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output); + TF_LITE_ENSURE_OK(context, fw_pass_status); + + TfLiteStatus bw_pass_status = EvalFloat( + input, bw_input_to_input_weights, bw_input_to_forget_weights, + bw_input_to_cell_weights, bw_input_to_output_weights, + bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, + bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, + bw_cell_to_input_weights, bw_cell_to_forget_weights, + bw_cell_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, + bw_cell_bias, bw_output_gate_bias, bw_projection_weights, + bw_projection_bias, params, /*forward_sequence=*/false, + bw_scratch_buffer, bw_activation_state, bw_cell_state, bw_output); + TF_LITE_ENSURE_OK(context, bw_pass_status); + return kTfLiteOk; + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = + GetTemporary(context, node, kInputQuantized); + TfLiteTensor* fw_activation_state_quantized = + GetTemporary(context, node, kFwActivationStateQuantized); + TfLiteTensor* bw_activation_state_quantized = + GetTemporary(context, node, kBwActivationStateQuantized); + TfLiteTensor* fw_cell_state_quantized = + GetTemporary(context, node, kFwCellStateQuantized); + TfLiteTensor* bw_cell_state_quantized = + GetTemporary(context, node, kBwCellStateQuantized); + TfLiteTensor* scaling_factors = + GetTemporary(context, node, kScalingFactors); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, kProductScalingFactors); + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, kRecoveredCellWeights); + TfLiteStatus fw_pass_status = EvalHybrid( + input, fw_input_to_input_weights, fw_input_to_forget_weights, + fw_input_to_cell_weights, fw_input_to_output_weights, + fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, + fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, + fw_cell_to_input_weights, fw_cell_to_forget_weights, + fw_cell_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, + fw_cell_bias, fw_output_gate_bias, fw_projection_weights, + fw_projection_bias, params, /*forward_sequence=*/true, + fw_scratch_buffer, scaling_factors, prod_scaling_factors, + recovered_cell_weights, input_quantized, + fw_activation_state_quantized, fw_cell_state_quantized, + fw_activation_state, fw_cell_state, fw_output); + TF_LITE_ENSURE_OK(context, fw_pass_status); + + TfLiteStatus bw_pass_status = EvalHybrid( + input, bw_input_to_input_weights, bw_input_to_forget_weights, + bw_input_to_cell_weights, bw_input_to_output_weights, + bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, + bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, + bw_cell_to_input_weights, bw_cell_to_forget_weights, + bw_cell_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, + bw_cell_bias, bw_output_gate_bias, bw_projection_weights, + bw_projection_bias, params, /*forward_sequence=*/false, + bw_scratch_buffer, scaling_factors, prod_scaling_factors, + recovered_cell_weights, input_quantized, + bw_activation_state_quantized, bw_cell_state_quantized, + bw_activation_state, bw_cell_state, bw_output); + TF_LITE_ENSURE_OK(context, bw_pass_status); + return kTfLiteOk; + } + default: + context->ReportError(context, "Type %d is not currently supported.", + fw_input_to_output_weights->type); + return kTfLiteError; } - - // Check optional tensors, the respective pointers can be null. - const float* bw_input_to_input_weights_ptr = - (bw_use_cifg) ? nullptr : bw_input_to_input_weights->data.f; - const float* bw_recurrent_to_input_weights_ptr = - (bw_use_cifg) ? nullptr : bw_recurrent_to_input_weights->data.f; - const float* bw_input_gate_bias_ptr = - (bw_use_cifg) ? nullptr : bw_input_gate_bias->data.f; - const float* bw_cell_to_input_weights_ptr = - (bw_use_peephole && !bw_use_cifg) ? bw_cell_to_input_weights->data.f - : nullptr; - const float* bw_cell_to_forget_weights_ptr = - (bw_use_peephole) ? bw_cell_to_forget_weights->data.f : nullptr; - const float* bw_cell_to_output_weights_ptr = - (bw_use_peephole) ? bw_cell_to_output_weights->data.f : nullptr; - const float* bw_projection_weights_ptr = (bw_projection_weights == nullptr) - ? nullptr - : bw_projection_weights->data.f; - const float* bw_projection_bias_ptr = - (bw_projection_bias == nullptr) ? nullptr : bw_projection_bias->data.f; - - // Loop through the sequence backwards. - for (int t = max_time - 1; t >= 0; t--) { - const float* input_ptr_batch = input->data.f + t * n_batch * n_input; - float* output_ptr_time = bw_output->data.f + t * n_batch * n_bw_output; - - kernel_utils::LstmStep( - input_ptr_batch, bw_input_to_input_weights_ptr, - bw_input_to_forget_weights->data.f, bw_input_to_cell_weights->data.f, - bw_input_to_output_weights->data.f, bw_recurrent_to_input_weights_ptr, - bw_recurrent_to_forget_weights->data.f, - bw_recurrent_to_cell_weights->data.f, - bw_recurrent_to_output_weights->data.f, bw_cell_to_input_weights_ptr, - bw_cell_to_forget_weights_ptr, bw_cell_to_output_weights_ptr, - bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f, - bw_cell_bias->data.f, bw_output_gate_bias->data.f, - bw_projection_weights_ptr, bw_projection_bias_ptr, params, n_batch, - n_bw_cell, n_input, n_bw_output, bw_activation_state->data.f, - bw_cell_state->data.f, bw_input_gate_scratch, bw_forget_gate_scratch, - bw_cell_scratch, bw_output_gate_scratch, output_ptr_time); - } - - // Backward step. return kTfLiteOk; } -- GitLab From 9c7ca4c83b2e98517d0ccbba81b6b7fbc178d731 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=9C=A8=E5=8E=9F=E4=BD=90=E4=B8=BA?= Date: Wed, 5 Sep 2018 08:15:42 +0800 Subject: [PATCH 090/540] use ndims --- tensorflow/contrib/autograph/operators/slices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py index a885bdab5b..4b3f7ebee8 100644 --- a/tensorflow/contrib/autograph/operators/slices.py +++ b/tensorflow/contrib/autograph/operators/slices.py @@ -58,7 +58,7 @@ def get_item(target, i, opts): elif tensor_util.is_tensor(target): if target.dtype == dtypes.variant: return _tf_tensor_list_get_item(target, i, opts) - elif target.dtype == dtypes.string and target.get_shape() == (): # target is string with rank 0 + elif target.dtype == dtypes.string and target.shape.ndims == 0: # target is string with rank 0 return _tf_tensor_string_get_item(target, i) else: return _tf_tensor_get_item(target, i) -- GitLab From f3ee2c74e9e3a79266503f5c4275c919303fd568 Mon Sep 17 00:00:00 2001 From: Piotr Padlewski Date: Tue, 4 Sep 2018 17:15:21 -0700 Subject: [PATCH 091/540] Move GrapplerFunctionItem arguments. This patch uses take by value and move idiom to optimize copying of constructor arguments. PiperOrigin-RevId: 211553877 --- tensorflow/core/grappler/utils/functions.cc | 32 ++++++++++----------- tensorflow/core/grappler/utils/functions.h | 13 ++++----- 2 files changed, 22 insertions(+), 23 deletions(-) diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index a2c363ea6e..a428aea7f5 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -304,21 +304,21 @@ Status GrapplerFunctionItemInstantiation::GetArgType( } GrapplerFunctionItem::GrapplerFunctionItem( - const string& func_name, const string& description, - const AttrValueMap& func_attr, - const std::vector& input_arg_expansions, - const std::vector& output_arg_expansions, - const std::vector& keep_nodes, const int graph_def_version, - bool is_stateful, GraphDef&& function_body) - : description_(description), - func_attr_(func_attr), - input_arg_expansions_(input_arg_expansions), - output_arg_expansions_(output_arg_expansions), + string func_name, string description, AttrValueMap func_attr, + std::vector input_arg_expansions, + std::vector output_arg_expansions, + std::vector keep_nodes, const int graph_def_version, + const bool is_stateful, GraphDef&& function_body) + : description_(std::move(description)), + func_attr_(std::move(func_attr)), + input_arg_expansions_(std::move(input_arg_expansions)), + output_arg_expansions_(std::move(output_arg_expansions)), is_stateful_(is_stateful) { - id = func_name; - keep_ops = keep_nodes; - // Swap the graph body. - graph.Swap(&function_body); + // Move assign GrapplerItem members. + keep_ops = std::move(keep_nodes); + id = std::move(func_name); + graph = std::move(function_body); + graph.mutable_versions()->set_producer(graph_def_version); // Fill the feed nodes with input placeholders. for (const InputArgExpansion& input_arg : input_arg_expansions_) { @@ -598,8 +598,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, *item = GrapplerFunctionItem( /*func_name=*/signature.name(), /*description=*/signature.description(), /*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()), - inputs, outputs, keep_nodes, graph_def_version, is_stateful, - std::move(function_body)); + std::move(inputs), std::move(outputs), std::move(keep_nodes), + graph_def_version, is_stateful, std::move(function_body)); return Status::OK(); } diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index 61588ceb83..733caf325f 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -136,13 +136,12 @@ class GrapplerFunctionItemInstantiation { class GrapplerFunctionItem : public GrapplerItem { public: GrapplerFunctionItem() = default; - GrapplerFunctionItem( - const string& func_name, const string& description, - const AttrValueMap& func_attr, - const std::vector& input_arg_expansions, - const std::vector& output_arg_expansions, - const std::vector& keep_nodes, const int versions, - bool is_stateful, GraphDef&& function_body); + GrapplerFunctionItem(string func_name, string description, + AttrValueMap func_attr, + std::vector input_arg_expansions, + std::vector output_arg_expansions, + std::vector keep_nodes, int graph_def_version, + bool is_stateful, GraphDef&& function_body); const string& description() const; -- GitLab From 65899c10ab9a384670369257662c7c00fca12f19 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 4 Sep 2018 17:46:44 -0700 Subject: [PATCH 092/540] Fix compiler warnings in `DebugNanCountOp` and `DebugNumericSummaryOp`. PiperOrigin-RevId: 211557740 --- tensorflow/core/kernels/debug_ops.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h index 33ed5522d0..d705e82b0d 100644 --- a/tensorflow/core/kernels/debug_ops.h +++ b/tensorflow/core/kernels/debug_ops.h @@ -255,7 +255,7 @@ class DebugNanCountOp : public BaseDebugOp { TensorShape shape({1}); OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor)); output_tensor->vec()(0) = nan_count; - PublishTensor(*output_tensor); + OP_REQUIRES_OK(context, PublishTensor(*output_tensor)); } }; @@ -380,7 +380,7 @@ class DebugNumericSummaryOp : public BaseDebugOp { bool mute = mute_if_healthy_ && nan_count == 0 && negative_inf_count == 0 && positive_inf_count == 0; if (!mute) { - PublishTensor(*output_tensor); + OP_REQUIRES_OK(context, PublishTensor(*output_tensor)); } } -- GitLab From bfde272cf661d942b11877a8709739a09c5d41fd Mon Sep 17 00:00:00 2001 From: Michael Case Date: Tue, 4 Sep 2018 17:46:47 -0700 Subject: [PATCH 093/540] Disable variable partitioning from TPU DNN canned estimator. PiperOrigin-RevId: 211557743 --- tensorflow/python/estimator/canned/dnn.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index c08cf61220..1c0c4581c0 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -142,7 +142,7 @@ def _dnn_model_fn(features, dropout=None, input_layer_partitioner=None, config=None, - tpu_estimator_spec=False, + use_tpu=False, batch_norm=False): """Deep Neural Net model_fn. @@ -164,8 +164,8 @@ def _dnn_model_fn(features, input_layer_partitioner: Partitioner for input layer. Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. config: `RunConfig` object to configure the runtime settings. - tpu_estimator_spec: Whether to return a `_TPUEstimatorSpec` or - or `model_fn.EstimatorSpec` instance. + use_tpu: Whether to make a DNN model able to run on TPU. Will make function + return a `_TPUEstimatorSpec` instance and disable variable partitioning. batch_norm: Whether to use batch normalization after each hidden layer. Returns: @@ -182,13 +182,15 @@ def _dnn_model_fn(features, optimizer, learning_rate=_LEARNING_RATE) num_ps_replicas = config.num_ps_replicas if config else 0 - partitioner = partitioned_variables.min_max_variable_partitioner( - max_partitions=num_ps_replicas) + partitioner = (None if use_tpu else + partitioned_variables.min_max_variable_partitioner( + max_partitions=num_ps_replicas)) with variable_scope.variable_scope( 'dnn', values=tuple(six.itervalues(features)), partitioner=partitioner): input_layer_partitioner = input_layer_partitioner or ( + None if use_tpu else partitioned_variables.min_max_variable_partitioner( max_partitions=num_ps_replicas, min_slice_size=64 << 20)) @@ -203,7 +205,7 @@ def _dnn_model_fn(features, batch_norm=batch_norm) logits = logit_fn(features=features, mode=mode) - if tpu_estimator_spec: + if use_tpu: return head._create_tpu_estimator_spec( # pylint: disable=protected-access features=features, mode=mode, -- GitLab From fd28fee75f141345c3e862bc1115ff4a2b478eb0 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Tue, 4 Sep 2018 18:10:09 -0700 Subject: [PATCH 094/540] [XLA] Don't show trivial feature_group_count attributes If the feature_group_count is 1, don't bother showing it as it is not very informative and a very common scenario. This is consistent with the HloCustomCall's feature_group_count attribute. PiperOrigin-RevId: 211560372 --- tensorflow/compiler/xla/service/BUILD | 2 ++ .../compiler/xla/service/hlo_instructions.cc | 4 +++- .../compiler/xla/service/hlo_parser_test.cc | 21 ++++++++++++++++--- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 26b48cf419..f6cfac6537 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3289,6 +3289,8 @@ tf_cc_test( size = "small", srcs = ["hlo_parser_test.cc"], deps = [ + ":hlo", + ":hlo_casting_utils", ":hlo_matchers", ":hlo_parser", "//tensorflow/compiler/xla:window_util", diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index bed273149b..e3683aaec9 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1674,7 +1674,9 @@ std::vector HloConvolutionInstruction::ExtraAttributesToStringImpl( } extra.push_back(StrCat("dim_labels=", ConvolutionDimensionNumbersToString( convolution_dimension_numbers_))); - extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + if (feature_group_count_ != 1) { + extra.push_back(StrCat("feature_group_count=", feature_group_count_)); + } return extra; } diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 759789437c..0dfc0a4d1c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -19,6 +19,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -382,7 +384,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 %input = f32[1,2,1]{2,1,0} parameter(0) %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) %filter = f32[1,1,1]{2,1,0} parameter(1) - ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, feature_group_count=1, operand_precision={high,default} + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default} } )" @@ -395,7 +397,7 @@ R"(HloModule ConvolveR2_module ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { %input = f32[1,2]{1,0} parameter(0) %filter = f32[1,1]{1,0} parameter(1) - ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf, feature_group_count=1 + ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf } )" @@ -408,7 +410,7 @@ R"(HloModule ConvolveBackward_module ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { %input = f32[128,7,7,512]{0,3,2,1} parameter(0) %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) - ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f, feature_group_count=1 + ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f } )" @@ -1775,5 +1777,18 @@ TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) { ::testing::HasSubstr("Operand broadcast had no shape in HLO text")); } +TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { + const string text = + R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Convolution(op::Parameter(0), op::Parameter(1))); + auto* convolution = + Cast(computation->root_instruction()); + EXPECT_EQ(convolution->feature_group_count(), 1); +} + } // namespace } // namespace xla -- GitLab From e1ba7ee122d218dd39cd423b821078d36b5663d1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 18:32:06 -0700 Subject: [PATCH 095/540] Hardcode input range from output for relu PiperOrigin-RevId: 211562900 --- .../graph_transformations/hardcode_min_max.cc | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index 502de88f7c..3114fa93e8 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -63,6 +63,25 @@ bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) { return true; } +bool HardcodeInputMinMaxFromOutput(Model* model, Operator* op) { + auto& input = model->GetArray(op->inputs[0]); + if (input.minmax) { + const auto* minmax = input.minmax.get(); + if (minmax) { + return false; + } + } + auto& output = model->GetArray(op->outputs[0]); + if (output.minmax) { + const auto* minmax = model->GetArray(op->outputs[0]).minmax.get(); + if (minmax) { + input.GetOrCreateMinMax() = *minmax; + return true; + } + } + return false; +} + bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) { // Do not early return if the output already has min/max: // we may still need to adjust the inputs min/max. @@ -366,6 +385,16 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForL2Normalization(model, op); break; + case OperatorType::kRelu: + // For any normalization other than batch norm, the quantizations ranges + // before and after relu are expected to be known. Having a quantization + // op before relu would reduce the number of bits of precision for the + // activation in half. So we deduce the range before relu from that after + // the relu. This would eliminate the need for two fake quantization nodes + // and would not reduce the bits of precision available for activation. + changed = HardcodeInputMinMaxFromOutput(model, op); + break; + case OperatorType::kConcatenation: changed = HardcodeMinMaxForConcatenation(model, op); break; -- GitLab From 30db26a5f4983b248bd4565d08c59155ad8bb36c Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Tue, 4 Sep 2018 18:32:44 -0700 Subject: [PATCH 096/540] Test cleanups - Remove unnecessary use of test_session() in tests that run with eager execution enabled. - Use cached_session() instead of test_session() (self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement.) PiperOrigin-RevId: 211562969 --- .../python/kernel_tests/core_rnn_cell_test.py | 73 ++-- .../data/kernel_tests/iterator_ops_test.py | 87 ++-- tensorflow/python/eager/backprop_test.py | 8 +- .../python/kernel_tests/check_ops_test.py | 80 ++-- .../kernel_tests/functional_ops_test.py | 405 +++++++++--------- .../python/kernel_tests/list_ops_test.py | 12 +- .../python/kernel_tests/py_func_test.py | 87 ++-- .../resource_variable_ops_test.py | 39 +- tensorflow/python/kernel_tests/rnn_test.py | 18 +- 9 files changed, 388 insertions(+), 421 deletions(-) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 15ce9d1ce7..be0306cb07 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -48,7 +48,7 @@ Linear = core_rnn_cell._Linear # pylint: disable=invalid-name class RNNCellTest(test.TestCase): def testLinear(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(1.0)): x = array_ops.zeros([1, 2]) @@ -69,7 +69,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(len(variables_lib.trainable_variables()), 2) def testBasicRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -89,7 +89,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(res[0].shape, (1, 2)) def testBasicRNNCellNotTrainable(self): - with self.test_session() as sess: + with self.cached_session() as sess: def not_trainable_getter(getter, *args, **kwargs): kwargs["trainable"] = False @@ -116,7 +116,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(res[0].shape, (1, 2)) def testIndRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -137,7 +137,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(res[0].shape, (1, 2)) def testGRUCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -165,7 +165,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.156736, 0.156736]]) def testIndyGRUCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -193,7 +193,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.155127, 0.157328]]) def testSRUCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -208,7 +208,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.509682, 0.509682]]) def testSRUCellWithDiffSize(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -288,7 +288,7 @@ class RNNCellTest(test.TestCase): def testBasicLSTMCellDimension0Error(self): """Tests that dimension 0 in both(x and m) shape must be equal.""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): num_units = 2 @@ -309,7 +309,7 @@ class RNNCellTest(test.TestCase): def testBasicLSTMCellStateSizeError(self): """Tests that state_size must be num_units * 2.""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): num_units = 2 @@ -329,7 +329,7 @@ class RNNCellTest(test.TestCase): }) def testBasicLSTMCellStateTupleType(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -360,7 +360,7 @@ class RNNCellTest(test.TestCase): self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) def testBasicLSTMCellWithStateTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -459,7 +459,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(len(res), 2) def testLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 8 num_proj = 6 state_size = num_units + num_proj @@ -494,7 +494,7 @@ class RNNCellTest(test.TestCase): float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) def testLSTMCellVariables(self): - with self.test_session(): + with self.cached_session(): num_units = 8 num_proj = 6 state_size = num_units + num_proj @@ -517,7 +517,7 @@ class RNNCellTest(test.TestCase): "root/lstm_cell/projection/kernel") def testLSTMCellLayerNorm(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 num_proj = 3 batch_size = 1 @@ -562,22 +562,21 @@ class RNNCellTest(test.TestCase): rnn_cell_impl.DropoutWrapper, rnn_cell_impl.ResidualWrapper, lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: - with self.test_session(): - cell = rnn_cell_impl.BasicRNNCell(1) - wrapper = wrapper_type(cell) - wrapper(array_ops.ones([1, 1]), - state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) - self.evaluate([v.initializer for v in cell.variables]) - checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) - prefix = os.path.join(self.get_temp_dir(), "ckpt") - self.evaluate(cell._bias.assign([40.])) - save_path = checkpoint.save(prefix) - self.evaluate(cell._bias.assign([0.])) - checkpoint.restore(save_path).assert_consumed().run_restore_ops() - self.assertAllEqual([40.], self.evaluate(cell._bias)) + cell = rnn_cell_impl.BasicRNNCell(1) + wrapper = wrapper_type(cell) + wrapper(array_ops.ones([1, 1]), + state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) + self.evaluate([v.initializer for v in cell.variables]) + checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(cell._bias.assign([40.])) + save_path = checkpoint.save(prefix) + self.evaluate(cell._bias.assign([0.])) + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([40.], self.evaluate(cell._bias)) def testOutputProjectionWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -594,7 +593,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.231907, 0.231907]]) def testInputProjectionWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -612,7 +611,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) def testResidualWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -638,7 +637,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[2], res[3]) def testResidualWrapperWithSlice(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 5]) @@ -716,7 +715,7 @@ class RNNCellTest(test.TestCase): self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name]) def testEmbeddingWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 1], dtype=dtypes.int32) @@ -735,7 +734,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.17139, 0.17139]]) def testEmbeddingWrapperWithDynamicRnn(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope("root"): inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64) input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64) @@ -753,7 +752,7 @@ class RNNCellTest(test.TestCase): sess.run(outputs) def testMultiRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -770,7 +769,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]]) def testMultiRNNCellWithStateTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -809,7 +808,7 @@ class DropoutWrapperTest(test.TestCase): time_steps=None, parallel_iterations=None, **kwargs): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): if batch_size is None and time_steps is None: diff --git a/tensorflow/python/data/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/kernel_tests/iterator_ops_test.py index b0414ad655..671e5d4812 100644 --- a/tensorflow/python/data/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_ops_test.py @@ -91,7 +91,7 @@ class IteratorTest(test.TestCase): self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) @@ -117,7 +117,7 @@ class IteratorTest(test.TestCase): self.assertEqual([c.shape[1:] for c in components], [t.shape for t in get_next]) - with self.test_session() as sess: + with self.cached_session() as sess: for _ in range(14): for i in range(7): result = sess.run(get_next) @@ -208,7 +208,7 @@ class IteratorTest(test.TestCase): iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(next_element) @@ -216,7 +216,7 @@ class IteratorTest(test.TestCase): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): sess.run(next_element) - with self.test_session() as sess: + with self.cached_session() as sess: def consumer_thread(): with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): @@ -287,7 +287,7 @@ class IteratorTest(test.TestCase): .make_initializable_iterator()) get_next = iterator.get_next() - with self.test_session() as sess: + with self.cached_session() as sess: with self.assertRaisesRegexp(errors.FailedPreconditionError, "iterator has not been initialized"): sess.run(get_next) @@ -308,7 +308,7 @@ class IteratorTest(test.TestCase): self.assertEqual(dataset_4.output_types, iterator.output_types) self.assertEqual([None], iterator.output_shapes.as_list()) - with self.test_session() as sess: + with self.cached_session() as sess: # The iterator is initially uninitialized. with self.assertRaises(errors.FailedPreconditionError): sess.run(get_next) @@ -380,7 +380,7 @@ class IteratorTest(test.TestCase): self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) self.assertEqual([], feedable_iterator.output_shapes) - with self.test_session() as sess: + with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) @@ -436,7 +436,7 @@ class IteratorTest(test.TestCase): self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) self.assertEqual([], feedable_iterator.output_shapes) - with self.test_session() as sess: + with self.cached_session() as sess: iterator_3_handle = sess.run(iterator_3.string_handle()) iterator_4_handle = sess.run(iterator_4.string_handle()) @@ -524,7 +524,7 @@ class IteratorTest(test.TestCase): feedable_int_any = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32) - with self.test_session() as sess: + with self.cached_session() as sess: handle_int_scalar = sess.run( dataset_int_scalar.make_one_shot_iterator().string_handle()) handle_float_vector = sess.run( @@ -687,7 +687,7 @@ class IteratorTest(test.TestCase): f=_remote_fn, target=target_placeholder) - with self.test_session() as sess: + with self.cached_session() as sess: elem = sess.run( remote_op, feed_dict={ @@ -803,16 +803,15 @@ class IteratorCheckpointingTest(test.TestCase): get_next = iterator.get_next if context.executing_eagerly( ) else functools.partial(self.evaluate, iterator.get_next()) checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) - with self.test_session() as sess: - self.assertAllEqual([1, 4], get_next()) - save_path = checkpoint.save(checkpoint_prefix) - self.assertAllEqual([9, 16], get_next()) - self.assertAllEqual([25, 36], get_next()) - checkpoint.restore(save_path).run_restore_ops(sess) - self.assertAllEqual([9, 16], get_next()) - self.assertAllEqual([25, 36], get_next()) - with self.assertRaises(errors.OutOfRangeError): - get_next() + self.assertAllEqual([1, 4], get_next()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([9, 16], get_next()) + self.assertAllEqual([25, 36], get_next()) + checkpoint.restore(save_path).run_restore_ops() + self.assertAllEqual([9, 16], get_next()) + self.assertAllEqual([25, 36], get_next()) + with self.assertRaises(errors.OutOfRangeError): + get_next() @test_util.run_in_graph_and_eager_modes def testSaveRestoreMultipleIterator(self): @@ -833,19 +832,18 @@ class IteratorCheckpointingTest(test.TestCase): ) else functools.partial(self.evaluate, iterator_3.get_next()) checkpoint = checkpointable_utils.Checkpoint( iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3) - with self.test_session() as sess: - self.assertAllEqual([1, 4], get_next_1()) - self.assertAllEqual(0, get_next_3()) - self.assertAllEqual(1, get_next_3()) - self.assertAllEqual(2, get_next_3()) - save_path = checkpoint.save(checkpoint_prefix) - self.assertAllEqual([1, 4], get_next_2()) - self.assertAllEqual([9, 16], get_next_2()) - self.assertAllEqual(3, get_next_3()) - checkpoint.restore(save_path).run_restore_ops(sess) - self.assertAllEqual([9, 16], get_next_1()) - self.assertAllEqual([1, 4], get_next_2()) - self.assertAllEqual(3, get_next_3()) + self.assertAllEqual([1, 4], get_next_1()) + self.assertAllEqual(0, get_next_3()) + self.assertAllEqual(1, get_next_3()) + self.assertAllEqual(2, get_next_3()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual([1, 4], get_next_2()) + self.assertAllEqual([9, 16], get_next_2()) + self.assertAllEqual(3, get_next_3()) + checkpoint.restore(save_path).run_restore_ops() + self.assertAllEqual([9, 16], get_next_1()) + self.assertAllEqual([1, 4], get_next_2()) + self.assertAllEqual(3, get_next_3()) @test_util.run_in_graph_and_eager_modes def testRestoreExhaustedIterator(self): @@ -856,17 +854,16 @@ class IteratorCheckpointingTest(test.TestCase): get_next = iterator.get_next if context.executing_eagerly( ) else functools.partial(self.evaluate, iterator.get_next()) checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) - with self.test_session() as sess: - self.assertAllEqual(0, get_next()) - self.assertAllEqual(1, get_next()) - save_path = checkpoint.save(checkpoint_prefix) - self.assertAllEqual(2, get_next()) - checkpoint.restore(save_path).run_restore_ops(sess) - self.assertAllEqual(2, get_next()) - save_path = checkpoint.save(checkpoint_prefix) - checkpoint.restore(save_path).run_restore_ops(sess) - with self.assertRaises(errors.OutOfRangeError): - get_next() + self.assertAllEqual(0, get_next()) + self.assertAllEqual(1, get_next()) + save_path = checkpoint.save(checkpoint_prefix) + self.assertAllEqual(2, get_next()) + checkpoint.restore(save_path).run_restore_ops() + self.assertAllEqual(2, get_next()) + save_path = checkpoint.save(checkpoint_prefix) + checkpoint.restore(save_path).run_restore_ops() + with self.assertRaises(errors.OutOfRangeError): + get_next() def testRestoreInReconstructedIteratorInitializable(self): checkpoint_directory = self.get_temp_dir() @@ -876,7 +873,7 @@ class IteratorCheckpointingTest(test.TestCase): get_next = iterator.get_next() checkpoint = checkpointable_utils.Checkpoint(iterator=iterator) for i in range(5): - with self.test_session() as sess: + with self.cached_session() as sess: checkpoint.restore(checkpoint_management.latest_checkpoint( checkpoint_directory)).initialize_or_restore(sess) for j in range(2): diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index caf36b6a36..6673178ee7 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -64,7 +64,7 @@ class BackpropTest(test.TestCase): grad = backprop.gradients_function(fn, [0])(var)[0] grad = self.evaluate(ops.convert_to_tensor(grad)) - with context.graph_mode(), self.test_session(): + with context.graph_mode(): tf_var = array_ops.constant(var_np, dtypes.float32) tf_ind1 = array_ops.constant([0, 1]) tf_ind2 = array_ops.constant([2, 3]) @@ -79,7 +79,7 @@ class BackpropTest(test.TestCase): tf_dense_grad = math_ops.unsorted_segment_sum( tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0]) - self.assertAllClose(grad, tf_dense_grad.eval()) + self.assertAllClose(grad, self.evaluate(tf_dense_grad)) def testImplicitGradWithResourceVariable(self): x = resource_variable_ops.ResourceVariable( @@ -198,7 +198,7 @@ class BackpropTest(test.TestCase): grad = backprop.implicit_grad(f)()[0][0] opt = training.GradientDescentOptimizer(lrn_rate) - with context.graph_mode(), self.test_session(): + with context.graph_mode(), self.cached_session(): tf_x = array_ops.ones((batch_size), dtypes.int64) # TODO(ashankar,apassos): Change to ResourceVariable. tf_embedding = variables.Variable( @@ -941,7 +941,7 @@ class BackpropTest(test.TestCase): def testZerosCacheDoesntLeakAcrossGraphs(self): with context.graph_mode(): def get_grad(): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): t = constant_op.constant(1, dtype=dtypes.float32, shape=(10, 4)) x = constant_op.constant(2, dtype=dtypes.float32, shape=(10, 4)) with backprop.GradientTape() as tape: diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 05f998d0d2..680d0c97cc 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -116,7 +116,7 @@ class AssertEqualTest(test.TestCase): check_ops.assert_equal(static_big, static_small, message="fail") def test_raises_when_greater_dynamic(self): - with self.test_session(): + with self.cached_session(): small = array_ops.placeholder(dtypes.int32, name="small") big = array_ops.placeholder(dtypes.int32, name="big") with ops.control_dependencies( @@ -194,7 +194,7 @@ First 2 elements of y: check_ops.assert_equal(static_big, static_small, message="fail") def test_raises_when_less_dynamic(self): - with self.test_session(): + with self.cached_session(): small = array_ops.placeholder(dtypes.int32, name="small") big = array_ops.placeholder(dtypes.int32, name="big") with ops.control_dependencies([check_ops.assert_equal(small, big)]): @@ -271,30 +271,28 @@ class AssertNoneEqualTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def test_raises_when_not_equal_but_non_broadcastable_shapes(self): - with self.test_session(): - small = constant_op.constant([1, 1, 1], name="small") - big = constant_op.constant([10, 10], name="big") - # The exception in eager and non-eager mode is different because - # eager mode relies on shape check done as part of the C++ op, while - # graph mode does shape checks when creating the `Operation` instance. - with self.assertRaisesRegexp( - (ValueError, errors.InvalidArgumentError), - (r"Incompatible shapes: \[3\] vs. \[2\]|" - r"Dimensions must be equal, but are 3 and 2")): - with ops.control_dependencies( - [check_ops.assert_none_equal(small, big)]): - out = array_ops.identity(small) - self.evaluate(out) + small = constant_op.constant([1, 1, 1], name="small") + big = constant_op.constant([10, 10], name="big") + # The exception in eager and non-eager mode is different because + # eager mode relies on shape check done as part of the C++ op, while + # graph mode does shape checks when creating the `Operation` instance. + with self.assertRaisesRegexp( + (ValueError, errors.InvalidArgumentError), + (r"Incompatible shapes: \[3\] vs. \[2\]|" + r"Dimensions must be equal, but are 3 and 2")): + with ops.control_dependencies( + [check_ops.assert_none_equal(small, big)]): + out = array_ops.identity(small) + self.evaluate(out) @test_util.run_in_graph_and_eager_modes def test_doesnt_raise_when_both_empty(self): - with self.test_session(): - larry = constant_op.constant([]) - curly = constant_op.constant([]) - with ops.control_dependencies( - [check_ops.assert_none_equal(larry, curly)]): - out = array_ops.identity(larry) - self.evaluate(out) + larry = constant_op.constant([]) + curly = constant_op.constant([]) + with ops.control_dependencies( + [check_ops.assert_none_equal(larry, curly)]): + out = array_ops.identity(larry) + self.evaluate(out) def test_returns_none_with_eager(self): with context.eager_mode(): @@ -905,7 +903,7 @@ class AssertRankTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 1 with ops.control_dependencies( @@ -923,7 +921,7 @@ class AssertRankTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 0 with ops.control_dependencies( @@ -940,7 +938,7 @@ class AssertRankTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_one_tensor_raises_if_rank_too_large_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 0 with ops.control_dependencies( @@ -957,7 +955,7 @@ class AssertRankTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 1 with ops.control_dependencies( @@ -974,7 +972,7 @@ class AssertRankTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 2 with ops.control_dependencies( @@ -989,7 +987,7 @@ class AssertRankTest(test.TestCase): check_ops.assert_rank(tensor, np.array([], dtype=np.int32)) def test_raises_if_rank_is_not_scalar_dynamic(self): - with self.test_session(): + with self.cached_session(): tensor = constant_op.constant( [1, 2], dtype=dtypes.float32, name="my_tensor") rank_tensor = array_ops.placeholder(dtypes.int32, name="rank_tensor") @@ -1006,7 +1004,7 @@ class AssertRankTest(test.TestCase): check_ops.assert_rank(tensor, .5) def test_raises_if_rank_is_not_integer_dynamic(self): - with self.test_session(): + with self.cached_session(): tensor = constant_op.constant( [1, 2], dtype=dtypes.float32, name="my_tensor") rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor") @@ -1029,7 +1027,7 @@ class AssertRankInTest(test.TestCase): self.evaluate(array_ops.identity(tensor_rank0)) def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor") with ops.control_dependencies([ check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]): @@ -1045,7 +1043,7 @@ class AssertRankInTest(test.TestCase): self.evaluate(array_ops.identity(tensor_rank0)) def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor") for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): with ops.control_dependencies([ @@ -1061,7 +1059,7 @@ class AssertRankInTest(test.TestCase): self.evaluate(array_ops.identity(tensor_rank1)) def test_rank_one_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor") for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): with ops.control_dependencies([ @@ -1079,7 +1077,7 @@ class AssertRankInTest(test.TestCase): self.evaluate(array_ops.identity(tensor_rank1)) def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor") with ops.control_dependencies([ check_ops.assert_rank_in(tensor_rank1, (0, 2))]): @@ -1098,7 +1096,7 @@ class AssertRankInTest(test.TestCase): check_ops.assert_rank_in(tensor, desired_ranks) def test_raises_if_rank_is_not_scalar_dynamic(self): - with self.test_session(): + with self.cached_session(): tensor = constant_op.constant( (42, 43), dtype=dtypes.float32, name="my_tensor") desired_ranks = ( @@ -1120,7 +1118,7 @@ class AssertRankInTest(test.TestCase): check_ops.assert_rank_in(tensor, (1, .5,)) def test_raises_if_rank_is_not_integer_dynamic(self): - with self.test_session(): + with self.cached_session(): tensor = constant_op.constant( (42, 43), dtype=dtypes.float32, name="my_tensor") rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor") @@ -1143,7 +1141,7 @@ class AssertRankAtLeastTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 1 with ops.control_dependencies( @@ -1160,7 +1158,7 @@ class AssertRankAtLeastTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 0 with ops.control_dependencies( @@ -1176,7 +1174,7 @@ class AssertRankAtLeastTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_one_ten_doesnt_raise_if_rank_too_large_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 0 with ops.control_dependencies( @@ -1192,7 +1190,7 @@ class AssertRankAtLeastTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 1 with ops.control_dependencies( @@ -1209,7 +1207,7 @@ class AssertRankAtLeastTest(test.TestCase): self.evaluate(array_ops.identity(tensor)) def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self): - with self.test_session(): + with self.cached_session(): tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") desired_rank = 2 with ops.control_dependencies( diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 7739b13143..3ddb5e06c9 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -59,39 +59,36 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testFoldl_Simple(self): - with self.test_session(): - elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") + elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") - r = functional_ops.foldl( - lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), - elems) - self.assertAllEqual(208, self.evaluate(r)) + r = functional_ops.foldl( + lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), + elems) + self.assertAllEqual(208, self.evaluate(r)) - r = functional_ops.foldl( - lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), - elems, - initializer=10) - self.assertAllEqual(880, self.evaluate(r)) + r = functional_ops.foldl( + lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), + elems, + initializer=10) + self.assertAllEqual(880, self.evaluate(r)) @test_util.run_in_graph_and_eager_modes def testFoldl_SingleInputMultiOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = np.array([1, -1.0]) - r = functional_ops.foldl(lambda a, x: a + x, elems, initializer) - r_value = self.evaluate(r) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = np.array([1, -1.0]) + r = functional_ops.foldl(lambda a, x: a + x, elems, initializer) + r_value = self.evaluate(r) - self.assertAllEqual(22, r_value[0]) - self.assertAllEqual(20, r_value[1]) + self.assertAllEqual(22, r_value[0]) + self.assertAllEqual(20, r_value[1]) @test_util.run_in_graph_and_eager_modes def testFoldl_MultiInputSingleOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = np.array(1.0) - r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems), - initializer) - self.assertAllEqual(1, self.evaluate(r)) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = np.array(1.0) + r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems), + initializer) + self.assertAllEqual(1, self.evaluate(r)) @test_util.run_in_graph_and_eager_modes def testFoldl_MultiInputDifferentDimsSingleOutput(self): @@ -103,7 +100,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual([1.0, 2.0, 3.0], self.evaluate(r)) def testFoldl_Scoped(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope("root") as varscope: elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") @@ -123,42 +120,39 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testFoldr_Simple(self): - with self.test_session(): - elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") + elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") - r = functional_ops.foldr( - lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), - elems) - self.assertAllEqual(450, self.evaluate(r)) + r = functional_ops.foldr( + lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), + elems) + self.assertAllEqual(450, self.evaluate(r)) - r = functional_ops.foldr( - lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), - elems, - initializer=10) - self.assertAllEqual(1282, self.evaluate(r)) + r = functional_ops.foldr( + lambda a, x: math_ops.multiply(math_ops.add(a, x), 2), + elems, + initializer=10) + self.assertAllEqual(1282, self.evaluate(r)) @test_util.run_in_graph_and_eager_modes def testFoldr_SingleInputMultiOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = np.array([1, -1.0]) - r = functional_ops.foldr(lambda a, x: a + x, elems, initializer) - r_value = self.evaluate(r) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = np.array([1, -1.0]) + r = functional_ops.foldr(lambda a, x: a + x, elems, initializer) + r_value = self.evaluate(r) - self.assertAllEqual(22, r_value[0]) - self.assertAllEqual(20, r_value[1]) + self.assertAllEqual(22, r_value[0]) + self.assertAllEqual(20, r_value[1]) @test_util.run_in_graph_and_eager_modes def testFoldr_MultiInputSingleOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = np.array(1.0) - r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems), - initializer) - self.assertAllEqual(1, self.evaluate(r)) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = np.array(1.0) + r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems), + initializer) + self.assertAllEqual(1, self.evaluate(r)) def testFoldr_Scoped(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope("root") as varscope: elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") @@ -178,7 +172,7 @@ class FunctionalOpsTest(test.TestCase): # pylint: disable=unnecessary-lambda def testFold_Grad(self): - with self.test_session(): + with self.cached_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") v = constant_op.constant(2.0, name="v") r = functional_ops.foldl( @@ -194,16 +188,15 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testMap_Simple(self): - with self.test_session(): - nums = [1, 2, 3, 4, 5, 6] - elems = constant_op.constant(nums, name="data") - r = functional_ops.map_fn( - lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems) - self.assertAllEqual( - np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) + nums = [1, 2, 3, 4, 5, 6] + elems = constant_op.constant(nums, name="data") + r = functional_ops.map_fn( + lambda x: math_ops.multiply(math_ops.add(x, 3), 2), elems) + self.assertAllEqual( + np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) def testMapSparseTensor(self): - with self.test_session(): + with self.cached_session(): with self.assertRaises(TypeError): functional_ops.map_fn( lambda x: x, @@ -220,7 +213,7 @@ class FunctionalOpsTest(test.TestCase): functional_ops.map_fn(lambda x: x, 1) def testMap_Scoped(self): - with self.test_session() as sess: + with self.cached_session() as sess: def double_scoped(x): """2x with a dummy 2 that is scoped.""" @@ -251,7 +244,7 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(doubles, self.evaluate(r)) def testMap_Grad(self): - with self.test_session(): + with self.cached_session(): param = constant_op.constant(2.0) elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems") y = functional_ops.map_fn( @@ -263,142 +256,131 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testMap_SimpleNotTensor(self): - with self.test_session(): - nums = np.array([1, 2, 3, 4, 5, 6]) - r = functional_ops.map_fn( - lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums) - self.assertAllEqual( - np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) + nums = np.array([1, 2, 3, 4, 5, 6]) + r = functional_ops.map_fn( + lambda x: math_ops.multiply(math_ops.add(x, 3), 2), nums) + self.assertAllEqual( + np.array([(x + 3) * 2 for x in nums]), self.evaluate(r)) @test_util.run_in_graph_and_eager_modes def testMap_SingleInputMultiOutput(self): - with self.test_session(): - nums = np.array([1, 2, 3, 4, 5, 6]) - r = functional_ops.map_fn( - lambda x: ((x + 3) * 2, -(x + 3) * 2), - nums, - dtype=(dtypes.int64, dtypes.int64)) - self.assertEqual(2, len(r)) - self.assertEqual((6,), r[0].get_shape()) - self.assertEqual((6,), r[1].get_shape()) - received = self.evaluate(r) - self.assertAllEqual((nums + 3) * 2, received[0]) - self.assertAllEqual(-(nums + 3) * 2, received[1]) + nums = np.array([1, 2, 3, 4, 5, 6]) + r = functional_ops.map_fn( + lambda x: ((x + 3) * 2, -(x + 3) * 2), + nums, + dtype=(dtypes.int64, dtypes.int64)) + self.assertEqual(2, len(r)) + self.assertEqual((6,), r[0].get_shape()) + self.assertEqual((6,), r[1].get_shape()) + received = self.evaluate(r) + self.assertAllEqual((nums + 3) * 2, received[0]) + self.assertAllEqual(-(nums + 3) * 2, received[1]) @test_util.run_in_graph_and_eager_modes def testMap_MultiOutputMismatchedDtype(self): - with self.test_session(): - nums = np.array([1, 2, 3, 4, 5, 6]) - with self.assertRaisesRegexp( - TypeError, r"two structures don't have the same nested structure"): - # lambda emits tuple, but dtype is a list - functional_ops.map_fn( - lambda x: ((x + 3) * 2, -(x + 3) * 2), - nums, - dtype=[dtypes.int64, dtypes.int64]) + nums = np.array([1, 2, 3, 4, 5, 6]) + with self.assertRaisesRegexp( + TypeError, r"two structures don't have the same nested structure"): + # lambda emits tuple, but dtype is a list + functional_ops.map_fn( + lambda x: ((x + 3) * 2, -(x + 3) * 2), + nums, + dtype=[dtypes.int64, dtypes.int64]) @test_util.run_in_graph_and_eager_modes def testMap_MultiInputSingleOutput(self): - with self.test_session(): - nums = np.array([1, 2, 3, 4, 5, 6]) - r = functional_ops.map_fn( - lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)), - dtype=dtypes.int64) - self.assertEqual((6,), r.get_shape()) - received = self.evaluate(r) - self.assertAllEqual(nums * nums + (-nums), received) + nums = np.array([1, 2, 3, 4, 5, 6]) + r = functional_ops.map_fn( + lambda x: x[0] * x[1][0] + x[1][1], (nums, (nums, -nums)), + dtype=dtypes.int64) + self.assertEqual((6,), r.get_shape()) + received = self.evaluate(r) + self.assertAllEqual(nums * nums + (-nums), received) @test_util.run_in_graph_and_eager_modes def testMap_MultiInputSameStructureOutput(self): - with self.test_session(): - nums = np.array([1, 2, 3, 4, 5, 6]) - r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])), - (nums, (2 * nums, -nums))) - r = [r[0], r[1][0], r[1][1]] - self.assertEqual((6,), r[0].get_shape()) - self.assertEqual((6,), r[1].get_shape()) - self.assertEqual((6,), r[2].get_shape()) - received = self.evaluate(r) - self.assertAllEqual(2 * nums, received[0]) - self.assertAllEqual(-nums, received[1]) - self.assertAllEqual(nums, received[2]) + nums = np.array([1, 2, 3, 4, 5, 6]) + r = functional_ops.map_fn(lambda x: (x[1][0], (x[1][1], x[0])), + (nums, (2 * nums, -nums))) + r = [r[0], r[1][0], r[1][1]] + self.assertEqual((6,), r[0].get_shape()) + self.assertEqual((6,), r[1].get_shape()) + self.assertEqual((6,), r[2].get_shape()) + received = self.evaluate(r) + self.assertAllEqual(2 * nums, received[0]) + self.assertAllEqual(-nums, received[1]) + self.assertAllEqual(nums, received[2]) @test_util.run_in_graph_and_eager_modes def testScan_Simple(self): - with self.test_session(): - elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") - v = constant_op.constant(2.0, name="v") + elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") + v = constant_op.constant(2.0, name="v") - # pylint: disable=unnecessary-lambda - r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems) - self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r)) + # pylint: disable=unnecessary-lambda + r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems) + self.assertAllEqual([1., 2., 6., 24., 120., 720.], self.evaluate(r)) - r = functional_ops.scan( - lambda a, x: math_ops.multiply(a, x), elems, initializer=v) - self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) - # pylint: enable=unnecessary-lambda + r = functional_ops.scan( + lambda a, x: math_ops.multiply(a, x), elems, initializer=v) + self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r)) + # pylint: enable=unnecessary-lambda @test_util.run_in_graph_and_eager_modes def testScan_Reverse(self): - with self.test_session(): - elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") - v = constant_op.constant(2.0, name="v") - - # pylint: disable=unnecessary-lambda - r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems, - reverse=True) - self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r)) - r = functional_ops.scan( - lambda a, x: math_ops.multiply(a, x), elems, initializer=v, - reverse=True) - self.assertAllEqual([1440., 1440., 720., 240., 60., 12.], - self.evaluate(r)) - # pylint: enable=unnecessary-lambda + elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") + v = constant_op.constant(2.0, name="v") + + # pylint: disable=unnecessary-lambda + r = functional_ops.scan(lambda a, x: math_ops.multiply(a, x), elems, + reverse=True) + self.assertAllEqual([720., 720., 360., 120., 30., 6.], self.evaluate(r)) + r = functional_ops.scan( + lambda a, x: math_ops.multiply(a, x), elems, initializer=v, + reverse=True) + self.assertAllEqual([1440., 1440., 720., 240., 60., 12.], + self.evaluate(r)) + # pylint: enable=unnecessary-lambda @test_util.run_in_graph_and_eager_modes def testScan_SingleInputMultiOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = (np.array(1.0), np.array(-1.0)) - r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems, - initializer) - r_value = self.evaluate(r) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = (np.array(1.0), np.array(-1.0)) + r = functional_ops.scan(lambda a, x: (a[0] * x, -a[1] * x), elems, + initializer) + r_value = self.evaluate(r) - self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0]) - self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1]) + self.assertAllEqual([1.0, 2.0, 6.0, 24.0, 120.0, 720.0], r_value[0]) + self.assertAllEqual([1.0, -2.0, 6.0, -24.0, 120.0, -720.0], r_value[1]) @test_util.run_in_graph_and_eager_modes def testScan_MultiInputSingleOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = np.array(1.0) - # Multiply a * 1 each time - r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]), - (elems + 1, -elems), initializer) - self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r)) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = np.array(1.0) + # Multiply a * 1 each time + r = functional_ops.scan(lambda a, x: a * (x[0] + x[1]), + (elems + 1, -elems), initializer) + self.assertAllEqual([1.0, 1.0, 1.0, 1.0, 1.0, 1.0], self.evaluate(r)) @test_util.run_in_graph_and_eager_modes def testScan_MultiInputSameTypeOutput(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]), - (elems, -elems)) - r_value = self.evaluate(r) - self.assertAllEqual(np.cumsum(elems), r_value[0]) - self.assertAllEqual(np.cumsum(-elems), r_value[1]) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + r = functional_ops.scan(lambda a, x: (a[0] + x[0], a[1] + x[1]), + (elems, -elems)) + r_value = self.evaluate(r) + self.assertAllEqual(np.cumsum(elems), r_value[0]) + self.assertAllEqual(np.cumsum(-elems), r_value[1]) @test_util.run_in_graph_and_eager_modes def testScan_MultiOutputMismatchedInitializer(self): - with self.test_session(): - elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - initializer = np.array(1.0) - # Multiply a * 1 each time - with self.assertRaisesRegexp( - ValueError, "two structures don't have the same nested structure"): - functional_ops.scan(lambda a, x: (a, -a), elems, initializer) + elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + initializer = np.array(1.0) + # Multiply a * 1 each time + with self.assertRaisesRegexp( + ValueError, "two structures don't have the same nested structure"): + functional_ops.scan(lambda a, x: (a, -a), elems, initializer) def testScan_Scoped(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope("root") as varscope: elems = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") @@ -420,30 +402,29 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testScanFoldl_Nested(self): - with self.test_session(): - elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data") - inner_elems = constant_op.constant([0.5, 0.5], name="data") - - def r_inner(a, x): - return functional_ops.foldl( - lambda b, y: b * y * x, inner_elems, initializer=a) - - r = functional_ops.scan(r_inner, elems) - - # t == 0 (returns 1) - # t == 1, a == 1, x == 2 (returns 1) - # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1 - # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1 - # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25) - # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5 - # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5 - # t == 3, a == 2.25, x == 4 (returns 9) - # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5 - # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9 - self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r)) + elems = constant_op.constant([1.0, 2.0, 3.0, 4.0], name="data") + inner_elems = constant_op.constant([0.5, 0.5], name="data") + + def r_inner(a, x): + return functional_ops.foldl( + lambda b, y: b * y * x, inner_elems, initializer=a) + + r = functional_ops.scan(r_inner, elems) + + # t == 0 (returns 1) + # t == 1, a == 1, x == 2 (returns 1) + # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1 + # t_1 == 1, b == 1, y == 0.5, returns b * y * x = 1 + # t == 2, a == 1, x == 3 (returns 1.5*1.5 == 2.25) + # t_0 == 0, b == a == 1, y == 0.5, returns b * y * x = 1.5 + # t_1 == 1, b == 1.5, y == 0.5, returns b * y * x = 1.5*1.5 + # t == 3, a == 2.25, x == 4 (returns 9) + # t_0 == 0, b == a == 2.25, y == 0.5, returns b * y * x = 4.5 + # t_1 == 1, b == 4.5, y == 0.5, returns b * y * x = 9 + self.assertAllClose([1., 1., 2.25, 9.], self.evaluate(r)) def testScan_Control(self): - with self.test_session() as sess: + with self.cached_session() as sess: s = array_ops.placeholder(dtypes.float32, shape=[None]) b = array_ops.placeholder(dtypes.bool) @@ -454,7 +435,7 @@ class FunctionalOpsTest(test.TestCase): b: True})) def testScan_Grad(self): - with self.test_session(): + with self.cached_session(): elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data") v = constant_op.constant(2.0, name="v") @@ -479,22 +460,20 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testFoldShape(self): - with self.test_session(): - x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) + x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) - def fn(_, current_input): - return current_input + def fn(_, current_input): + return current_input - initializer = constant_op.constant([0, 0, 0]) - y = functional_ops.foldl(fn, x, initializer=initializer) - self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) + initializer = constant_op.constant([0, 0, 0]) + y = functional_ops.foldl(fn, x, initializer=initializer) + self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) @test_util.run_in_graph_and_eager_modes def testMapShape(self): - with self.test_session(): - x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) - y = functional_ops.map_fn(lambda e: e, x) - self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) + x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) + y = functional_ops.map_fn(lambda e: e, x) + self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) def testMapUnknownShape(self): x = array_ops.placeholder(dtypes.float32) @@ -503,15 +482,14 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testMapEmptyScalar(self): - with self.test_session(): - map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([])) - self.assertAllEqual([0], map_return.get_shape().dims) - self.assertAllEqual([0], self.evaluate(map_return).shape) + map_return = functional_ops.map_fn(lambda x: 1, constant_op.constant([])) + self.assertAllEqual([0], map_return.get_shape().dims) + self.assertAllEqual([0], self.evaluate(map_return).shape) # TODO(akshayka): this test fails in eager: the iterable is of length 0 so # so the body of the while loop never executes def testMapEmptyTensor(self): - with self.test_session(): + with self.cached_session(): map_return = functional_ops.map_fn(lambda x: array_ops.zeros([3, 2]), constant_op.constant([])) self.assertAllEqual([0, 3, 2], map_return.get_shape().dims) @@ -519,20 +497,19 @@ class FunctionalOpsTest(test.TestCase): @test_util.run_in_graph_and_eager_modes def testScanShape(self): - with self.test_session(): - x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) + x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) - def fn(_, current_input): - return current_input + def fn(_, current_input): + return current_input - initializer = constant_op.constant([0, 0, 0]) - y = functional_ops.scan(fn, x, initializer=initializer) - self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) + initializer = constant_op.constant([0, 0, 0]) + y = functional_ops.scan(fn, x, initializer=initializer) + self.assertAllEqual(y.get_shape(), self.evaluate(y).shape) # TODO(akshayka): this test fails in eager: the iterable is of length 0 so # so the body of the while loop never executes def testScanEmptyTensor(self): - with self.test_session(): + with self.cached_session(): x = functional_ops.scan( lambda x, _: x, math_ops.range(0), initializer=array_ops.ones([2, 4])) self.assertAllEqual([0, 2, 4], x.get_shape()) @@ -549,7 +526,7 @@ class FunctionalOpsTest(test.TestCase): self.assertIs(None, y.get_shape().dims) def testScanVaryingShape(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.placeholder(dtype=dtypes.float32, shape=[None, 2]) x_t = array_ops.transpose(x) # scan over dimension 0 (with shape None) @@ -628,7 +605,7 @@ class FunctionalOpsTest(test.TestCase): remote_op = functional_ops.remote_call( args=[a, b], Tout=[dtypes.int32], f=_remote_fn, target="/cpu:0") - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) mul = sess.run(remote_op) self.assertEqual(mul, [6]) @@ -652,7 +629,7 @@ class FunctionalOpsTest(test.TestCase): f=_remote_fn, target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0 - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) mul = sess.run(remote_op) self.assertEqual(mul, 9.0) @@ -676,7 +653,7 @@ class FunctionalOpsTest(test.TestCase): f=_remote_fn, target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0 - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(variables.global_variables_initializer()) mul = sess.run(remote_op) self.assertEqual(mul, 9.0) @@ -695,7 +672,7 @@ class FunctionalOpsTest(test.TestCase): remote_op = functional_ops.remote_call( args=[a], Tout=[dtypes.string], f=_remote_fn, target="/cpu:0") - with self.test_session() as sess: + with self.cached_session() as sess: ret = sess.run(remote_op) self.assertAllEqual(ret, [b"a"]) diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index ff941b64fa..0f5607712b 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -170,9 +170,8 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_pop_back( l_cpu, element_dtype=dtypes.float32)[1]), 2.0) - @test_util.run_in_graph_and_eager_modes def testGraphStack(self): - with context.graph_mode(), self.test_session(): + with self.cached_session(): tl = list_ops.empty_tensor_list( element_shape=constant_op.constant([1], dtype=dtypes.int32), element_dtype=dtypes.int32) @@ -182,9 +181,8 @@ class ListOpsTest(test_util.TensorFlowTestCase): list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)), [[1]]) - @test_util.run_in_graph_and_eager_modes def testGraphStackInLoop(self): - with context.graph_mode(), self.test_session(): + with self.cached_session(): t1 = list_ops.empty_tensor_list( element_shape=constant_op.constant([], dtype=dtypes.int32), element_dtype=dtypes.int32) @@ -200,9 +198,8 @@ class ListOpsTest(test_util.TensorFlowTestCase): s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32) self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3]) - @test_util.run_in_graph_and_eager_modes def testGraphStackSwitchDtype(self): - with context.graph_mode(), self.test_session(): + with self.cached_session(): list_ = list_ops.empty_tensor_list( element_shape=constant_op.constant([], dtype=dtypes.int32), element_dtype=dtypes.int32) @@ -222,9 +219,8 @@ class ListOpsTest(test_util.TensorFlowTestCase): np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32) self.assertAllEqual(self.evaluate(s1), np_s1) - @test_util.run_in_graph_and_eager_modes def testGraphStackInLoopSwitchDtype(self): - with context.graph_mode(), self.test_session(): + with self.cached_session(): t1 = list_ops.empty_tensor_list( element_shape=constant_op.constant([], dtype=dtypes.int32), element_dtype=dtypes.int32) diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 50154a45a8..79fcbaad43 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -61,7 +61,7 @@ class PyFuncTest(test.TestCase): for dtype in [dtypes.float16, dtypes.float32, dtypes.float64, dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16, dtypes.int32, dtypes.int64]: - with self.test_session(): + with self.cached_session(): x = constant_op.constant(1, dtype=dtype) y = constant_op.constant(2, dtype=dtype) z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype)) @@ -71,7 +71,7 @@ class PyFuncTest(test.TestCase): def sub_func(x, y): return x - y for dtype in [dtypes.complex64, dtypes.complex128]: - with self.test_session(): + with self.cached_session(): x = constant_op.constant(1 + 1j, dtype=dtype) y = constant_op.constant(2 - 2j, dtype=dtype) z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype)) @@ -81,21 +81,21 @@ class PyFuncTest(test.TestCase): def and_func(x, y): return x and y dtype = dtypes.bool - with self.test_session(): + with self.cached_session(): x = constant_op.constant(True, dtype=dtype) y = constant_op.constant(False, dtype=dtype) z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype)) self.assertEqual(z, False) def testSingleType(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant(1.0, dtypes.float32) y = constant_op.constant(2.0, dtypes.float32) z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32)) self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32)) def testScalar(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant(1.0, dtypes.float32) y = constant_op.constant(2.0, dtypes.float32) z = self.evaluate( @@ -103,7 +103,7 @@ class PyFuncTest(test.TestCase): self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32)) def testArray(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant([1.0, 2.0], dtypes.float64) y = constant_op.constant([2.0, 3.0], dtypes.float64) z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64])) @@ -111,14 +111,14 @@ class PyFuncTest(test.TestCase): np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64)) def testComplexType(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant(1 + 2j, dtypes.complex64) y = constant_op.constant(3 + 4j, dtypes.complex64) z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64)) self.assertAllClose(z, np_func(1 + 2j, 3 + 4j)) def testRFFT(self): - with self.test_session(): + with self.cached_session(): x = constant_op.constant([1., 2., 3., 4.], dtypes.float32) def rfft(x): @@ -128,7 +128,7 @@ class PyFuncTest(test.TestCase): self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.])) def testPythonLiteral(self): - with self.test_session(): + with self.cached_session(): def literal(x): return 1.0 if float(x) == 0.0 else 0.0 @@ -138,7 +138,7 @@ class PyFuncTest(test.TestCase): self.assertAllClose(y, 1.0) def testList(self): - with self.test_session(): + with self.cached_session(): def list_func(x): return [x, x + 1] @@ -150,7 +150,7 @@ class PyFuncTest(test.TestCase): def testTuple(self): # returns a tuple - with self.test_session(): + with self.cached_session(): def tuple_func(x): return x, x + 1 @@ -161,7 +161,7 @@ class PyFuncTest(test.TestCase): self.assertAllClose(y, [0.0, 1.0]) # returns a tuple, Tout and inp a tuple - with self.test_session(): + with self.cached_session(): x = constant_op.constant(0.0, dtypes.float64) y = self.evaluate( script_ops.py_func(tuple_func, (x,), @@ -176,7 +176,7 @@ class PyFuncTest(test.TestCase): def read_and_return_strings(x, y): return x + y - with self.test_session(): + with self.cached_session(): x = constant_op.constant([b"hello", b"hi"], dtypes.string) y = self.evaluate( script_ops.py_func(read_fixed_length_numpy_strings, [], @@ -193,7 +193,7 @@ class PyFuncTest(test.TestCase): def read_and_return_strings(x, y): return x + y - with self.test_session(): + with self.cached_session(): x = constant_op.constant(["hello", "hi"], dtypes.string) y = self.evaluate( script_ops.py_func(read_fixed_length_numpy_strings, [], @@ -210,7 +210,7 @@ class PyFuncTest(test.TestCase): def read_and_return_strings(x, y): return x + y - with self.test_session(): + with self.cached_session(): x = constant_op.constant(["hello", "hi"], dtypes.string) y, = script_ops.py_func(read_object_array, [], [dtypes.string]) @@ -219,19 +219,19 @@ class PyFuncTest(test.TestCase): def testStringPadding(self): correct = [b"this", b"is", b"a", b"test"] - with self.test_session(): + with self.cached_session(): s, = script_ops.py_func(lambda: [correct], [], [dtypes.string]) self.assertAllEqual(s.eval(), correct) def testStringPaddingAreConvertedToBytes(self): inp = ["this", "is", "a", "test"] correct = [b"this", b"is", b"a", b"test"] - with self.test_session(): + with self.cached_session(): s, = script_ops.py_func(lambda: [inp], [], [dtypes.string]) self.assertAllEqual(s.eval(), correct) def testLarge(self): - with self.test_session() as sess: + with self.cached_session() as sess: x = array_ops.zeros([1000000], dtype=np.float32) y = script_ops.py_func(lambda x: x + 1, [x], [dtypes.float32]) z = script_ops.py_func(lambda x: x * 2, [x], [dtypes.float32]) @@ -239,12 +239,12 @@ class PyFuncTest(test.TestCase): sess.run([y[0].op, z[0].op]) def testNoInput(self): - with self.test_session(): + with self.cached_session(): x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64)) self.assertAllClose(x, 42.0) def testAlias(self): - with self.test_session(): + with self.cached_session(): np_array = np.array([1.0, 2.0], dtype=np.float32) tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32]) value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32) @@ -252,7 +252,7 @@ class PyFuncTest(test.TestCase): self.assertAllEqual(np_array, [1.0, 2.0]) def testReturnUnicodeString(self): - with self.test_session(): + with self.cached_session(): correct = u"你好 世界" def unicode_string(): @@ -262,7 +262,7 @@ class PyFuncTest(test.TestCase): self.assertEqual(z.eval(), correct.encode("utf8")) def testBadNumpyReturnType(self): - with self.test_session(): + with self.cached_session(): def bad(): # Structured numpy arrays aren't supported. @@ -275,7 +275,7 @@ class PyFuncTest(test.TestCase): y.eval() def testBadReturnType(self): - with self.test_session(): + with self.cached_session(): def bad(): # Non-string python objects aren't supported. @@ -288,7 +288,7 @@ class PyFuncTest(test.TestCase): z.eval() def testReturnInput(self): - with self.test_session(): + with self.cached_session(): def ident(x): return x[0] @@ -303,7 +303,7 @@ class PyFuncTest(test.TestCase): self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]})) def testStateful(self): - # Not using self.test_session(), which disables optimization. + # Not using self.cached_session(), which disables optimization. with session_lib.Session() as sess: producer = iter(range(3)) x, = script_ops.py_func(lambda: next(producer), [], [dtypes.int64]) @@ -312,7 +312,7 @@ class PyFuncTest(test.TestCase): self.assertEqual(sess.run(x), 2) def testStateless(self): - # Not using self.test_session(), which disables optimization. + # Not using self.cached_session(), which disables optimization. with session_lib.Session() as sess: producer = iter(range(3)) x, = script_ops.py_func( @@ -331,7 +331,7 @@ class PyFuncTest(test.TestCase): self.assertEqual(None, ops.get_gradient_function(y.op)) def testCOrder(self): - with self.test_session(): + with self.cached_session(): val = [[1, 2], [3, 4]] x, = script_ops.py_func(lambda: np.array(val, order="F"), [], [dtypes.int64]) @@ -339,7 +339,7 @@ class PyFuncTest(test.TestCase): def testParallel(self): # Tests that tf.py_func's can run in parallel if they release the GIL. - with self.test_session() as session: + with self.cached_session() as session: q = queue.Queue(1) def blocking_put(): @@ -375,7 +375,7 @@ class PyFuncTest(test.TestCase): def value(self): return self._value - with self.test_session(): + with self.cached_session(): s = State() op = s.increment(constant_op.constant(2, dtypes.int64)) ret = self.evaluate(op) @@ -389,7 +389,7 @@ class PyFuncTest(test.TestCase): f = script_ops.py_func( do_nothing, [constant_op.constant(3, dtypes.int64)], [], stateful=False) - with self.test_session() as sess: + with self.cached_session() as sess: self.assertEqual(sess.run(f), []) def _testExceptionHandling(self, py_exp, tf_exp, eager=False): @@ -417,21 +417,22 @@ class PyFuncTest(test.TestCase): else: f = script_ops.py_func(raise_exception, [], []) - with self.test_session(): - with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): - self.evaluate(f) + with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): + self.evaluate(f) def testExceptionHandling(self): - self._testExceptionHandling(ValueError, errors.InvalidArgumentError) - self._testExceptionHandling(TypeError, errors.InvalidArgumentError) - self._testExceptionHandling(StopIteration, errors.OutOfRangeError) - self._testExceptionHandling(MemoryError, errors.ResourceExhaustedError) - self._testExceptionHandling(NotImplementedError, errors.UnimplementedError) + with self.cached_session(): + self._testExceptionHandling(ValueError, errors.InvalidArgumentError) + self._testExceptionHandling(TypeError, errors.InvalidArgumentError) + self._testExceptionHandling(StopIteration, errors.OutOfRangeError) + self._testExceptionHandling(MemoryError, errors.ResourceExhaustedError) + self._testExceptionHandling(NotImplementedError, + errors.UnimplementedError) - class WeirdError(Exception): - pass + class WeirdError(Exception): + pass - self._testExceptionHandling(WeirdError, errors.UnknownError) + self._testExceptionHandling(WeirdError, errors.UnknownError) # ----- Tests shared by py_func and eager_py_func ----- def testCleanup(self): @@ -452,7 +453,7 @@ class PyFuncTest(test.TestCase): # (see #18292) _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) - + # Call garbage collector to enforce deletion. make_graphs() ops.reset_default_graph() @@ -610,7 +611,7 @@ class PyFuncTest(test.TestCase): func=log_huber, inp=[x, m], Tout=dtypes.float32) dy_dx = gradients_impl.gradients(y, x)[0] - with self.test_session() as sess: + with self.cached_session() as sess: # Takes the first branch of log_huber. y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0}) self.assertEqual(y, 1.0) diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index d0ed08933d..f90545f84c 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -54,7 +54,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(0, len(gc.garbage)) def testHandleDtypeShapeMatch(self): - with self.test_session(): + with self.cached_session(): handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) with self.assertRaises(ValueError): resource_variable_ops.assign_variable_op( @@ -123,7 +123,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertFalse(np.allclose(variable.numpy(), copied_variable.numpy())) def testGraphDeepCopy(self): - with self.test_session(): + with self.cached_session(): init_value = np.ones((4, 4, 4)) variable = resource_variable_ops.ResourceVariable(init_value, name="init") @@ -145,13 +145,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): # variable graph. def testFetchHandle(self): - with self.test_session(): + with self.cached_session(): handle = resource_variable_ops.var_handle_op( dtype=dtypes.int32, shape=[1], name="foo") self.assertGreater(len(handle.eval()), 0) def testCachedValueReadBeforeWrite(self): - with self.test_session() as sess: + with self.cached_session() as sess: v = resource_variable_ops.ResourceVariable(0.0, caching_device="cpu:0") sess.run(v.initializer) value, _ = sess.run([v, v.assign_add(1.0)]) @@ -492,7 +492,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): # TODO(alive): how should this work in Eager mode? def testInitFn(self): - with self.test_session(): + with self.cached_session(): v = resource_variable_ops.ResourceVariable( initial_value=lambda: 1, dtype=dtypes.float32) self.assertEqual(v.handle.op.colocation_groups(), @@ -569,11 +569,11 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(2.0, self.evaluate(v.value())) def testVariableDefInitializedInstances(self): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: v_def = resource_variable_ops.ResourceVariable( initial_value=constant_op.constant(3.0)).to_proto() - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: # v describes a VariableDef-based variable without an initial value. v = resource_variable_ops.ResourceVariable(variable_def=v_def) self.assertEqual(3.0, sess.run(v.initialized_value())) @@ -584,7 +584,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): self.assertEqual(1.0, v.initialized_value().eval()) v_def.ClearField("initial_value_name") - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: # Restoring a legacy VariableDef proto that does not have # initial_value_name set should still work. v = resource_variable_ops.ResourceVariable(variable_def=v_def) @@ -615,17 +615,16 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes def testSparseRead(self): - with self.test_session(): - init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4)) - v = resource_variable_ops.ResourceVariable( - constant_op.constant(init_value, dtype=dtypes.int32), name="var3") - self.evaluate(variables.global_variables_initializer()) + init_value = np.reshape(np.arange(np.power(4, 3)), (4, 4, 4)) + v = resource_variable_ops.ResourceVariable( + constant_op.constant(init_value, dtype=dtypes.int32), name="var3") + self.evaluate(variables.global_variables_initializer()) - value = self.evaluate(v.sparse_read([0, 3, 1, 2])) - self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value) + value = self.evaluate(v.sparse_read([0, 3, 1, 2])) + self.assertAllEqual(init_value[[0, 3, 1, 2], ...], value) def testToFromProto(self): - with self.test_session(): + with self.cached_session(): v = resource_variable_ops.ResourceVariable(1.0) variables.global_variables_initializer().run() @@ -686,7 +685,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): handle, ignore_lookup_error=True)) def testAssignDifferentShapes(self): - with self.test_session() as sess, variable_scope.variable_scope( + with self.cached_session() as sess, variable_scope.variable_scope( "foo", use_resource=True): var = variable_scope.get_variable("x", shape=[1, 1], dtype=dtypes.float32) placeholder = array_ops.placeholder(dtypes.float32) @@ -728,7 +727,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): _ = w.value().op.get_attr("_class") def testSharedName(self): - with self.test_session(): + with self.cached_session(): v = resource_variable_ops.ResourceVariable(300.0, name="var4") variables.global_variables_initializer().run() @@ -746,7 +745,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.read_variable_op(x, v.dtype.base_dtype).eval() def testSharedNameWithNamescope(self): - with self.test_session(): + with self.cached_session(): with ops.name_scope("foo"): v = resource_variable_ops.ResourceVariable(300.0, name="var6") self.assertEqual("foo/var6", v._shared_name) # pylint: disable=protected-access @@ -774,7 +773,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): str(v.sparse_read(array_ops.placeholder(dtypes.int32)).shape)) def testSetInitialValue(self): - with self.test_session(): + with self.cached_session(): # Initialize variable with a value different from the initial value passed # in the constructor. v = resource_variable_ops.ResourceVariable(2.0) diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index 562d11f0b0..a28cdc3b26 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -197,7 +197,7 @@ class RNNTest(test.TestCase): else: inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) - with self.test_session() as sess: + with self.cached_session(use_gpu=True) as sess: outputs, state = rnn.dynamic_rnn( cell, inputs, dtype=dtypes.float32, sequence_length=[4]) if not in_eager_mode: @@ -217,7 +217,7 @@ class RNNTest(test.TestCase): else: inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) - with self.test_session() as sess: + with self.cached_session(use_gpu=True) as sess: outputs, state = rnn.dynamic_rnn( cell, inputs, dtype=dtypes.float32, sequence_length=[4]) if not in_eager_mode: @@ -246,7 +246,7 @@ class RNNTest(test.TestCase): else: inputs = array_ops.placeholder(dtypes.float32, shape=(1, 4, 1)) - with self.test_session() as sess: + with self.cached_session(use_gpu=True) as sess: outputs, state = rnn.dynamic_rnn( cell, inputs, dtype=dtypes.float32, sequence_length=[4]) state = (state[0], state[1].stack()) @@ -321,7 +321,7 @@ class RNNTest(test.TestCase): self._assert_cell_builds(contrib_rnn.IndyLSTMCell, f64, 5, 7, 3) def testRNNWithKerasSimpleRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: input_shape = 10 output_shape = 5 timestep = 4 @@ -354,7 +354,7 @@ class RNNTest(test.TestCase): self.assertEqual(len(state), batch) def testRNNWithKerasGRUCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: input_shape = 10 output_shape = 5 timestep = 4 @@ -387,7 +387,7 @@ class RNNTest(test.TestCase): self.assertEqual(len(state), batch) def testRNNWithKerasLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: input_shape = 10 output_shape = 5 timestep = 4 @@ -424,7 +424,7 @@ class RNNTest(test.TestCase): self.assertEqual(len(state[1]), batch) def testRNNWithStackKerasCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: input_shape = 10 output_shape = 5 timestep = 4 @@ -465,7 +465,7 @@ class RNNTest(test.TestCase): self.assertEqual(len(s), batch) def testStaticRNNWithKerasSimpleRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: input_shape = 10 output_shape = 5 timestep = 4 @@ -567,7 +567,7 @@ class RNNTest(test.TestCase): rnn_cell_impl.GRUCell( 32, kernel_initializer="ones", dtype=dtypes.float32) ]: - with self.test_session(): + with self.cached_session(): x = keras.Input((None, 5)) layer = keras.layers.RNN(cell) y = layer(x) -- GitLab From d046dd6501af0ca7d90a6ce7611dfe23a99aa781 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Tue, 4 Sep 2018 18:45:42 -0700 Subject: [PATCH 097/540] Move iterator.get_next() to be called inside fit from inside of standardize function. PiperOrigin-RevId: 211564198 --- tensorflow/python/keras/engine/training.py | 42 +++------------- .../keras/engine/training_distributed.py | 49 +++++++++++++------ 2 files changed, 42 insertions(+), 49 deletions(-) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 85d25411b4..ef6a04b00f 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -790,10 +790,7 @@ class Model(Network): Fraction of the training data to be used as validation data. Returns: - A tuple of 3 lists: input arrays, target arrays, sample-weight arrays. - If the model's input and targets are symbolic, these lists are empty - (since the model takes no user-provided data, instead the data comes - from the symbolic inputs/targets). + Iterator for reading the dataset `x`. Raises: ValueError: In case of invalid user-provided data. @@ -828,30 +825,7 @@ class Model(Network): training_utils.validate_iterator_input(x, y, sample_weight, validation_split) - # x an y may be PerDevice objects with an input and output tensor - # corresponding to each device. For example, x could be - # PerDevice:{device: get_next tensor,...}. - next_element = iterator.get_next() - - if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: - raise ValueError('Please provide model inputs as a list or tuple of 2 ' - 'elements: input and target pair. ' - 'Received %s' % next_element) - x, y = next_element - # Validate that all the elements in x and y are of the same type and shape. - # We can then pass the first element of x and y to `_standardize_weights` - # below and be confident of the output. We need to reopen the scope since - # we unwrap values when we validate x and y. - with self._distribution_strategy.scope(): - x_values, y_values = distributed_training_utils.\ - validate_distributed_dataset_inputs(self._distribution_strategy, x, y) - - _, _, sample_weights = self._standardize_weights(x_values, - y_values, - sample_weight, - class_weight, - batch_size) - return x, y, sample_weights + return iterator def _standardize_user_data(self, x, @@ -916,7 +890,7 @@ class Model(Network): RuntimeError: If the model was never compiled. """ if self._distribution_strategy: - return self._distribution_standardize_user_data( + iterator = self._distribution_standardize_user_data( x, y, sample_weight=sample_weight, @@ -926,6 +900,7 @@ class Model(Network): steps_name=steps_name, steps=steps, validation_split=validation_split) + return iterator, None, None if isinstance(x, dataset_ops.Dataset): if context.executing_eagerly(): @@ -982,6 +957,7 @@ class Model(Network): def _standardize_weights(self, x, y, sample_weight=None, class_weight=None, batch_size=None,): + # TODO(sourabhbajaj): Split input validation from weight standardization. if sample_weight is not None and class_weight is not None: logging.warning( 'Received both a `sample_weight` and `class_weight` argument. ' @@ -1566,12 +1542,11 @@ class Model(Network): validation_steps=validation_steps) elif self._distribution_strategy: return training_distributed.fit_loop( - self, x, y, + self, x, epochs=epochs, verbose=verbose, callbacks=callbacks, - val_inputs=val_x, - val_targets=val_y, + val_iterator=val_x, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps) @@ -1677,8 +1652,7 @@ class Model(Network): elif self._distribution_strategy: return training_distributed.test_loop( self, - inputs=x, - targets=y, + iterator=x, verbose=verbose, steps=steps) else: diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index 85f1d6299f..b7f43dea56 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -30,13 +30,11 @@ from tensorflow.python.platform import tf_logging as logging def fit_loop( model, - inputs, - targets, + iterator, epochs=100, verbose=1, callbacks=None, - val_inputs=None, - val_targets=None, + val_iterator=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None): @@ -44,13 +42,11 @@ def fit_loop( Arguments: model: Keras Model instance. - inputs: List of input arrays. - targets: List of target arrays. + iterator: Iterator for input data. epochs: Number of times to iterate over the data verbose: Verbosity mode, 0, 1 or 2 callbacks: List of callbacks to be called during training - val_inputs: List of input arrays. - val_targets: List of target arrays. + val_iterator: Iterator for validation data. initial_epoch: Epoch at which to start training (useful for resuming a previous training run) steps_per_epoch: Total number of steps (batches of samples) @@ -74,6 +70,7 @@ def fit_loop( model.train_function.updates_op, model.train_function.session_kwargs) + inputs, targets = _get_input_from_iterator(iterator, model) with current_strategy.scope(): # Create train ops on each of the devices when we call # `_per_device_train_function`. @@ -169,8 +166,7 @@ def fit_loop( if do_validation: val_outs = test_loop( model, - val_inputs, - val_targets, + val_iterator, steps=validation_steps, verbose=0) if not isinstance(val_outs, list): @@ -192,13 +188,12 @@ def fit_loop( return model.history -def test_loop(model, inputs, targets, verbose=0, steps=None): +def test_loop(model, iterator, verbose=0, steps=None): """evaluate method to validate a model that uses DistributionStrategy. Arguments: model: Keras Model instance. - inputs: List of input arrays. - targets: List of target arrays. + iterator: Iterator for input data. verbose: verbosity mode. steps: Total number of steps (batches of samples) before declaring predictions finished. @@ -218,6 +213,7 @@ def test_loop(model, inputs, targets, verbose=0, steps=None): model.test_function.updates_op, model.test_function.session_kwargs) + inputs, targets = _get_input_from_iterator(iterator, model) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( @@ -284,12 +280,12 @@ def test_loop(model, inputs, targets, verbose=0, steps=None): return outs -def predict_loop(model, inputs, verbose=0, steps=None): +def predict_loop(model, iterator, verbose=0, steps=None): """Abstract method to loop over some data in batches. Arguments: model: Keras Model instance. - inputs: list of tensors to be fed to `f`. + iterator: Iterator for input data. verbose: verbosity mode. steps: Total number of steps (batches of samples) before declaring `_predict_loop` finished. @@ -308,6 +304,7 @@ def predict_loop(model, inputs, verbose=0, steps=None): model.predict_function.updates_op, model.predict_function.session_kwargs) + inputs, _ = _get_input_from_iterator(iterator, model) with current_strategy.scope(): (grouped_inputs, grouped_outputs, grouped_updates, grouped_session_args) = current_strategy.call_for_each_tower( @@ -419,3 +416,25 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs): merged_output.append(m) current_index += num_devices return merged_output + + +def _get_input_from_iterator(iterator, model): + """Get elements from the iterator and verify the input shape and type.""" + next_element = iterator.get_next() + # TODO(anjalisridhar): Support predict input correctly as it will not contain + # targets, only inputs. + if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: + raise ValueError('Please provide model inputs as a list or tuple of 2 ' + 'elements: input and target pair. ' + 'Received %s' % next_element) + + x, y = next_element + # Validate that all the elements in x and y are of the same type and shape. + # We can then pass the first element of x and y to `_standardize_weights` + # below and be confident of the output. + x_values, y_values = distributed_training_utils.\ + validate_distributed_dataset_inputs(model._distribution_strategy, x, y) + # TODO(sourabhbajaj): Add support for sample weights in distribution + # strategy. + model._standardize_weights(x_values, y_values) + return x, y -- GitLab From ecb6bc19e0cdbd2f2e98de909b4f3b8ca9fd7ab1 Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Tue, 4 Sep 2018 20:09:05 -0700 Subject: [PATCH 098/540] Clone the model in fit instead of compile for distribution strategy in keras. PiperOrigin-RevId: 211570665 --- tensorflow/python/keras/engine/training.py | 45 +++++++------------ .../keras/engine/training_distributed.py | 22 ++++++++- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index ef6a04b00f..e07220d15a 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -405,20 +405,7 @@ class Model(Network): # Set DistributionStrategy specific parameters. self._distribution_strategy = distribute if self._distribution_strategy is not None: - self._grouped_model = self._compile_distributed_model( - self._distribution_strategy) - with self._distribution_strategy.scope(): - first_replicated_model = self._distribution_strategy.unwrap( - self._grouped_model)[0] - # If the specified metrics in `compile` are stateful, raise an error - # since we currently don't support stateful metrics. - if first_replicated_model.stateful_metric_names: - raise NotImplementedError('Stateful metrics are not supported with ' - 'DistributionStrategy.') - - # We initialize the callback model with the first replicated model. - self._replicated_model = DistributedCallbackModel(first_replicated_model) - self._replicated_model.set_original_model(self) + self._grouped_model = None if not self.built: # Model is not compilable because it does not know its number of inputs # and outputs, nor their shapes and names. We will compile after the first @@ -636,6 +623,12 @@ class Model(Network): skip_target_indices=skip_target_indices, sample_weights=self.sample_weights) + # If using distribution strategy and stateful_metrics, raise an error + # since we currently don't support stateful metrics. + if self._distribution_strategy is not None and self.stateful_metric_names: + raise NotImplementedError('Stateful metrics are not supported with ' + 'DistributionStrategy.') + # Prepare gradient updates and state updates. self.total_loss = total_loss @@ -652,19 +645,6 @@ class Model(Network): trainable_weights = self.trainable_weights self._collected_trainable_weights = trainable_weights - def _compile_distributed_model(self, distribution_strategy): - # TODO(anjalisridhar): Can we move the clone_and_build_model to outside the - # model? - def _clone_model_per_tower(model): - new_model = training_distributed.clone_and_build_model(model) - return new_model - - with distribution_strategy.scope(): - # Create a copy of this model on each of the devices. - grouped_models = distribution_strategy.call_for_each_tower( - _clone_model_per_tower, self) - return grouped_models - def _check_trainable_weights_consistency(self): """Check trainable weights count consistency. @@ -2162,6 +2142,13 @@ class Model(Network): return self.callback_model return self + def _make_callback_model(self): + first_replicated_model = self._distribution_strategy.unwrap( + self._grouped_model)[0] + # We initialize the callback model with the first replicated model. + self._replicated_model = DistributedCallbackModel(first_replicated_model) + self._replicated_model.set_original_model(self) + class DistributedCallbackModel(Model): """Model that is used for callbacks with DistributionStrategy.""" @@ -2199,6 +2186,6 @@ class DistributedCallbackModel(Model): # Whitelisted atttributes of the model that can be accessed by the user # during a callback. if item not in ['_setattr_tracking']: - logging.warning('You are accessing attribute ' + item + 'of the' - 'DistributedCallbackModel that may not have been set' + logging.warning('You are accessing attribute ' + item + 'of the ' + 'DistributedCallbackModel that may not have been set ' 'correctly.') diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index b7f43dea56..a7bb1f8177 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -63,6 +63,10 @@ def fit_loop( ValueError: in case of invalid arguments. """ current_strategy = model._distribution_strategy + + clone_model_on_towers( + model, current_strategy, make_callback_model=True) + def _per_device_train_function(model): model._make_train_function() return (model.train_function.inputs, @@ -206,6 +210,9 @@ def test_loop(model, iterator, verbose=0, steps=None): the display labels for the scalar outputs. """ current_strategy = model._distribution_strategy + + clone_model_on_towers(model, current_strategy) + def _per_device_test_function(model): model._make_test_function() return (model.test_function.inputs, @@ -297,6 +304,9 @@ def predict_loop(model, iterator, verbose=0, steps=None): (if the model has multiple outputs). """ current_strategy = model._distribution_strategy + + clone_model_on_towers(model, current_strategy) + def _per_device_predict_function(model): model._make_predict_function() return (model.predict_function.inputs, @@ -363,7 +373,7 @@ def predict_loop(model, iterator, verbose=0, steps=None): ] -def clone_and_build_model(model): +def _clone_and_build_model(model): """Clone and build the given keras_model.""" # We need to set the import here since we run into a circular dependency # error. @@ -387,6 +397,16 @@ def clone_and_build_model(model): return cloned_model +def clone_model_on_towers(model, strategy, make_callback_model=False): + """Create a cloned model on each tower, unless already created.""" + if not model._grouped_model: + with strategy.scope(): + model._grouped_model = strategy.call_for_each_tower( + _clone_and_build_model, model) + if make_callback_model: + model._make_callback_model() + + def _aggregate_metrics_across_towers(num_devices, out_labels, outs): """Aggregate metrics values across all towers. -- GitLab From e9332539bea372f6dbe6ef185f9d8b1f3b6e1fe2 Mon Sep 17 00:00:00 2001 From: Alan Chiao Date: Tue, 4 Sep 2018 20:22:55 -0700 Subject: [PATCH 099/540] Relu1 custom op. This is implemented as custom op instead of builtin op because Relu1 is not supported in Tensorflow and not commonly used. PiperOrigin-RevId: 211571619 --- tensorflow/contrib/lite/kernels/BUILD | 18 +++++ tensorflow/contrib/lite/kernels/register.cc | 2 + tensorflow/contrib/lite/kernels/relu1.cc | 59 ++++++++++++++ tensorflow/contrib/lite/kernels/relu1_test.cc | 79 +++++++++++++++++++ 4 files changed, 158 insertions(+) create mode 100644 tensorflow/contrib/lite/kernels/relu1.cc create mode 100644 tensorflow/contrib/lite/kernels/relu1_test.cc diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index ab989c5425..b7c5cbf207 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -192,6 +192,7 @@ cc_library( "pooling.cc", "pow.cc", "reduce.cc", + "relu1.cc", "reshape.cc", "resize_bilinear.cc", "select.cc", @@ -304,6 +305,23 @@ tf_cc_test( ], ) +tf_cc_test( + name = "relu1_test", + size = "small", + srcs = ["relu1_test.cc"], + tags = [ + "no_oss", + "tflite_not_portable_ios", + ], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) + tf_cc_test( name = "activations_test", size = "small", diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 188015f43c..c66959fdf4 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -25,6 +25,7 @@ TfLiteRegistration* Register_AUDIO_SPECTROGRAM(); TfLiteRegistration* Register_LAYER_NORM_LSTM(); TfLiteRegistration* Register_MFCC(); TfLiteRegistration* Register_DETECTION_POSTPROCESS(); +TfLiteRegistration* Register_RELU_1(); } // namespace custom @@ -249,6 +250,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddCustom("AudioSpectrogram", tflite::ops::custom::Register_AUDIO_SPECTROGRAM()); AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM()); + AddCustom("Relu1", tflite::ops::custom::Register_RELU_1()); AddCustom("TFLite_Detection_PostProcess", tflite::ops::custom::Register_DETECTION_POSTPROCESS()); } diff --git a/tensorflow/contrib/lite/kernels/relu1.cc b/tensorflow/contrib/lite/kernels/relu1.cc new file mode 100644 index 0000000000..abafee2d57 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/relu1.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +namespace tflite { +namespace ops { +namespace custom { +namespace relu1 { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TfLiteTensor* output = GetOutput(context, node, 0); + output->type = input->type; + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +// This is derived from lite/kernels/activations.cc. +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + const int elements = NumElements(input); + const float* in = input->data.f; + const float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; ++in, ++out) { + *out = std::min(std::max(0.f, *in), 1.f); + } + return kTfLiteOk; +} + +} // namespace relu1 + +TfLiteRegistration* Register_RELU_1() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + relu1::Prepare, relu1::Eval}; + return &r; +} + +} // namespace custom +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/relu1_test.cc b/tensorflow/contrib/lite/kernels/relu1_test.cc new file mode 100644 index 0000000000..c1e0149c20 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/relu1_test.cc @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "flatbuffers/flexbuffers.h" // flatbuffers +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" + +namespace tflite { +namespace ops { +namespace custom { + +TfLiteRegistration* Register_RELU_1(); + +namespace { + +using ::testing::ElementsAreArray; + +class BaseActivationsOpModel : public SingleOpModel { + public: + explicit BaseActivationsOpModel(const TensorData& input) { + input_ = AddInput(input); + output_ = AddOutput({input.type, {}}); + flexbuffers::Builder fbb; + fbb.Map([&]() {}); + fbb.Finish(); + SetCustomOp("RELU_1", fbb.GetBuffer(), Register_RELU_1); + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(FloatActivationsOpTest, Relu1) { + FloatActivationsOpModel m(/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -2.0, 1.1, -0.1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0, 0.0, 0.2, 0.0, // + 0.3, 0.0, 1.0, 0.0, // + })); +} + +} // namespace +} // namespace custom +} // namespace ops +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} -- GitLab From 734214903cfa8df6d55d25a04748b0989428f2ee Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Tue, 4 Sep 2018 20:48:45 -0700 Subject: [PATCH 100/540] Set session_config.isolate_session_state to True for all strategies except Parameter server strategy where variables are shared across sessions. PiperOrigin-RevId: 211573447 --- .../distribute/python/collective_all_reduce_strategy.py | 2 ++ .../contrib/distribute/python/mirrored_strategy.py | 4 ++++ .../distribute/python/parameter_server_strategy.py | 2 ++ tensorflow/contrib/distribute/python/tpu_strategy.py | 9 +++++++++ 4 files changed, 17 insertions(+) diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index 4fa8aa06cc..77079d0df9 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -229,6 +229,8 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): if not session_config or not self._cluster_spec: return + session_config.isolate_session_state = True + assert self._task_type assert self._task_id is not None diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index d1235b7afb..0c6805d682 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -572,6 +572,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): task_type=None, task_id=None): del task_type, task_id + + if session_config: + session_config.isolate_session_state = True + if cluster_spec: self._initialize_multi_worker(self._num_gpus, cluster_spec) diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py index 88d7768b14..1125d027f6 100644 --- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py +++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py @@ -412,6 +412,8 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy): if not session_config or not self._cluster_spec: return + session_config.isolate_session_state = False + assert self._cluster_spec assert self._task_type assert self._task_id is not None diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 32d7444e42..27853fb317 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -311,3 +311,12 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): if self._tpu_cluster_resolver.get_master() in ('', 'local'): return '/replica:0/task:0/device:CPU:0' return '/job:tpu_worker/task:%d/device:CPU:0' % (host_id,) + + def configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): + del cluster_spec, task_type, task_id + if session_config: + session_config.isolate_session_state = True -- GitLab From 67dec723b5d4feaf36b24f164e094d1789ec3a89 Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Tue, 4 Sep 2018 20:55:31 -0700 Subject: [PATCH 101/540] Make minimum num elements of quantizable weights tensor configurable. Also minor fix of enabling quantization of shared weights if hybrid evaluation is true. PiperOrigin-RevId: 211573947 --- .../lite/tools/optimize/quantize_weights.cc | 74 ++++++++++++------- .../lite/tools/optimize/quantize_weights.h | 17 ++++- .../tools/optimize/quantize_weights_test.cc | 30 +++++++- 3 files changed, 90 insertions(+), 31 deletions(-) diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc index e0ed7c7946..e5bb3c990a 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc @@ -42,10 +42,9 @@ typedef struct { bool eval_hybrid; } TensorInfo; -// The minimum number of elements a weights array must have to be quantized -// by this transformation. -// TODO(suharshs): Make this configurable. -const int kWeightsMinSize = 1024; +// The default minimum number of elements a weights array must have to be +// quantized by this transformation. +const int kWeightsMinNumElementsDefault = 1024; // Nudge min and max so that floating point 0 falls exactly on a quantized // value, returning the nudges scale and zero_point. @@ -158,42 +157,45 @@ bool IsHybridEvaluationOp(const OperatorT* op, const BuiltinOperator& op_code) { // Returns a vector of TensorInfos for each input tensor of op that should be // quantized. -std::vector GetQuantizableTensorsFromOperator(const ModelT* model, - const OperatorT* op) { +std::vector GetQuantizableTensorsFromOperator( + const ModelT* model, const OperatorT* op, uint64_t weights_min_num_elements, + bool use_hybrid_evaluation) { SubGraphT* subgraph = model->subgraphs.at(0).get(); const BuiltinOperator op_code = model->operator_codes[op->opcode_index]->builtin_code; std::vector tensor_infos; - bool eval_hybrid = IsHybridEvaluationOp(op, op_code); + bool eval_hybrid = use_hybrid_evaluation && IsHybridEvaluationOp(op, op_code); bool skipped_tensor = false; std::vector op_input_indices = GetWeightInputIndices(op_code); for (const int32_t op_input_idx : op_input_indices) { int32_t tensor_idx = op->inputs[op_input_idx]; + TensorT* tensor = subgraph->tensors[tensor_idx].get(); // TODO(suharshs): Support shared weights, i.e. If two tensors share the // same weight array, things may break. (i.e. SSD object detection) - if (CountTensorConsumers(model, subgraph, tensor_idx) != 1) { - LOG(INFO) << "Skipping quantization of tensor that is shared between " - "multiple multiple operations."; + if (!eval_hybrid && + CountTensorConsumers(model, subgraph, tensor_idx) != 1) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " that is shared between multiple multiple operations."; skipped_tensor = true; continue; } - TensorT* tensor = subgraph->tensors[tensor_idx].get(); - if (tensor->type != TensorType_FLOAT32) { - LOG(INFO) << "Skipping quantization of tensor that is not type float."; + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " that is not type float."; skipped_tensor = true; continue; } const uint64_t num_elements = NumElements(tensor); - if (num_elements < kWeightsMinSize) { - LOG(INFO) << "Skipping quantization of tensor because it has fewer than " - << kWeightsMinSize << " elements (" << num_elements << ")."; + if (num_elements < weights_min_num_elements) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " because it has fewer than " << weights_min_num_elements + << " elements (" << num_elements << ")."; skipped_tensor = true; continue; } @@ -331,11 +333,10 @@ void MakeTensor(const string& name, const std::vector& shape, tensor->reset(tensor_raw); } -} // namespace - -TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, - const Model* input_model, - bool use_hybrid_evaluation) { +TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + bool use_hybrid_evaluation, + uint64_t weights_min_num_elements) { std::unique_ptr model; model.reset(input_model->UnPack()); @@ -352,11 +353,11 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, for (int i = 0; i < subgraph->operators.size(); ++i) { OperatorT* op = subgraph->operators[i].get(); - std::vector tensor_infos = - GetQuantizableTensorsFromOperator(model.get(), op); + std::vector tensor_infos = GetQuantizableTensorsFromOperator( + model.get(), op, weights_min_num_elements, use_hybrid_evaluation); for (const TensorInfo& tensor_info : tensor_infos) { - if (use_hybrid_evaluation && tensor_info.eval_hybrid) { + if (tensor_info.eval_hybrid) { // Quantize the tensor. TF_LITE_ENSURE_STATUS( SymmetricQuantizeTensor(model.get(), tensor_info.tensor)); @@ -399,9 +400,32 @@ TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, return kTfLiteOk; } +} // namespace + +namespace internal { +TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + bool use_hybrid_evaluation) { + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + return QuantizeWeightsInternal(builder, input_model, use_hybrid_evaluation, + kWeightsMinNumElementsDefault); +} +} // namespace internal + +TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements) { + return QuantizeWeightsInternal(builder, input_model, true, + weights_min_num_elements); +} + TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model) { - return QuantizeWeights(builder, input_model, true); + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + return QuantizeWeightsInternal(builder, input_model, true, + kWeightsMinNumElementsDefault); } } // namespace optimize diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h index 3743c0ce53..706f10b87b 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h @@ -25,6 +25,8 @@ namespace tflite { namespace optimize { // Quantizes input_model and populates the provided builder with the new model. +// By default only weights tensors weight more than 1024 elements will be +// quantized. // // A tflite::Model can be obtained from the builder with: // const uint8_t* buffer = builder->GetBufferPointer(); @@ -32,11 +34,22 @@ namespace optimize { TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model); -// Same as above, but if use_hybrid_evaluation is false, will disable using -// hybrid eval for operations that support it. +// Same as above, but only weights with greater than or equal +// weights_min_num_elements elements will be quantized. +TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements); + +namespace internal { +// If use_hybrid_evaluation is false, will disable using hybrid eval for +// operations that support it. +// +// We use this internal QuantizeWeights call to test models with hybrid +// evaluation disabled. TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, const Model* input_model, bool use_hybrid_evaluation); +} // namespace internal } // namespace optimize } // namespace tflite diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc index efaf9929e9..387b3471c2 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc @@ -76,7 +76,8 @@ class QuantizeWeightsTest : public ::testing::Test { void CheckWeights(const Model* input_model_packed, const Model* output_model_packed, - bool use_hybrid_evaluation) { + bool use_hybrid_evaluation, + uint64_t weights_min_num_elements = 1024) { std::unique_ptr input_model; input_model.reset(input_model_packed->UnPack()); @@ -113,8 +114,9 @@ class QuantizeWeightsTest : public ::testing::Test { int tensor_size = GetElementsNum(tensor); // If the tensor_size is less than 1024 we expect the tensor to remain // unquantized. - if (tensor_size < 1024) { - ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name; + if (tensor_size < weights_min_num_elements) { + ASSERT_TRUE(tensor->type == TensorType_FLOAT32) + << tensor->name << " of type " << tensor->type; const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx); // The weight tensor should not come from a dequantize op. ASSERT_TRUE(preceding_op == nullptr); @@ -183,7 +185,7 @@ TEST_F(QuantizeWeightsTest, SimpleTestWithoutHybrid) { flatbuffers::FlatBufferBuilder builder; // Disable hybrid evaluation. - EXPECT_EQ(QuantizeWeights(&builder, input_model, false), kTfLiteOk); + EXPECT_EQ(internal::QuantizeWeights(&builder, input_model, false), kTfLiteOk); const uint8_t* buffer = builder.GetBufferPointer(); const Model* output_model = GetModel(buffer); @@ -191,6 +193,26 @@ TEST_F(QuantizeWeightsTest, SimpleTestWithoutHybrid) { CheckWeights(input_model, output_model, false); } +TEST_F(QuantizeWeightsTest, SimpleTestWithWeightsMinNumElements) { + string model_path = + "third_party/tensorflow/contrib/lite/tools/optimize/testdata/" + "mobilenet_v1_0.25_128.tflite"; + std::unique_ptr input_fb = + FlatBufferModel::BuildFromFile(model_path.data()); + const Model* input_model = input_fb->GetModel(); + + flatbuffers::FlatBufferBuilder builder; + // Make weights_min_size sufficiently large such that no quantization should + // happen, i.e. the original model is the same size as the old one. + const uint64_t kWeightsMinNumElements = 1000000; + EXPECT_EQ(QuantizeWeights(&builder, input_model, kWeightsMinNumElements), + kTfLiteOk); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + CheckWeights(input_model, output_model, true, kWeightsMinNumElements); +} + // TODO(suharshs): Add tests that run the resulting model. } // namespace -- GitLab From c8be0ea9bb3a86f9bf7b1636246ecef1b9869924 Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Tue, 4 Sep 2018 21:34:48 -0700 Subject: [PATCH 102/540] In TPUStrategy.configure, copy cluster spec from cluster resolver so that the user doesn't have to pass it again to session_config. PiperOrigin-RevId: 211576564 --- tensorflow/contrib/distribute/python/tpu_strategy.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 27853fb317..4fb70ec685 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -320,3 +320,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): del cluster_spec, task_type, task_id if session_config: session_config.isolate_session_state = True + cluster_spec = self._tpu_cluster_resolver.cluster_spec() + if cluster_spec: + session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) + -- GitLab From 220a546cfae7459abf7d0e4c50bb9848fa69ff53 Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Tue, 4 Sep 2018 21:38:37 -0700 Subject: [PATCH 103/540] Allow configuring session options in keras when running with distribution strategy. PiperOrigin-RevId: 211576839 --- tensorflow/python/keras/backend.py | 18 +++++++++------ .../engine/distributed_training_utils.py | 22 +++++++++++++++++-- tensorflow/python/keras/engine/training.py | 2 ++ 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index b52ab7f05c..7768caeaf0 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -443,13 +443,7 @@ def get_session(): session = default_session else: if _SESSION is None: - if not os.environ.get('OMP_NUM_THREADS'): - config = config_pb2.ConfigProto(allow_soft_placement=True) - else: - num_thread = int(os.environ.get('OMP_NUM_THREADS')) - config = config_pb2.ConfigProto( - intra_op_parallelism_threads=num_thread, allow_soft_placement=True) - _SESSION = session_module.Session(config=config) + _SESSION = session_module.Session(config=get_default_session_config()) session = _SESSION if not _MANUAL_VAR_INIT: with session.graph.as_default(): @@ -468,6 +462,16 @@ def set_session(session): _SESSION = session +def get_default_session_config(): + if not os.environ.get('OMP_NUM_THREADS'): + config = config_pb2.ConfigProto(allow_soft_placement=True) + else: + num_thread = int(os.environ.get('OMP_NUM_THREADS')) + config = config_pb2.ConfigProto( + intra_op_parallelism_threads=num_thread, allow_soft_placement=True) + return config + + # DEVICE MANIPULATION diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py index fcb073322c..c1c4970025 100644 --- a/tensorflow/python/keras/engine/distributed_training_utils.py +++ b/tensorflow/python/keras/engine/distributed_training_utils.py @@ -17,8 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.client import session as session_module from tensorflow.python.framework import tensor_util -from tensorflow.python.keras import backend +from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import distribute as distribute_lib @@ -46,7 +47,7 @@ def set_weights(distribution_strategy, dist_model, weights): assign_ops.append(distribution_strategy.unwrap(sw.assign(w))) weights = weights[num_param:] - backend.get_session().run(assign_ops) + K.get_session().run(assign_ops) def unwrap_values(distribution_strategy, grouped_inputs, grouped_outputs, @@ -269,3 +270,20 @@ def validate_all_tensor_shapes(x, x_values): if x_shape != x_values[i].get_shape().as_list(): raise ValueError('Input tensor shapes do not match for distributed tensor' ' inputs {}'.format(x)) + + +def configure_and_create_session(distribution_strategy): + """Configure session config and create a session with it.""" + # TODO(priyag): Throw error if a session already exists. + session_config = K.get_default_session_config() + distribution_strategy.configure(session_config) + + if distribution_strategy.__class__.__name__ == 'TPUStrategy': + # TODO(priyag): Remove this workaround when Distributed Coordinator is + # integrated with keras and we can create a session from there. + master = distribution_strategy._tpu_cluster_resolver.master() # pylint: disable=protected-access + session = session_module.Session(config=session_config, target=master) + else: + session = session_module.Session(config=session_config) + + K.set_session(session) diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index e07220d15a..966b446f22 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -406,6 +406,8 @@ class Model(Network): self._distribution_strategy = distribute if self._distribution_strategy is not None: self._grouped_model = None + distributed_training_utils.configure_and_create_session( + self._distribution_strategy) if not self.built: # Model is not compilable because it does not know its number of inputs # and outputs, nor their shapes and names. We will compile after the first -- GitLab From 2b9ba9e6969e783f3727a38453749b939226b7e3 Mon Sep 17 00:00:00 2001 From: Billy Lamberta Date: Tue, 4 Sep 2018 22:40:37 -0700 Subject: [PATCH 104/540] edit --- tensorflow/python/ops/array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 48f7d3be40..e7fc4d13b2 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1275,7 +1275,7 @@ unique_with_counts.__doc__ = gen_array_ops.unique_with_counts.__doc__ def split(value, num_or_size_splits, axis=0, num=None, name="split"): """Splits a tensor into sub tensors. - If `num_or_size_splits` is an integer type, then splits `value` + If `num_or_size_splits` is an integer type, then split the `value` along dimension `axis` into `num_split` smaller tensors. Requires that `num_split` evenly divides `value.shape[axis]`. -- GitLab From 606ece2a394943e92890b82e53337cb91a749513 Mon Sep 17 00:00:00 2001 From: Michael Case Date: Tue, 4 Sep 2018 22:42:03 -0700 Subject: [PATCH 105/540] Automated rollback of commit 8cf8afefdb4c240f74a05e24246c8cd2dcce9d54 PiperOrigin-RevId: 211581486 --- tensorflow/contrib/__init__.py | 8 -------- tensorflow/python/__init__.py | 7 ------- tensorflow/python/tools/component_api_helper.py | 2 +- 3 files changed, 1 insertion(+), 16 deletions(-) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 9478e42b46..5f477a79a3 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -21,14 +21,6 @@ from __future__ import print_function import os -from tensorflow.python.tools import component_api_helper -component_api_helper.package_hook( - parent_package_str=( - "tensorflow.contrib"), - child_package_str=( - "tensorflow_estimator.contrib.estimator")) -del component_api_helper - # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import autograph from tensorflow.contrib import batching diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 4921ecc43c..a2ab63bb48 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -48,13 +48,6 @@ import numpy as np from tensorflow.python import pywrap_tensorflow -from tensorflow.python.tools import component_api_helper -component_api_helper.package_hook( - parent_package_str='tensorflow.python', - child_package_str=( - 'tensorflow_estimator.python.estimator')) -del component_api_helper - # Protocol buffers from tensorflow.core.framework.graph_pb2 import * from tensorflow.core.framework.node_def_pb2 import * diff --git a/tensorflow/python/tools/component_api_helper.py b/tensorflow/python/tools/component_api_helper.py index e261758add..988ecc61f0 100644 --- a/tensorflow/python/tools/component_api_helper.py +++ b/tensorflow/python/tools/component_api_helper.py @@ -67,7 +67,7 @@ def package_hook(parent_package_str, child_package_str, error_msg=None): """ child_pkg_path = [os.path.join(os.path.dirname(child_pkg.__file__), "..")] try: - parent_pkg.__path__ = child_pkg_path + parent_pkg.__path__ + parent_pkg.__path__ += child_pkg_path except AttributeError: parent_pkg.__path__ = child_pkg_path -- GitLab From f00855ee9c8ae8878a2feca7c2c8a23e4b9c6c11 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 5 Sep 2018 06:06:23 +0000 Subject: [PATCH 106/540] Update include order of the header files in python_op_gen_internal.cc, to conform to `Experimental clang-format Check` Signed-off-by: Yong Tang --- tensorflow/python/framework/python_op_gen_internal.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/python_op_gen_internal.cc b/tensorflow/python/framework/python_op_gen_internal.cc index 7c4941a586..f6aef5bc50 100644 --- a/tensorflow/python/framework/python_op_gen_internal.cc +++ b/tensorflow/python/framework/python_op_gen_internal.cc @@ -23,12 +23,12 @@ limitations under the License. #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_def.pb_text.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_def.pb_text.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/op_gen_lib.h" -#include "tensorflow/core/framework/tensor.pb_text.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor.pb_text.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -- GitLab From 8251fd93c0d50e737a9a083353624817b8d8f3ee Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Wed, 5 Sep 2018 06:21:20 +0000 Subject: [PATCH 107/540] Update tensorflow/core/kernels/non_max_suppression_op.cc for `Experimental clang-format Check` fix. Signed-off-by: Yong Tang --- .../core/kernels/non_max_suppression_op.cc | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index c0ea277ed5..c93f668801 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -77,8 +77,7 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context, // Return intersection-over-union overlap between boxes i and j template static inline bool IOUGreaterThanThreshold( - typename TTypes::ConstTensor boxes, int i, int j, - T iou_threshold) { + typename TTypes::ConstTensor boxes, int i, int j, T iou_threshold) { const T ymin_i = std::min(boxes(i, 0), boxes(i, 2)); const T xmin_i = std::min(boxes(i, 1), boxes(i, 3)); const T ymax_i = std::max(boxes(i, 0), boxes(i, 2)); @@ -111,8 +110,9 @@ template static inline std::function CreateIOUSuppressCheckFn( const Tensor& boxes, float threshold) { typename TTypes::ConstTensor boxes_data = boxes.tensor(); - return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1, - std::placeholders::_2, static_cast(threshold)); + return std::bind(&IOUGreaterThanThreshold, boxes_data, + std::placeholders::_1, std::placeholders::_2, + static_cast(threshold)); } static inline std::function CreateOverlapsSuppressCheckFn( @@ -224,11 +224,12 @@ class NonMaxSuppressionOp : public OpKernel { if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_); + auto suppress_check_fn = + CreateIOUSuppressCheckFn(boxes, iou_threshold_); const float score_threshold_val = std::numeric_limits::lowest(); DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + score_threshold_val, suppress_check_fn); } private: @@ -267,11 +268,12 @@ class NonMaxSuppressionV2Op : public OpKernel { if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val); + auto suppress_check_fn = + CreateIOUSuppressCheckFn(boxes, iou_threshold_val); const float score_threshold_val = std::numeric_limits::lowest(); DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + score_threshold_val, suppress_check_fn); } }; @@ -340,7 +342,7 @@ class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base { CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, - score_threshold_val_, suppress_check_fn); + score_threshold_val_, suppress_check_fn); } }; @@ -360,8 +362,8 @@ class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base { int num_valid_outputs; DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, - score_threshold_val_, suppress_check_fn, - pad_to_max_output_size_, &num_valid_outputs); + score_threshold_val_, suppress_check_fn, + pad_to_max_output_size_, &num_valid_outputs); // Allocate scalar output tensor for number of indices computed. Tensor* num_outputs_t = nullptr; @@ -417,26 +419,35 @@ class NonMaxSuppressionWithOverlapsOp : public OpKernel { CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val); DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + score_threshold_val, suppress_check_fn); } }; REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU), NonMaxSuppressionOp); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").TypeConstraint("T").Device(DEVICE_CPU), - NonMaxSuppressionV2Op); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").TypeConstraint("T").Device(DEVICE_CPU), +REGISTER_KERNEL_BUILDER( + Name("NonMaxSuppressionV2").TypeConstraint("T").Device(DEVICE_CPU), + NonMaxSuppressionV2Op); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2") + .TypeConstraint("T") + .Device(DEVICE_CPU), NonMaxSuppressionV2Op); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").TypeConstraint("T").Device(DEVICE_CPU), - NonMaxSuppressionV3Op); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").TypeConstraint("T").Device(DEVICE_CPU), +REGISTER_KERNEL_BUILDER( + Name("NonMaxSuppressionV3").TypeConstraint("T").Device(DEVICE_CPU), + NonMaxSuppressionV3Op); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3") + .TypeConstraint("T") + .Device(DEVICE_CPU), NonMaxSuppressionV3Op); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").TypeConstraint("T").Device(DEVICE_CPU), - NonMaxSuppressionV4Op); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").TypeConstraint("T").Device(DEVICE_CPU), +REGISTER_KERNEL_BUILDER( + Name("NonMaxSuppressionV4").TypeConstraint("T").Device(DEVICE_CPU), + NonMaxSuppressionV4Op); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4") + .TypeConstraint("T") + .Device(DEVICE_CPU), NonMaxSuppressionV4Op); REGISTER_KERNEL_BUILDER( -- GitLab From 6b89e9ffc991e0683cecd7a62e04cdf4a8c88356 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 4 Sep 2018 23:53:37 -0700 Subject: [PATCH 108/540] PR #21187: Added a normalization term to ctc_beam_search_decoder for tflite PiperOrigin-RevId: 211586062 --- .../experimental/kernels/ctc_beam_search.h | 18 +++++++++++++++--- .../kernels/ctc_beam_search_decoder_test.cc | 13 ++++++------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h index c658e43092..7c5099235a 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h @@ -257,6 +257,16 @@ void CTCBeamSearchDecoder::Step( } else { max_coeff = raw_input.maxCoeff(); } + + // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))). + float logsumexp = 0.0; + for (int j = 0; j < raw_input.size(); ++j) { + logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff); + } + logsumexp = Eigen::numext::log(logsumexp); + // Final normalization offset to get correct log probabilities. + float norm_offset = max_coeff + logsumexp; + const float label_selection_input_min = (label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_) : -std::numeric_limits::infinity(); @@ -288,10 +298,10 @@ void CTCBeamSearchDecoder::Step( beam_scorer_->GetStateExpansionScore(b->state, previous)); } // Plabel(l=abc @ t=6) *= P(c @ 6) - b->newp.label += raw_input(b->label) - max_coeff; + b->newp.label += raw_input(b->label) - norm_offset; } // Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6) - b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff; + b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset; // P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6) b->newp.total = LogSumExp(b->newp.blank, b->newp.label); @@ -326,6 +336,8 @@ void CTCBeamSearchDecoder::Step( const float logit = top_k ? top_k_logits[ind] : raw_input(ind); // Perform label selection: if input for this label looks very // unpromising, never evaluate it with a scorer. + // We may compare logits instead of log probabilities, + // since the difference is the same in both cases. if (logit < label_selection_input_min) { continue; } @@ -339,7 +351,7 @@ void CTCBeamSearchDecoder::Step( // Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6) beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label); float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total; - c.newp.label = logit - max_coeff + + c.newp.label = logit - norm_offset + beam_scorer_->GetStateExpansionScore(c.state, previous); // P(l=abcd @ t=6) = Plabel(l=abcd @ t=6) c.newp.total = c.newp.label; diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc index 32458305c4..aa42b495bd 100644 --- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc +++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc @@ -117,7 +117,7 @@ TEST(CTCBeamSearchTest, SimpleTest) { EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1)); // Check log probabilities output. EXPECT_THAT(m.GetLogProbabilitiesOutput(), - ElementsAreArray(ArrayFloatNear({0.32134813}))); + ElementsAreArray(ArrayFloatNear({-0.357094}))); } TEST(CTCBeamSearchTest, MultiBatchTest) { @@ -148,9 +148,8 @@ TEST(CTCBeamSearchTest, MultiBatchTest) { EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0)); EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2)); // Check log probabilities output. - EXPECT_THAT( - m.GetLogProbabilitiesOutput(), - ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572}))); + EXPECT_THAT(m.GetLogProbabilitiesOutput(), + ElementsAreArray(ArrayFloatNear({-1.88343, -1.41188, -1.20958}))); } TEST(CTCBeamSearchTest, MultiPathsTest) { @@ -188,8 +187,8 @@ TEST(CTCBeamSearchTest, MultiPathsTest) { EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2)); // Check log probabilities output. EXPECT_THAT(m.GetLogProbabilitiesOutput(), - ElementsAreArray(ArrayFloatNear( - {0.91318405, 0.9060272, 1.0780245, 0.64358956}))); + ElementsAreArray( + ArrayFloatNear({-2.65148, -2.65864, -2.17914, -2.61357}))); } TEST(CTCBeamSearchTest, NonEqualSequencesTest) { @@ -223,7 +222,7 @@ TEST(CTCBeamSearchTest, NonEqualSequencesTest) { EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1)); // Check log probabilities output. EXPECT_THAT(m.GetLogProbabilitiesOutput(), - ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005}))); + ElementsAreArray(ArrayFloatNear({-0.97322, -1.16334, -2.15553}))); } } // namespace -- GitLab From 568b763776b7890570d9f6ab9568153329079958 Mon Sep 17 00:00:00 2001 From: Michael Kuperstein Date: Wed, 5 Sep 2018 00:26:04 -0700 Subject: [PATCH 109/540] [XLA] Add some ReduceWindow tests, and make them more robust. PiperOrigin-RevId: 211588937 --- .../compiler/xla/tests/reduce_window_test.cc | 86 +++++++++++++++---- 1 file changed, 68 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 997880a018..a1001296a1 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -613,7 +613,7 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, Array4D input(param.base_bounds[0], param.base_bounds[1], param.base_bounds[2], param.base_bounds[3]); - input.FillIota(1); + input.FillRandom(0.1f, 0.1f); std::unique_ptr input_literal = LiteralUtil::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); @@ -629,7 +629,14 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, auto init_value = CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); - auto computation = param.reducer == kAdd + auto reducer = param.reducer; + if (use_bfloat16() && Product(param.window_bounds) > 128) { + // To avoid numerical issues, force the reducer to be kMax for large bf16 + // windows. + reducer = kMax; + } + + auto computation = reducer == kAdd ? CreateScalarAddComputation(FloatType(), &b) : CreateScalarMaxComputation(FloatType(), &b); ReduceWindowWithGeneralPadding( @@ -640,8 +647,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase, /*window_strides=*/param.strides, /*padding=*/padding); - CHECK(param.reducer == kAdd || param.reducer == kMax); - auto reduce_func = param.reducer == kAdd + CHECK(reducer == kAdd || reducer == kMax); + auto reduce_func = reducer == kAdd ? +[](float a, float b) { return a + b; } : +[](float a, float b) { return std::max(a, b); }; std::unique_ptr> expected = @@ -809,6 +816,22 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*pad_high=*/{1, 0, 0, 0}, /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{8, 256, 256, 3}, + /*window_bounds=*/{1, 64, 64, 1}, + /*strides=*/{1, 64, 64, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 0, 2, 1}, + /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 8, 64}, + /*window_bounds=*/{112, 112, 1, 8}, + /*strides=*/{112, 112, 1, 8}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, }; INSTANTIATE_TEST_CASE_P( @@ -930,6 +953,27 @@ struct R3ReduceWindowTestData { {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2}, /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{95, 202, 251}, /*window_bounds=*/{95, 202, 251}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{999, 57, 3}, /*window_bounds=*/{999, 57, 3}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{178, 302, 64}, /*window_bounds=*/{178, 302, 64}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{63, 261, 257}, /*window_bounds=*/{63, 261, 257}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{9999, 1, 1}, /*window_bounds=*/{9999, 1, 1}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{10003, 10, 5}, /*window_bounds=*/{9999, 7, 3}, + /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, }; string R3ReduceWindowTestDataToString( @@ -956,35 +1000,42 @@ class R3ReduceWindowTest : public ReduceWindowTestBase, R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } }; -TEST_P(R3ReduceWindowTest, Add) { +TEST_P(R3ReduceWindowTest, DoIt) { XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); - CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; Array3D input(param.base_bounds[0], param.base_bounds[1], - param.base_bounds[2], 1.0f); + param.base_bounds[2]); + input.FillRandom(0.1f, 0.1f); std::unique_ptr input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); + auto reducer = param.reducer; + if (use_bfloat16()) { + input_literal = LiteralUtil::ConvertF32ToBF16(*input_literal); + if (Product(param.window_bounds) > 128) { + // To avoid numerical issues, force the reducer to be kMax for large bf16 + // windows. + reducer = kMax; + } + } - XlaOp parameter; - auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", - &b, ¶meter); + XlaOp parameter = Parameter(&b, 0, input_literal->shape(), "input"); auto init_value = CreateConstantFromLiteral(*LiteralUtil::CreateR0(kInitValue), &b); + + auto computation = reducer == kAdd + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); + ReduceWindow(/*operand=*/parameter, /*init_value=*/init_value, - /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, /*padding=*/param.padding); - auto expected = ReferenceUtil::ReduceWindow3DAdd( - /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); - - ComputeAndCompareLiteral(&b, *LiteralUtil::CreateFromArray(*expected), - {input_arg.get()}, DefaultErrorSpec()); + ComputeAndCompare(&b, {std::move(*input_literal)}, DefaultErrorSpec()); } INSTANTIATE_TEST_CASE_P( @@ -1093,7 +1144,6 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, void DoIt() { XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); - CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); -- GitLab From c7bd1589d08e84ca215b3c8c4dc3023986522ef7 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 5 Sep 2018 01:00:25 -0700 Subject: [PATCH 110/540] Add support for grouped convolutions to the HloEvaluator. Add a missing check to InferConvolveShape(), the output feature dimension needs to be divisible by feature_group_count. Also fix some tests which took a const reference to the return value of a function which doesn't return a reference. PiperOrigin-RevId: 211592011 --- .../xla/service/hlo_evaluator_test.cc | 75 +++++++++++++++++-- .../xla/service/hlo_evaluator_typed_visitor.h | 36 ++++++++- .../compiler/xla/service/shape_inference.cc | 10 +++ 3 files changed, 112 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 3ab8ef18dd..f586f253da 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -798,7 +798,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) { dnums.set_kernel_input_feature_dimension(1); dnums.add_kernel_spatial_dimensions(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); @@ -853,7 +853,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); @@ -937,7 +937,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(1); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); @@ -1015,7 +1015,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { dnums.add_kernel_spatial_dimensions(3); dnums.add_kernel_spatial_dimensions(1); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); @@ -1075,7 +1075,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); @@ -1139,7 +1139,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); @@ -1211,7 +1211,7 @@ TEST_P(HloEvaluatorTest, ConvolutionDimensionNumbers dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(2); - const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, DefaultPrecisionConfig(2))); @@ -1236,6 +1236,67 @@ TEST_P(HloEvaluatorTest, EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } +TEST_P(HloEvaluatorTest, Conv2DGroupedConvolution) { + HloComputation::Builder b(TestName()); + std::vector input_dims = {1, 2, 2, 4}; + std::vector filter_dims = {2, 2, 2, 8}; + Shape input_shape = ShapeUtil::MakeShapeWithType(input_dims); + Shape filter_shape = ShapeUtil::MakeShapeWithType(filter_dims); + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + Window window; + WindowDimension dim; + dim.set_size(2); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + std::iota(input_elems.begin(), input_elems.end(), -7); + auto input_r1 = LiteralUtil::CreateR1(input_elems); + auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input_r4))); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + std::iota(filter_elems.begin(), filter_elems.end(), -31); + auto filter_r1 = LiteralUtil::CreateR1(filter_elems); + auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(filter_r4))); + + Shape shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 8}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, + /*feature_group_count=*/2, window, dnums, DefaultPrecisionConfig(2))); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); + + Array4D expected_array(1, 1, 1, 8); + expected_array.FillWithYX( + Array2D({{668, 664, 660, 656, 668, 680, 692, 704}})); + auto expected = LiteralUtil::CreateR4FromArray4D(expected_array); + EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); +} + class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {}; // Tests that Reduce doesn't lose precision when adding many numbers (because diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index dc16a84246..6a09bb08f4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1047,9 +1047,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { auto lhs_literal_data = lhs_literal.data(); auto rhs_literal_data = rhs_literal.data(); + int64 feature_group_count = conv->feature_group_count(); + auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window, &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data, - rhs_literal_data](absl::Span out_index) { + rhs_literal_data, + feature_group_count](absl::Span out_index) { // Dimension number applicable for input (lhs). const int64 input_batch_dim = dnums.input_batch_dimension(); const int64 input_z_dim = dnums.input_feature_dimension(); @@ -1061,6 +1064,8 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { const int64 output_z_dim = dnums.output_feature_dimension(); const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); + const int64 output_z_size = + ShapeUtil::GetDimension(rhs_shape, kernel_output_z_dim); ElementwiseT result_val = static_cast(0); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(), @@ -1069,6 +1074,33 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Convolve input feature with kernel. do { for (int64 iz = 0; iz < z_size; ++iz) { + int64 rhs_iz = iz; + // Handle grouped convolutions. + if (feature_group_count > 1) { + // The size of a feature group. + int64 feature_group_size = z_size / feature_group_count; + rhs_iz = iz % feature_group_size; + + // The output feature dimension is a concatenation of convolution + // results from the different groups. + int64 output_feature_group_size = + output_z_size / feature_group_count; + + // Calculate the group index to which the current input feature + // index belongs. + int64 input_group_index = iz / feature_group_size; + + // Calculate the group index to which the current output index + // belongs. + int64 output_group_index = + out_index[output_z_dim] / output_feature_group_size; + if (input_group_index != output_group_index) { + // If the current output index does not belong to the current + // feature group, skip it. + continue; + } + } + int64 lhs_linear_index = 0; lhs_linear_index += out_index[output_batch_dim] * lhs_dim_multipliers[input_batch_dim]; @@ -1077,7 +1109,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { int64 rhs_linear_index = 0; rhs_linear_index += out_index[output_z_dim] * rhs_dim_multipliers[kernel_output_z_dim]; - rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim]; + rhs_linear_index += rhs_iz * rhs_dim_multipliers[kernel_input_z_dim]; // Find corresponding spatial dimension index for input (lhs). for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 7758a5dd4d..74bdf2a2e3 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1672,6 +1672,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), dnums.DebugString()); } + if (kernel_output_features % feature_group_count > 0) { + return InvalidArgument( + "Expected output feature dimension (value %d) to be divisible by " + "feature_group_count (value %d); " + "got (%s, %s)\n" + "Dimension numbers: {%s}.", + kernel_output_features, feature_group_count, + ShapeUtil::HumanString(lhs), ShapeUtil::HumanString(rhs), + dnums.DebugString()); + } std::vector window_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { window_dims[i] = window.dimensions(i).size(); -- GitLab From 32e96b1dc588cccf4e008259f831c4e50d948dc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Wed, 5 Sep 2018 15:46:09 +0800 Subject: [PATCH 111/540] ENH: add gradient for broadcast_to --- .../kernel_tests/broadcast_to_ops_test.py | 20 +++++++++++++++++++ tensorflow/python/ops/array_grad.py | 19 ++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py index 6a1bd958ba..282a619094 100644 --- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py +++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.platform import test as test_lib @@ -81,5 +82,24 @@ class BroadcastToTest(test_util.TensorFlowTestCase): # check shape inference when shape input is constant self.assertAllEqual(shape, v_np.shape) + def testGradient(self): + x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 4, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientForScalar(self): + x = constant_op.constant(1, dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 4, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + if __name__ == "__main__": test_lib.main() diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 6ae869b89e..ade86e85bf 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -805,3 +805,22 @@ def _ScatterNdNonAliasingAddGrad(op, grad): indices = op.inputs[1] updates_grad = array_ops.gather_nd(grad, indices) return [grad, None, updates_grad] + + +@ops.RegisterGradient("BroadcastTo") +def _BroadcastToGrad(op, grad): + input_value = op.inputs[0] + broadcast_shape = op.inputs[1] + # Assign ids for each position in input_value. + input_value_shape = array_ops.shape(input_value) + input_value_size = array_ops.size(input_value) + ids = array_ops.reshape(math_ops.range(input_value_size), input_value_shape) + broadcast_ids = array_ops.broadcast_to(ids, broadcast_shape) + # Group by ids and sum its gradients. + grad_flatten = array_ops.reshape(grad, [-1]) + broadcast_ids_flatten = array_ops.reshape(broadcast_ids, [-1]) + updates_grad_flatten = math_ops.unsorted_segment_sum(grad_flatten, + broadcast_ids_flatten, + input_value_size) + updates_grad = array_ops.reshape(updates_grad_flatten, input_value_shape) + return [updates_grad, None] -- GitLab From bd8df09cbd43c7244b4b66c62531eae557c1c468 Mon Sep 17 00:00:00 2001 From: Tom Hennigan Date: Wed, 5 Sep 2018 01:07:14 -0700 Subject: [PATCH 112/540] Update `make_tensor_proto` docs to reference public symbol for `make_ndarray`. PiperOrigin-RevId: 211592901 --- tensorflow/python/framework/tensor_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index b14290c203..26170b000d 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -367,7 +367,7 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False): A `TensorProto`. Depending on the type, it may contain data in the "tensor_content" attribute, which is not directly useful to Python programs. To access the values you should convert the proto back to a numpy ndarray - with `tensor_util.MakeNdarray(proto)`. + with `tf.make_ndarray(proto)`. If `values` is a `TensorProto`, it is immediately returned; `dtype` and `shape` are ignored. -- GitLab From f15e8613aa42f7f2b1439c652a465438553df219 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 02:02:41 -0700 Subject: [PATCH 113/540] compat: Update forward compatibility horizon to 2018-09-05 PiperOrigin-RevId: 211598349 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 459f494b48..586f4c6936 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -26,7 +26,7 @@ import datetime from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 4) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 5) @tf_export("compat.forward_compatible") -- GitLab From 858f4672e25825bc5e091a79fd4234f1968a278d Mon Sep 17 00:00:00 2001 From: Yu-Cheng Ling Date: Wed, 5 Sep 2018 06:01:51 -0700 Subject: [PATCH 114/540] Minimum change for generating Eager ops with Toco. PiperOrigin-RevId: 211621189 --- tensorflow/contrib/lite/toco/args.h | 4 ++ .../contrib/lite/toco/import_tensorflow.cc | 10 +++- .../contrib/lite/toco/import_tensorflow.h | 5 ++ tensorflow/contrib/lite/toco/tflite/export.cc | 52 ++++++++++++------- tensorflow/contrib/lite/toco/tflite/export.h | 51 ++++++++++++++---- .../contrib/lite/toco/tflite/export_test.cc | 9 ++-- .../contrib/lite/toco/tflite/operator.cc | 39 +++++++++++--- .../contrib/lite/toco/tflite/operator.h | 8 ++- .../contrib/lite/toco/toco_cmdline_flags.cc | 18 ++++++- tensorflow/contrib/lite/toco/toco_flags.proto | 15 +++++- tensorflow/contrib/lite/toco/toco_tooling.cc | 24 +++++++-- 11 files changed, 183 insertions(+), 52 deletions(-) diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 84f71dc7a7..f14dbc258b 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -247,6 +247,10 @@ struct ParsedTocoFlags { Arg allow_nudging_weights_to_use_fast_gemm_kernel = Arg(false); Arg dedupe_array_min_size_bytes = Arg(64); Arg split_tflite_lstm_inputs = Arg(true); + // WARNING: Experimental interface, subject to change + Arg allow_eager_ops = Arg(false); + // WARNING: Experimental interface, subject to change + Arg force_eager_ops = Arg(false); }; } // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index cb6da21039..9bc23c4b3c 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -2061,8 +2061,14 @@ std::unique_ptr ImportTensorFlowGraphDef( } Model* model = new Model; - const internal::ConverterMapType& converter_map = - internal::GetTensorFlowNodeConverterMap(); + internal::ConverterMapType converter_map; + + // This is used for the TFLite "Full Eager Mode" conversion. All the ops are + // imported as `TensorFlowUnsupportedOperator`, and later all these ops are + // converted to TFLite Eager ops. + if (!tf_import_flags.import_all_ops_as_unsupported) { + converter_map = internal::GetTensorFlowNodeConverterMap(); + } for (auto node : inlined_graph.node()) { StripZeroOutputIndexFromInputs(&node); diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h index 2177872334..7db23f2d44 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.h +++ b/tensorflow/contrib/lite/toco/import_tensorflow.h @@ -27,6 +27,11 @@ struct TensorFlowImportFlags { // If true, control dependencies will be dropped immediately // during the import of the TensorFlow GraphDef. bool drop_control_dependency = false; + + // Do not recognize any op and import all ops as + // `TensorFlowUnsupportedOperator`. This is used to populated with the + // `force_eager_ops` flag. + bool import_all_ops_as_unsupported = false; }; std::unique_ptr ImportTensorFlowGraphDef( diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index c79469f59b..fee10b1dff 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -49,12 +49,21 @@ namespace { details::OperatorKey GetOperatorKey( const ::toco::Operator& op, - const std::map>& ops_by_type) { + const std::map>& ops_by_type, + bool allow_eager_ops) { string custom_code; if (op.type == OperatorType::kUnsupported) { const TensorFlowUnsupportedOperator& unsupported_op = static_cast(op); - custom_code = unsupported_op.tensorflow_op; + + // TODO(b/113715895): When `allow_eager_ops` is on, for now there's no way + // to populate a regular custom op. We need to find a way to fix this. + if (allow_eager_ops) { + custom_code = string(::tflite::kEagerCustomCodePrefix) + + unsupported_op.tensorflow_op; + } else { + custom_code = unsupported_op.tensorflow_op; + } } int version = 1; if (ops_by_type.count(op.type) != 0) { @@ -91,11 +100,12 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) { void LoadOperatorsMap( const Model& model, OperatorsMap* operators_map, - const std::map>& ops_by_type) { + const std::map>& ops_by_type, + bool allow_eager_ops) { // First find a list of unique operator types. std::set keys; for (const auto& op : model.operators) { - keys.insert(GetOperatorKey(*op, ops_by_type)); + keys.insert(GetOperatorKey(*op, ops_by_type, allow_eager_ops)); } // Now assign indices to them and fill in the map. int index = 0; @@ -189,7 +199,7 @@ Offset>> ExportOperatorCodes( const Model& model, const std::map>& ops_by_type, const details::OperatorsMap& operators_map, FlatBufferBuilder* builder, - std::set* error_summary) { + std::set* error_summary, const ExportParams& params) { // Map from operator name to TF Lite enum value, for all builtins. std::map builtin_ops; for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) { @@ -205,7 +215,8 @@ Offset>> ExportOperatorCodes( std::map> ordered_opcodes; for (const auto& op : model.operators) { - const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type); + const details::OperatorKey operator_key = + GetOperatorKey(*op, ops_by_type, params.allow_eager_ops); int op_index = operators_map.at(operator_key); int op_version = operator_key.version; @@ -252,7 +263,7 @@ Offset>> ExportOperators( const std::map>& ops_by_type, const details::OperatorsMap& operators_map, const details::TensorsMap& tensors_map, FlatBufferBuilder* builder, - std::set* variable_tensor_indices) { + std::set* variable_tensor_indices, const ExportParams& params) { variable_tensor_indices->clear(); // The operators are in execution order, so we just follow tf.mini order. @@ -269,7 +280,8 @@ Offset>> ExportOperators( outputs.push_back(tensors_map.at(output)); } - int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type)); + int op_index = operators_map.at( + GetOperatorKey(*op, ops_by_type, params.allow_eager_ops)); auto tflite_op_it = ops_by_type.find(op->type); BaseOperator* tflite_op = tflite_op_it == ops_by_type.end() @@ -320,16 +332,15 @@ Offset>> ExportBuffers( return builder->CreateVector(buffer_vector); } -void Export(const Model& model, bool allow_custom_ops, bool quantize_weights, - string* output_file_contents) { - const auto ops_by_type = BuildOperatorByTypeMap(); - Export(model, allow_custom_ops, quantize_weights, output_file_contents, - ops_by_type); +void Export(const Model& model, string* output_file_contents, + const ExportParams& params) { + const auto ops_by_type = BuildOperatorByTypeMap(params.allow_eager_ops); + Export(model, output_file_contents, params, ops_by_type); } void Export( - const Model& model, bool allow_custom_ops, bool quantize_weights, - string* output_file_contents, + const Model& model, string* output_file_contents, + const ExportParams& params, const std::map>& ops_by_type) { flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240); @@ -337,7 +348,8 @@ void Export( details::LoadTensorsMap(model, &tensors_map); details::OperatorsMap operators_map; - details::LoadOperatorsMap(model, &operators_map, ops_by_type); + details::LoadOperatorsMap(model, &operators_map, ops_by_type, + params.allow_eager_ops); std::vector buffers_to_write; Array empty_array; @@ -345,7 +357,7 @@ void Export( std::set error_summary; auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, - &builder, &error_summary); + &builder, &error_summary, params); for (const auto& op : model.operators) { if (op->type == OperatorType::kFakeQuant) { @@ -355,7 +367,7 @@ void Export( "for --std_values and --mean_values."; } } - if (!allow_custom_ops && !error_summary.empty()) { + if (!params.allow_custom_ops && !error_summary.empty()) { // Remove ExpandDims and ReorderAxes from unimplemented list unless they // compose the list. Both ops are removed during graph transformations. // However, if an op is unimplemented earlier in the model, the graph @@ -383,7 +395,7 @@ void Export( std::set variable_tensor_indices; auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map, - &builder, &variable_tensor_indices); + &builder, &variable_tensor_indices, params); auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write, variable_tensor_indices); @@ -402,7 +414,7 @@ void Export( builder.CreateVector(subgraphs), description, buffers); ::tflite::FinishModelBuffer(builder, new_model_location); - if (quantize_weights) { + if (params.quantize_weights) { // Call the quantize_weights tool. LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. " "dump_graphviz will only output the model before this " diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 915d5dd3d6..b070a38768 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -23,22 +23,54 @@ namespace toco { namespace tflite { +// The parameters for exporting a TFLite model. +struct ExportParams { + bool allow_custom_ops = false; + bool allow_eager_ops = false; + bool quantize_weights = false; +}; + // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the // result in the given string. -void Export(const Model& model, bool allow_custom_ops, bool quantize_weights, - string* output_file_contents); +void Export(const Model& model, string* output_file_contents, + const ExportParams& params); + +// Export API with custom TFLite operator mapping. +void Export( + const Model& model, string* output_file_contents, + const ExportParams& params, + const std::map>& ops_by_type); -// This if backward-compatibility. +// This is for backward-compatibility. // TODO(ycling): Remove the deprecated entry functions. -inline void Export(const Model& model, string* output_file_contents) { - Export(model, true, false, output_file_contents); +inline void Export(const Model& model, bool allow_custom_ops, + bool quantize_weights, string* output_file_contents) { + ExportParams params; + params.allow_custom_ops = allow_custom_ops; + params.quantize_weights = quantize_weights; + Export(model, output_file_contents, params); } -// Export API with custom TFLite operator mapping. -void Export( +// This is for backward-compatibility. +// TODO(ycling): Remove the deprecated entry functions. +inline void Export( const Model& model, bool allow_custom_ops, bool quantize_weights, string* output_file_contents, - const std::map>& ops_by_type); + const std::map>& ops_by_type) { + ExportParams params; + params.allow_custom_ops = allow_custom_ops; + params.quantize_weights = quantize_weights; + Export(model, output_file_contents, params, ops_by_type); +} + +// This is for backward-compatibility. +// TODO(ycling): Remove the deprecated entry functions. +inline void Export(const Model& model, string* output_file_contents) { + ExportParams params; + params.allow_custom_ops = true; + Export(model, output_file_contents, params); + Export(model, true, false, output_file_contents); +} namespace details { @@ -88,7 +120,8 @@ using OperatorsMap = std::unordered_map; void LoadTensorsMap(const Model& model, TensorsMap* tensors_map); void LoadOperatorsMap( const Model& model, OperatorsMap* operators_map, - const std::map>& ops_by_type); + const std::map>& ops_by_type, + bool allow_eager_ops); } // namespace details } // namespace tflite diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc index 4994ea30de..8d4d197c46 100644 --- a/tensorflow/contrib/lite/toco/tflite/export_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc @@ -105,7 +105,8 @@ TEST_F(ExportTest, LoadOperatorsMap) { details::OperatorsMap operators; const auto ops_by_type = BuildOperatorByTypeMap(); - details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + // TODO(ycling): Add a test for allow_eager_ops. + details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]); EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]); EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]); @@ -253,7 +254,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) { details::OperatorsMap operators; const auto ops_by_type = BuildFakeOperatorByTypeMap(); - details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(1, operators.size()); EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); @@ -264,7 +265,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) { details::OperatorsMap operators; const auto ops_by_type = BuildFakeOperatorByTypeMap(); - details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(1, operators.size()); EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2))); @@ -276,7 +277,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) { details::OperatorsMap operators; const auto ops_by_type = BuildFakeOperatorByTypeMap(); - details::LoadOperatorsMap(input_model_, &operators, ops_by_type); + details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false); EXPECT_EQ(2, operators.size()); EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1))); diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index a314c8d53a..eb0f7c443a 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1149,7 +1149,9 @@ class Unpack : public BuiltinOperator Deserialize( const BuiltinOptions* builtin_options, const CustomOptions* custom_options) const override { + // Deserializing Eager ops doesn't work now. + // TODO(ycling): Revisit and decide if we should fix the flow for importing + // TFLite models with Eager ops. auto op = absl::make_unique(); if (custom_options) { auto flexbuffer_map = @@ -1185,6 +1190,16 @@ class TensorFlowUnsupported : public BaseOperator { return std::unique_ptr(); } + if (allow_eager_ops_) { + fbb->Vector([&]() { + fbb->String(node_def.op()); + fbb->String(op.tensorflow_node_def); + }); + fbb->Finish(); + LOG(INFO) << "Writing eager op: " << node_def.op(); + return std::unique_ptr(fbb.release()); + } + bool has_valid_attr = false; size_t map_start = fbb->StartMap(); for (const auto& pair : node_def.attr()) { @@ -1285,11 +1300,15 @@ class TensorFlowUnsupported : public BaseOperator { // custom ops. return 1; } + + private: + const bool allow_eager_ops_; }; namespace { // Build a vector containing all the known operators. -std::vector> BuildOperatorList() { +std::vector> BuildOperatorList( + bool allow_eager_ops = false) { std::vector> ops; using tensorflow::MakeUnique; // Builtin Operators. @@ -1400,8 +1419,8 @@ std::vector> BuildOperatorList() { MakeUnique("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); ops.push_back(MakeUnique( "CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder)); - ops.push_back(MakeUnique("TENSORFLOW_UNSUPPORTED", - OperatorType::kUnsupported)); + ops.push_back(MakeUnique( + "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_eager_ops)); // There operators are supported by Toco, but not by TF Lite, and has no // attributes. @@ -1474,10 +1493,12 @@ std::vector> BuildOperatorList() { } } // namespace -std::map> BuildOperatorByTypeMap() { +std::map> BuildOperatorByTypeMap( + bool allow_eager_ops) { std::map> result; - std::vector> ops = BuildOperatorList(); + std::vector> ops = + BuildOperatorList(allow_eager_ops); for (auto& op : ops) { result[op->type()] = std::move(op); } @@ -1485,10 +1506,12 @@ std::map> BuildOperatorByTypeMap() { return result; } -std::map> BuildOperatorByNameMap() { +std::map> BuildOperatorByNameMap( + bool allow_eager_ops) { std::map> result; - std::vector> ops = BuildOperatorList(); + std::vector> ops = + BuildOperatorList(allow_eager_ops); for (auto& op : ops) { result[op->name()] = std::move(op); } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h index d9ea23edf2..702fb28ea6 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.h +++ b/tensorflow/contrib/lite/toco/tflite/operator.h @@ -26,11 +26,15 @@ namespace tflite { class BaseOperator; // Return a map contained all know TF Lite Operators, keyed by their names. -std::map> BuildOperatorByNameMap(); +// TODO(ycling): The pattern to propagate parameters (e.g. allow_eager_ops) +// is ugly here. Consider refactoring. +std::map> BuildOperatorByNameMap( + bool allow_eager_ops = false); // Return a map contained all know TF Lite Operators, keyed by the type of // their tf.mini counterparts. -std::map> BuildOperatorByTypeMap(); +std::map> BuildOperatorByTypeMap( + bool allow_eager_ops = false); // These are the flatbuffer types for custom and builtin options. using CustomOptions = flatbuffers::Vector; diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index f83a290195..b6aebc0470 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -165,7 +165,13 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.post_training_quantize.default_value(), "Boolean indicating whether to quantize the weights of the " "converted float model. Model size will be reduced and there will " - "be latency improvements (at the cost of accuracy).")}; + "be latency improvements (at the cost of accuracy)."), + // WARNING: Experimental interface, subject to change + Flag("allow_eager_ops", parsed_flags.allow_eager_ops.bind(), + parsed_flags.allow_eager_ops.default_value(), ""), + // WARNING: Experimental interface, subject to change + Flag("force_eager_ops", parsed_flags.force_eager_ops.bind(), + parsed_flags.force_eager_ops.default_value(), "")}; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); if (asked_for_help) { @@ -260,6 +266,16 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone); READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone); READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone); + READ_TOCO_FLAG(allow_eager_ops, FlagRequirement::kNone); + READ_TOCO_FLAG(force_eager_ops, FlagRequirement::kNone); + + if (parsed_toco_flags.force_eager_ops.value() && + !parsed_toco_flags.allow_eager_ops.value()) { + // TODO(ycling): Consider to enforce `allow_eager_ops` when + // `force_eager_ops` is true. + LOG(WARNING) << "--force_eager_ops should always be used with " + "--allow_eager_ops."; + } // Deprecated flag handling. if (parsed_toco_flags.input_type.specified()) { diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index c1dd621429..53d60fed05 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 27. +// Next ID to use: 29. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -189,4 +189,17 @@ message TocoFlags { // model. Model size will be reduced and there will be latency improvements // (at the cost of accuracy). optional bool post_training_quantize = 26 [default = false]; + + // When enabled, unsupported ops will be converted to TFLite Eager ops. + // TODO(ycling): Consider to rename the following 2 flags and don't call it + // "Eager". + // `allow_eager_ops` should always be used with `allow_custom_ops`. + // WARNING: Experimental interface, subject to change + optional bool allow_eager_ops = 27 [default = false]; + + // When enabled, all TensorFlow ops will be converted to TFLite Eager + // ops directly. This will force `allow_eager_ops` to true. + // `force_eager_ops` should always be used with `allow_eager_ops`. + // WARNING: Experimental interface, subject to change + optional bool force_eager_ops = 28 [default = false]; } diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 7db7acb44d..a7c17156b1 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -197,6 +197,10 @@ std::unique_ptr Import(const TocoFlags& toco_flags, toco_flags.has_drop_control_dependency() ? toco_flags.drop_control_dependency() : (toco_flags.output_format() != TENSORFLOW_GRAPHDEF); + + tf_import_flags.import_all_ops_as_unsupported = + toco_flags.force_eager_ops(); + model = ImportTensorFlowGraphDef(model_flags, tf_import_flags, input_file_contents); break; @@ -397,11 +401,21 @@ void Export(const TocoFlags& toco_flags, const Model& model, case TENSORFLOW_GRAPHDEF: ExportTensorFlowGraphDef(model, output_file_contents); break; - case TFLITE: - toco::tflite::Export(model, allow_custom_ops, - toco_flags.post_training_quantize(), - output_file_contents); - break; + case TFLITE: { + toco::tflite::ExportParams params; + + // Always allow custom ops when eager ops are allowed. + if (toco_flags.force_eager_ops() || toco_flags.allow_eager_ops()) { + params.allow_eager_ops = true; + params.allow_custom_ops = true; + } else if (allow_custom_ops) { + params.allow_custom_ops = true; + } + + params.quantize_weights = toco_flags.post_training_quantize(); + + toco::tflite::Export(model, output_file_contents, params); + } break; case GRAPHVIZ_DOT: DumpGraphviz(model, output_file_contents); break; -- GitLab From ffaab58cad72e177ada0e7d1d3724de63032928d Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 5 Sep 2018 06:07:52 -0700 Subject: [PATCH 115/540] Simplify analysis in funcitonalize_cond by splitting CondState. * Split CondState into CondState (which corresponds to scope previously) and AncestorState (which tracks which switch/merge nodes are an ancestor of a ndoe). Previously CondState tracked both but that resulted in difficult to follow meet rules. Instead by splitting these out the meet for merge and non-merge are straight forward set operations. The ancestor relation is similarly easy to compute along with CondState computation. * Enhance the redundant switch checking: previously we only considered the predicates but %s=switch(val=%P, pred=switch(%P_1, %P):then) is also redundant as if %P is true then %s:else is dead. * Enhance in-edge testing to insert a switch if a value from an outer context is consumed inside an inner context. * Rename CondStateMap to StateMap to match new usage. PiperOrigin-RevId: 211622021 --- .../compiler/tf2xla/functionalize_cond.cc | 787 +++++++++--------- .../compiler/tf2xla/functionalize_cond.h | 166 ++-- .../tf2xla/functionalize_cond_test.cc | 118 +-- 3 files changed, 481 insertions(+), 590 deletions(-) diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index b5667ca0d3..e2affee51f 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -40,26 +40,11 @@ using xla::StatusOr; namespace tensorflow { namespace functionalize_cond { -string DebugString(const CondStateMap::CondNode& node) { - return node.ToString(); -} - // TODO(jpienaar): Move to OutputTensor. string DebugString(const OutputTensor& tensor) { return strings::StrCat(tensor.node->name(), ":", tensor.index); } -string DebugString(CondStateMap::CondId cond_state) { - if (cond_state == nullptr || cond_state->empty()) return "[]"; - return strings::StrCat( - "[", - absl::StrJoin(*cond_state, ", ", - [](string* output, const CondStateMap::CondNode& node) { - strings::StrAppend(output, node.ToString()); - }), - "]"); -} - string Branch_Name(BranchType b) { switch (b) { case BranchType::kElseBranch: @@ -73,6 +58,24 @@ string Branch_Name(BranchType b) { } } +string DebugString(StateMap::CondId cond_state) { + if (cond_state == nullptr || cond_state->empty()) return "{}"; + using value_type = StateMap::CondState::value_type; + return strings::StrCat( + "{", + absl::StrJoin(*cond_state, ", ", + [](string* output, const value_type& pred_branch) { + const OutputTensor& pred = pred_branch.first; + const BranchType& branch = pred_branch.second; + if (branch == BranchType::kNeither) + strings::StrAppend(output, "d"); + else + strings::StrAppend(output, "s(", DebugString(pred), ",", + Branch_Name(branch), ")"); + }), + "}"); +} + // Returns the predicate of a switch. Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { const Edge* pred_edge; @@ -86,64 +89,65 @@ Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { return Status::OK(); } -CondStateMap::CondNode::CondNode(Type type, Node* switch_node, - BranchType branch) - : type(type), branch(branch) { - if (type == Type::kSwitch) { - TF_CHECK_OK(GetSwitchPredicate(*switch_node, &predicate)); - } -} - -string CondStateMap::CondNode::ToString() const { - switch (type) { - case Type::kSwitch: - return strings::StrCat("s(", DebugString(predicate), ",", - Branch_Name(branch), ")"); - case Type::kMerge: - return "m"; - case Type::kDead: - return "d"; - } +Status GetSwitchValue(const Node& switch_node, OutputTensor* val) { + const Edge* val_edge; + TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge)); + *val = OutputTensor(val_edge->src(), val_edge->src_output()); + return Status::OK(); } -bool CondStateMap::CondNode::operator==(const CondNode& other) const { - if (type != Type::kSwitch) return type == other.type; - return type == other.type && predicate == other.predicate && - branch == other.branch; +bool StateMap::OutputTensorLess::operator()(const OutputTensor& lhs, + const OutputTensor& rhs) const { + return (lhs.node->id() < rhs.node->id()) || + (lhs.node->id() == rhs.node->id() && lhs.index < rhs.index); } -bool CondStateMap::CondNode::operator!=(const CondNode& other) const { - return !(*this == other); -} +struct CondStateLess { + bool operator()(const StateMap::CondState::value_type& lhs, + const StateMap::CondState::value_type& rhs) const { + if (StateMap::OutputTensorLess().operator()(lhs.first, rhs.first)) + return true; + if (lhs.first.node->id() == rhs.first.node->id() && + lhs.first.index == rhs.first.index) + return lhs.second < rhs.second; + return false; + } +}; -CondStateMap::CondStateMap(Graph* graph) { +StateMap::StateMap(Graph* graph) { node_to_condid_map_.resize(graph->num_node_ids()); + node_to_ancestorid_map_.resize(graph->num_node_ids()); // Initialize the dead state (empty state is designated with a nullptr). - dead_id_ = GetUniqueId({CondNode(CondStateMap::CondNode::Type::kDead)}); + dead_id_ = GetCondId( + {std::make_pair(OutputTensor(nullptr, -1), BranchType::kNeither)}); } -bool CondStateMap::IsDead(CondStateMap::CondId id) const { - return id == dead_id_; -} +bool StateMap::IsDead(StateMap::CondId id) const { return id == dead_id_; } -bool CondStateMap::IsEmpty(CondStateMap::CondId id) const { - return id == nullptr; -} +bool StateMap::IsEmpty(StateMap::CondId id) const { return id == nullptr; } -size_t CondStateMap::CondHash::operator()( - const CondStateMap::CondNode& item) const { - return Hash64Combine(Hash64Combine(OutputTensor::Hash()(item.predicate), - hash()(item.branch)), - hash()(item.type)); +size_t StateMap::Hash::operator()(const StateMap::CondState& map) const { + if (map.empty()) return 0; + // Compute hash of the front element. + auto it = map.begin(); + size_t h = Hash64Combine(OutputTensor::Hash()(it->first), + hash()(it->second)); + for (++it; it != map.end(); ++it) { + // Combine the has with the different elements in the map. + h = Hash64Combine(h, Hash64Combine(OutputTensor::Hash()(it->first), + hash()(it->second))); + } + return h; } -size_t CondStateMap::CondHash::operator()( - const CondStateMap::CondState& vec) const { - if (vec.empty()) return 0; - size_t h = (*this)(vec.front()); - auto it = vec.begin(); - for (++it; it != vec.end(); ++it) { - h = Hash64Combine(h, (*this)(*it)); +size_t StateMap::Hash::operator()(const StateMap::AncestorState& map) const { + if (map.empty()) return 0; + // Compute hash of the front element. + auto it = map.begin(); + size_t h = hash()(*it); + for (++it; it != map.end(); ++it) { + // Combine the has with the different elements in the map. + h = Hash64Combine(h, hash()(*it)); } return h; } @@ -176,49 +180,71 @@ string DebugString(const CondArgNodes& nodes) { "]"); } -CondStateMap::CondId CondStateMap::LookupId(const Node* node) const { +StateMap::CondId StateMap::LookupCondId(const Node* node) const { if (node->id() < node_to_condid_map_.size()) return node_to_condid_map_[node->id()]; - return added_node_mapping_.at(node->id()); + return added_node_condid_mapping_.at(node->id()); } -CondStateMap::CondId CondStateMap::GetUniqueId( - const CondStateMap::CondState& state) { +StateMap::CondId StateMap::GetCondId(const StateMap::CondState& state) { if (state.empty()) return nullptr; return &*condstate_set_.insert(state).first; } -const CondStateMap::CondState& CondStateMap::LookupState( - const Node* node) const { - return *LookupId(node); -} - -void CondStateMap::ResetId(const Node* node, CondStateMap::CondId id) { +void StateMap::ResetCondId(const Node* node, StateMap::CondId id) { if (node->id() < node_to_condid_map_.size()) node_to_condid_map_[node->id()] = id; else - added_node_mapping_[node->id()] = id; + added_node_condid_mapping_[node->id()] = id; +} + +StateMap::AncestorId StateMap::LookupAncestorId(const Node* node) const { + if (node->id() < node_to_ancestorid_map_.size()) + return node_to_ancestorid_map_[node->id()]; + return added_node_ancestorid_mapping_.at(node->id()); } -void CondStateMap::MarkDead(const Node* node) { ResetId(node, dead_id_); } +StateMap::AncestorId StateMap::GetAncestorId( + const StateMap::AncestorState& state) { + if (state.empty()) return nullptr; + return &*ancestorstate_set_.insert(state).first; +} -string CondStateMap::CondStateToString(const Node* node) const { - return CondStateToString(LookupId(node)); +void StateMap::ResetAncestorId(const Node* node, StateMap::AncestorId id) { + if (node->id() < node_to_ancestorid_map_.size()) + node_to_ancestorid_map_[node->id()] = id; + else + added_node_ancestorid_mapping_[node->id()] = id; } -string CondStateMap::CondStateToString(CondStateMap::CondId id) const { +const StateMap::CondState& StateMap::LookupState(const Node* node) const { + return *LookupCondId(node); +} + +void StateMap::MarkDead(const Node* node) { ResetCondId(node, dead_id_); } + +string StateMap::CondStateToString(const Node* node) const { + return CondStateToString(LookupCondId(node)); +} + +string StateMap::CondStateToString(StateMap::CondId id) const { return DebugString(id); } +string StateMap::AncestorStateToString(const Node* node) const { + if (auto id = LookupAncestorId(node)) return NodesToString(*id); + return "{}"; +} + FunctionalizeCond::FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) - : cond_state_map_(graph), library_(library), graph_(graph) {} + : state_map_(graph), library_(library), graph_(graph) {} // Class representing the merge/switch nodes that will become a conditional. class Conditional { public: Conditional(OutputTensor predicate, FunctionalizeCond* parent, - CondStateMap* cond_state_map); + StateMap* cond_state_map); // Adds merge node that is part of this conditional. Status AddMerge(Node* m); @@ -247,6 +273,10 @@ class Conditional { // Adds switch node that is part of this conditional. Status AddSwitch(Node* s); + // Adds a switch node along the edge and rewire the edge to go via the switch. + Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, + Graph* graph); + // Internal name of conditional. The name is based on the first merge node // added. string name() const; @@ -255,7 +285,7 @@ class Conditional { FunctionalizeCond* parent_; // Mapping between nodes and their cond state. - CondStateMap* cond_state_map_; + StateMap* state_map_; // The predicate of the conditional. OutputTensor predicate_; @@ -292,8 +322,8 @@ class Conditional { }; Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent, - CondStateMap* cond_state_map) - : parent_(parent), cond_state_map_(cond_state_map), predicate_(predicate) {} + StateMap* cond_state_map) + : parent_(parent), state_map_(cond_state_map), predicate_(predicate) {} Status Conditional::AddMerge(Node* m) { merges_.insert(m); @@ -397,6 +427,35 @@ Status Conditional::BuildArgumentNodes() { return Status::OK(); } +Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, + Graph* graph) { + // Previously we had edge: + // src:src_output ---- edge ----> dst:dst_input + // post this we have (in graph) + // src:src_output --> switch --- new_edge --> dst:dst_input + + // TODO(jpienaar): One could keep a map caching the extra switch nodes added + // to avoid adding another switch to feed a value for which a switch was + // already added. + Node* switch_node; + Node* src = edge->src(); + int src_output = edge->src_output(); + TF_RETURN_IF_ERROR( + NodeBuilder(graph->NewName(strings::StrCat(src->name(), "_added_switch")), + "Switch") + .Input(src, src_output) + .Input(const_cast(predicate_.node), predicate_.index) + .Finalize(graph, &switch_node)); + state_map_->ResetCondId(switch_node, state_map_->LookupCondId(src)); + state_map_->ResetAncestorId(switch_node, state_map_->LookupAncestorId(src)); + + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + graph->AddEdge(switch_node, static_cast(branch), dst, dst_input); + return AddSwitch(switch_node); +} + Status Conditional::ExtractBodies(Graph* graph) { VLOG(2) << "Extracting bodies for " << name(); for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) { @@ -405,16 +464,16 @@ Status Conditional::ExtractBodies(Graph* graph) { } auto find_branch = [&](const Edge* e) { - const auto& id = cond_state_map_->LookupId(e->src()); + const auto& id = state_map_->LookupCondId(e->src()); return IsSwitch(e->src()) ? BranchType(e->src_output()) - : cond_state_map_->FindBranchOf(id, predicate_); + : state_map_->FindBranchOf(id, predicate_); }; std::array, 2> stacks; VLOG(5) << "Merges: " << NodesToString(merges_); for (Node* m : merges_) { VLOG(5) << "For merge: " << m->DebugString() << " " - << cond_state_map_->CondStateToString(m); + << state_map_->CondStateToString(m); for (auto e : m->in_edges()) { if (e->IsControlEdge()) continue; BranchType branch = find_branch(e); @@ -422,7 +481,8 @@ Status Conditional::ExtractBodies(Graph* graph) { branch == BranchType::kElseBranch) << "Error: " << e->src()->name() << " is not on either then or else branch (" << Branch_Name(branch) - << ")."; + << ") for predicate " << DebugString(predicate_) << " [" + << DebugString(state_map_->LookupCondId(e->src())) << "]."; Node* src = e->src(); if (IsSwitch(src)) { // Switch node outputs and dependencies are handled separately. @@ -456,8 +516,8 @@ Status Conditional::ExtractBodies(Graph* graph) { if (IsMerge(dst)) continue; Node* src = e->src(); - auto dst_id = cond_state_map_->LookupId(dst); - auto src_id = cond_state_map_->LookupId(src); + auto dst_id = state_map_->LookupCondId(dst); + auto src_id = state_map_->LookupCondId(src); if (dst_id != src_id) { if (e->IsControlEdge()) { external_control_outputs_.push_back(e->src()); @@ -480,8 +540,11 @@ Status Conditional::ExtractBodies(Graph* graph) { } } - // Copying incomming edges to dst node. - for (const Edge* e : n->in_edges()) { + // Copying incomming edges to dst node. Iterate over a copy of the edges + // as they could be mutated during iteration. + std::vector in_edges(n->in_edges().begin(), + n->in_edges().end()); + for (const Edge* e : in_edges) { Node* src = e->src(); // Skip src/dst node. if (!src->IsOp()) continue; @@ -494,8 +557,8 @@ Status Conditional::ExtractBodies(Graph* graph) { } // Verify input is from the same context. - auto src_id = cond_state_map_->LookupId(src); - auto dst_id = cond_state_map_->LookupId(dst); + auto src_id = state_map_->LookupCondId(src); + auto dst_id = state_map_->LookupCondId(dst); if (IsMerge(dst) || src_id == dst_id) { // TODO(jpienaar): The merge case can be more strict. if (node_map.at(src->id()) == nullptr) { @@ -506,18 +569,25 @@ Status Conditional::ExtractBodies(Graph* graph) { external_control_inputs_.push_back(src); } else { // This shouldn't happen, this means we have an external data input - // not entering via a switch node. Work around this for constant - // nodes as some constant nodes are inserted without the required - // control context dominance. + // not entering via a switch node. Work around this by for + // * constant nodes copy them; + // * non-constant nodes, insert a switch along the edge; if (IsConstant(src)) { node_map.at(src->id()) = output->CopyNode(src); } else { - return errors::InvalidArgument( - "Graph contains node ", FormatNodeForError(*src), - " that feeds into node ", FormatNodeForError(*dst), - " but these nodes are in different control contexts (", - DebugString(src_id), " vs ", DebugString(dst_id), - " (detected during in edge testing)"); + StateMap::CondState state = *dst_id; + state.erase(predicate_); + if (state_map_->GetCondId(state) == src_id) { + TF_RETURN_IF_ERROR(AddSwitchNodeAlongEdge(e, branch, graph)); + continue; + } else { + return errors::InvalidArgument( + "Graph contains node ", FormatNodeForError(*src), + " that feeds into node ", FormatNodeForError(*dst), + " but these nodes are in different control contexts (", + DebugString(src_id), " vs ", DebugString(dst_id), + " (detected during in edge testing)"); + } } } @@ -639,7 +709,8 @@ Status Conditional::BuildIfNode(Graph* graph, VLOG(3) << "Build If node"; NodeDef if_def; TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); - TF_ASSIGN_OR_RETURN(if_node_, parent_->AddIfNode(if_def, *merges_.begin())); + TF_ASSIGN_OR_RETURN(if_node_, + parent_->AddIfNode(if_def, *merges_.begin(), predicate_)); return Status::OK(); } @@ -699,7 +770,8 @@ Status Conditional::AddOutputEdges(Graph* graph) { Status Conditional::BuildAndReplace(Graph* graph, FunctionLibraryDefinition* library) { - VLOG(1) << "Build If and replace merge nodes " << name(); + VLOG(1) << "Build If and replace merge nodes " + << NodesToString(this->merges_); if (replaced_) return Status::OK(); TF_RETURN_IF_ERROR(ExtractBodies(graph)); @@ -719,7 +791,7 @@ Status Conditional::BuildAndReplace(Graph* graph, TF_RETURN_IF_ERROR(AddInputEdges(graph)); TF_RETURN_IF_ERROR(AddOutputEdges(graph)); TF_RETURN_IF_ERROR(parent_->PropagateUpdatedState(if_node_)); - for (Node* m : merges_) cond_state_map_->MarkDead(m); + for (Node* m : merges_) state_map_->MarkDead(m); // Check that the if_node doesn't feed into itself. TF_RETURN_WITH_CONTEXT_IF_ERROR( @@ -735,55 +807,41 @@ string Conditional::name() const { return strings::StrCat((*merges_.begin())->name(), "_if"); } -bool CondStateMap::ScopeIn(CondStateMap::CondId id, - CondStateMap::CondId* scope) { - if (id == nullptr) { - *scope = nullptr; - return true; - } - CondState state; - for (const CondNode& node : *id) { - if (node.type == CondNode::Type::kSwitch) { - state.push_back(node); - } - if (node.type == CondNode::Type::kMerge) { - if (state.empty()) { - return false; - } - DCHECK(state.back().type == CondNode::Type::kSwitch && - state.back().branch == BranchType::kBoth); - state.pop_back(); - } - } - *scope = GetUniqueId(state); - return true; -} - Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, int port) { Node* id; TF_RETURN_IF_ERROR(NodeBuilder(replacee->name(), "Identity") .Input(if_node, port) .Finalize(graph_, &id)); - cond_state_map_.ResetId(id, cond_state_map_.LookupId(if_node)); + state_map_.ResetCondId(id, state_map_.LookupCondId(if_node)); + state_map_.ResetAncestorId(id, state_map_.LookupAncestorId(if_node)); return Status::OK(); } StatusOr FunctionalizeCond::AddIfNode(const NodeDef& def, - const Node* replacee) { + const Node* replacee, + const OutputTensor& predicate) { Status status; Node* ret = graph_->AddNode(def, &status); TF_RETURN_IF_ERROR(status); - CondStateMap::CondState state = cond_state_map_.LookupState(replacee); - state.pop_back(); VLOG(1) << "Adding If for " << replacee->name(); - cond_state_map_.ResetId(ret, cond_state_map_.GetUniqueId(state)); + StateMap::CondId id = state_map_.LookupCondId(replacee); + if (id) { + StateMap::CondState state = *id; + state.erase(predicate); + state_map_.ResetCondId(ret, state_map_.GetCondId(state)); + } else { + state_map_.ResetCondId(ret, nullptr); + } + + state_map_.ResetAncestorId(ret, state_map_.LookupAncestorId(replacee)); + return ret; } Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { VLOG(2) << "Propagating update state for " << replacee->name() << " " - << cond_state_map_.CondStateToString(replacee); + << state_map_.CondStateToString(replacee); // Redo topological sort as the order could have changed. // TODO(jpienaar): The original topological order could also be updated // dynamically if needed. @@ -801,10 +859,10 @@ Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { if (changed.find(*it) != changed.end()) { // Update the node state. Node* n = *it; - CondStateMap::CondId old_state = cond_state_map_.LookupId(n); - cond_state_map_.ResetId(n, nullptr); + StateMap::CondId old_state = state_map_.LookupCondId(n); + state_map_.ResetCondId(n, nullptr); TF_RETURN_IF_ERROR(DetermineCondState(n)); - if (cond_state_map_.LookupId(n) != old_state) { + if (state_map_.LookupCondId(n) != old_state) { for (auto out : n->out_nodes()) if (out->IsOp()) changed.insert(out); } @@ -825,127 +883,44 @@ BranchType MeetBranch(const BranchType& lhs, const BranchType& rhs) { return BranchType::kNeither; } -CondStateMap::ContainsResult CondStateMap::LhsHoldsWhereverRhsHolds( - CondStateMap::CondId lhs, CondStateMap::CondId rhs) { - CondId lhs_scope; - CondId rhs_scope; - bool could_determine_scope = ScopeIn(lhs, &lhs_scope); - could_determine_scope = could_determine_scope && ScopeIn(rhs, &rhs_scope); - if (!could_determine_scope) return kIncomparable; - - // Returns whether a contains b. - auto contains = [&](CondId a, CondId b) { - // Handle empty states. - if (a == nullptr && b != nullptr) return true; - if (a == nullptr && b == nullptr) return true; - if (a != nullptr && b == nullptr) return false; - - if (a->size() > b->size()) return false; - auto a_it = a->begin(); - auto b_it = b->begin(); - while (a_it != a->end()) { - if (*a_it != *b_it) { - if (!(a_it->predicate == b_it->predicate)) return false; - BranchType mb = MeetBranch(a_it->branch, b_it->branch); - if (mb != b_it->branch) return false; - } - ++a_it; - ++b_it; - } - return true; - }; - - bool lhs_contains_rhs = contains(lhs_scope, rhs_scope); - bool rhs_contains_lhs = contains(rhs_scope, lhs_scope); - if (lhs_contains_rhs && rhs_contains_lhs) return kEqual; - if (lhs_contains_rhs) return kLhsContainsRhs; - if (rhs_contains_lhs) return kRhsContainsLhs; - return kIncomparable; -} - -BranchType CondStateMap::FindBranchOf(CondId id, OutputTensor predicate) const { +BranchType StateMap::FindBranchOf(CondId id, OutputTensor predicate) const { if (IsEmpty(id)) return BranchType::kNeither; - absl::optional b; const CondState& nodes = *id; - for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) { - if (it->type == CondStateMap::CondNode::Type::kSwitch && - it->predicate == predicate) { - if (b.has_value()) { - b = MeetBranch(*b, it->branch); - } else { - b = it->branch; - } - if (*b == BranchType::kNeither) { - LOG(FATAL) << "Inconsistent state for node: " << DebugString(id); - } - } - } - return b.has_value() ? *b : BranchType::kNeither; + auto it = nodes.find(predicate); + if (it == nodes.end()) return BranchType::kNeither; + return it->second; } -StatusOr FunctionalizeCond::JoinCondStatesNonMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { - VLOG(4) << "Joining src=" << DebugString(src) << " [" << src +StatusOr FunctionalizeCond::JoinCondStatesNonMerge( + StateMap::CondId src, StateMap::CondId dst) { + VLOG(5) << "Joining src=" << DebugString(src) << " [" << src << "] and dst=" << DebugString(dst) << " [" << dst << "]"; - if (cond_state_map_.IsEmpty(dst) || cond_state_map_.IsDead(src)) return src; - if (cond_state_map_.IsDead(dst)) return dst; + if (state_map_.IsEmpty(dst) || state_map_.IsDead(src)) return src; + if (state_map_.IsDead(dst) || state_map_.IsEmpty(src)) return dst; // Nothing to do if the CondState is the same. if (src == dst) return src; - CondStateMap::CondId src_scope; - CondStateMap::CondId dst_scope; - if (!cond_state_map_.ScopeIn(src, &src_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(src)); - if (!cond_state_map_.ScopeIn(dst, &dst_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(dst)); - - auto result = cond_state_map_.LhsHoldsWhereverRhsHolds(src_scope, dst_scope); - switch (result) { - case CondStateMap::kIncomparable: - return errors::InvalidArgument( - "Graph contains node with inputs predicated on incompatible " - "predicates: ", - DebugString(src), " and ", DebugString(dst)); - case CondStateMap::kEqual: - // If both respect the same predicates, propagate the longer constraint. - if ((src != nullptr && dst == nullptr) || - (src != nullptr && dst != nullptr && src->size() > dst->size())) - return src; - else - return dst; - case CondStateMap::kLhsContainsRhs: - // src contains dst, so dst is already more restrictive. - return dst; - case CondStateMap::kRhsContainsLhs: - // dst contains src, so src is more restrictive. - return src; - } -} - -StatusOr -FindThenElseSwitchForPredicate(const OutputTensor& pred, - CondStateMap::CondId id) { - for (auto it = id->begin(); it != id->end(); ++it) { - // Along every path one there can be only one instance of a then or else - // switch for a given predicate, so return once found. - if (it->type == CondStateMap::CondNode::Type::kSwitch && - it->predicate == pred && - (it->branch == BranchType::kThenBranch || - it->branch == BranchType::kElseBranch)) - return it; + StateMap::CondState both = *src; + for (const auto& kv : *dst) { + auto it = both.find(kv.first); + if (it == both.end()) { + both.insert(kv); + } else { + if (it->second != kv.second) { + return errors::InvalidArgument( + "Graph contains node with inputs predicated on incompatible " + "predicates: ", + DebugString(src), " and ", DebugString(dst)); + } + } } - return errors::Internal("Unable to find then/else branch with predicate ", - DebugString(pred), " for ", DebugString(id)); + return state_map_.GetCondId(both); } -StatusOr FunctionalizeCond::JoinCondStatesMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { +StatusOr FunctionalizeCond::JoinCondStatesMerge( + Node* merge, StateMap::CondId src, StateMap::CondId dst) { // Determine the flow state when joining two states for a merge // node. Combining the two states for a merge node is effectively performing a // disjunction of the states along the different input edges. For a merge that @@ -956,91 +931,56 @@ StatusOr FunctionalizeCond::JoinCondStatesMerge( // followed by s(p, both). VLOG(4) << "Joining (for merge) " << DebugString(src) << " and " << DebugString(dst); - if (cond_state_map_.IsEmpty(dst)) return src; - - if (cond_state_map_.IsDead(src)) return src; - if (cond_state_map_.IsDead(dst)) return dst; - - CondStateMap::CondId src_scope; - CondStateMap::CondId dst_scope; - if (!cond_state_map_.ScopeIn(src, &src_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(src)); - if (!cond_state_map_.ScopeIn(dst, &dst_scope)) - return errors::Unimplemented( - "Predicates that must hold for node to execute are invalid! ", - DebugString(dst)); - - TF_RET_CHECK(src_scope != nullptr && dst_scope != nullptr) - << "Illegal merge inputs from outer scope: src=" << DebugString(src) - << " dst=" << DebugString(dst); - auto src_it = src_scope->begin(); - auto dst_it = dst_scope->begin(); - - // Find branch divergent condition. - OutputTensor pred; - while (src_it != src_scope->end() && dst_it != dst_scope->end()) { - if (*src_it != *dst_it) { - VLOG(5) << "Diverges with: " << DebugString(*src_it) << " and " - << DebugString(*dst_it); - if (!(src_it->predicate == dst_it->predicate)) { - return errors::InvalidArgument( - "Unable to find common predicate which holds for one input " - "but not the other of the merge node."); - } - pred = src_it->predicate; - break; - } - ++src_it; - ++dst_it; - } - - if (pred.node == nullptr) - return errors::InvalidArgument("Unable to determine predicate for merge."); - - TF_ASSIGN_OR_RETURN(auto div_src_it, - FindThenElseSwitchForPredicate(pred, src)); - TF_ASSIGN_OR_RETURN(auto div_dst_it, - FindThenElseSwitchForPredicate(pred, dst)); - TF_RET_CHECK(*div_src_it != *div_dst_it); - - CondStateMap::CondState result; - // Populate result with the longest/most restrictive path up to the divergent - // node. For example, if the one input is `[switch(pred:0, then)]` and the - // other is `[switch(pred:0, both), merge, switch(pred:0, else)]` (as created - // in gradient of cond test), then the resultant state here should be - // `[switch(pred:0, both), merge, switch(pred:0, both)]`. - if (std::distance(src->begin(), div_src_it) > - std::distance(dst->begin(), div_dst_it)) { - result.assign(src->begin(), std::next(div_src_it)); + if (state_map_.IsEmpty(dst)) return src; + + if (state_map_.IsDead(src)) return src; + if (state_map_.IsDead(dst)) return dst; + + std::vector diff; + StateMap::CondState merged; + std::set_symmetric_difference(src->begin(), src->end(), dst->begin(), + dst->end(), std::back_inserter(diff), + CondStateLess()); + std::set_intersection(src->begin(), src->end(), dst->begin(), dst->end(), + std::inserter(merged, merged.begin()), CondStateLess()); + + // Update mapping from merge node to predicate. + if (diff.size() == 2) { + auto pred = diff[0].first; + bool different_branches = (diff[0].second != diff[1].second) && + (diff[0].second == BranchType::kThenBranch || + diff[0].second == BranchType::kElseBranch) && + (diff[1].second == BranchType::kThenBranch || + diff[1].second == BranchType::kElseBranch); + if (!(pred == diff[1].first) || !different_branches) + return errors::InvalidArgument( + "Unable to determine predicate for merge node"); + merge_to_predicate_[merge] = pred; } else { - result.assign(dst->begin(), std::next(div_dst_it)); + return errors::InvalidArgument( + "Merge of two inputs that differ on more than one predicate ", + DebugString(src), " and ", DebugString(dst)); } - result.back().branch = BranchType::kBoth; - return cond_state_map_.GetUniqueId(result); + + return state_map_.GetCondId(merged); } -CondStateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { +StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { Node* src = e->src(); - CondStateMap::CondId id = cond_state_map_.LookupId(e->src()); - if (IsMerge(src)) { - CondStateMap::CondState state; - if (id != nullptr) state = *id; - state.emplace_back(CondStateMap::CondNode::Type::kMerge); - return cond_state_map_.GetUniqueId(state); - } + StateMap::CondId id = state_map_.LookupCondId(e->src()); + + // Dead nodes only propagate dead state. + if (state_map_.IsDead(id)) return id; + if (IsSwitch(src)) { - CondStateMap::CondState state; + StateMap::CondState state; if (id != nullptr) state = *id; - if (e->IsControlEdge()) { - state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src, - BranchType::kBoth); - } else { - state.emplace_back(CondStateMap::CondNode::Type::kSwitch, src, - BranchType(e->src_output())); + OutputTensor predicate; + TF_CHECK_OK(GetSwitchPredicate(*src, &predicate)); + if (!e->IsControlEdge()) { + state[predicate] = BranchType(e->src_output()); } - return cond_state_map_.GetUniqueId(state); + return state_map_.GetCondId(state); } return id; } @@ -1049,22 +989,21 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { // Only Merge nodes with two inputs are supported, but if this is a redundant // merge, then the dead edge may already have been removed (if due to a // switch) and so the input count would be incorrect. - if (cond_state_map_.IsDead(cond_state_map_.LookupId(dst))) - return Status::OK(); + if (state_map_.IsDead(state_map_.LookupCondId(dst))) return Status::OK(); int data_inputs = 0; for (auto e : dst->in_edges()) { Node* src = e->src(); VLOG(5) << "Processing forward flow for merge: " << e->DebugString() << " " - << cond_state_map_.CondStateToString(src); + << state_map_.CondStateToString(src); if (!src->IsOp()) continue; if (!e->IsControlEdge()) ++data_inputs; - CondStateMap::CondId prop = StateAlongEdge(e); - auto id_or = JoinCondStatesMerge(prop, cond_state_map_.LookupId(dst)); + StateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesMerge(dst, prop, state_map_.LookupCondId(dst)); TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", FormatNodeForError(*dst)); - cond_state_map_.ResetId(dst, id_or.ValueOrDie()); + state_map_.ResetCondId(dst, id_or.ValueOrDie()); } // Incomplete Merge nodes are not supported. @@ -1076,27 +1015,20 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { return Status::OK(); } -Status FunctionalizeCond::DetermineCondState(Node* dst) { - // The logic for the merge and non-merge case differ: for non-merge it is - // the most restrictive CondState, while for merge nodes the - // resultant state is less restrictive than either. - if (IsMerge(dst)) { - TF_RETURN_IF_ERROR(DetermineCondStateMerge(dst)); - } else { - // Handle non-merge join. - for (auto e : dst->in_edges()) { - VLOG(5) << "Processing forward flow for: " << e->DebugString() << " " - << cond_state_map_.CondStateToString(dst); - Node* src = e->src(); - if (!src->IsOp()) continue; - - // Joining the state between the current and propagated state. - CondStateMap::CondId prop = StateAlongEdge(e); - auto id_or = JoinCondStatesNonMerge(prop, cond_state_map_.LookupId(dst)); - TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", - FormatNodeForError(*dst)); - cond_state_map_.ResetId(dst, id_or.ValueOrDie()); - } +Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) { + // Handle non-merge join. + for (auto e : dst->in_edges()) { + VLOG(4) << "Processing forward flow for: " << e->DebugString() << " " + << state_map_.CondStateToString(dst); + Node* src = e->src(); + if (!src->IsOp()) continue; + + // Joining the state between the current and propagated state. + StateMap::CondId prop = StateAlongEdge(e); + auto id_or = JoinCondStatesNonMerge(prop, state_map_.LookupCondId(dst)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", + FormatNodeForError(*dst)); + state_map_.ResetCondId(dst, id_or.ValueOrDie()); } return Status::OK(); } @@ -1104,8 +1036,7 @@ Status FunctionalizeCond::DetermineCondState(Node* dst) { Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { // Handle redundant merge nodes. A merge node is considered redundant if // one input edge is dead while the other has a value. - if (!cond_state_map_.IsDead(cond_state_map_.LookupId(node))) - return Status::OK(); + if (!state_map_.IsDead(state_map_.LookupCondId(node))) return Status::OK(); const Edge* non_dead_edge = nullptr; for (auto e : node->in_edges()) { @@ -1113,8 +1044,8 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { Node* src = e->src(); // Handle merge with dead state. - const auto& src_id = cond_state_map_.LookupId(src); - if (!cond_state_map_.IsDead(src_id)) { + const auto& src_id = state_map_.LookupCondId(src); + if (!state_map_.IsDead(src_id)) { non_dead_edge = e; break; } @@ -1124,7 +1055,7 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { return errors::InvalidArgument("Merge node ", FormatNodeForError(*node), " has no non-dead inputs."); } - cond_state_map_.MarkDead(node); + state_map_.MarkDead(node); delete_nodes_.push_back(node->id()); VLOG(5) << "removing redundant merge: " << node->name(); while (!node->out_edges().empty()) { @@ -1149,16 +1080,33 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { // along one. The checking of predicate is based on the exact predicate // (rather than boolean equivalence) and aimed at redundant switches as // currently generated by gradient code. + StateMap::CondId dst_id = state_map_.LookupCondId(node); + if (state_map_.IsDead(dst_id)) return Status::OK(); + + BranchType b; OutputTensor pred; TF_RETURN_IF_ERROR(GetSwitchPredicate(*node, &pred)); - auto dst_id = cond_state_map_.LookupId(node); - BranchType b = cond_state_map_.FindBranchOf(dst_id, pred); + // Determine if we are already on a branch where the switch predicate is - // true/false. - if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) - return Status::OK(); + // true/false. Consider both the data and predicate to determine if the + // node is redundant (skipping over identity node). + b = state_map_.FindBranchOf(dst_id, pred); + if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) { + OutputTensor val; + const Edge* e; + TF_RETURN_IF_ERROR(node->input_edge(0, &e)); + val = OutputTensor(e->src(), e->src_output()); + while (IsIdentity(val.node)) { + TF_RETURN_IF_ERROR(val.node->input_edge(0, &e)); + val = OutputTensor(e->src(), e->src_output()); + } + b = state_map_.FindBranchOf(dst_id, val); + if (b != BranchType::kThenBranch && b != BranchType::kElseBranch) + return Status::OK(); + } - VLOG(5) << "Redundant switch " << node->name(); + VLOG(5) << "Redundant switch " << node->name() << " " << Branch_Name(b) << " " + << DebugString(dst_id); const Edge* value_edge; TF_RETURN_IF_ERROR(node->input_edge(0, &value_edge)); Node* val_node = value_edge->src(); @@ -1171,19 +1119,19 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { graph_->RemoveEdge(e); if (switch_branch == Graph::kControlSlot) { if (IsMerge(dst_node)) { - auto id_or = - JoinCondStatesMerge(dst_id, cond_state_map_.LookupId(dst_node)); + auto id_or = JoinCondStatesMerge(dst_node, dst_id, + state_map_.LookupCondId(dst_node)); TF_RETURN_WITH_CONTEXT_IF_ERROR(id_or.status(), "for node ", FormatNodeForError(*dst_node)); - cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); + state_map_.ResetCondId(dst_node, id_or.ValueOrDie()); } else { auto id_or = - JoinCondStatesNonMerge(dst_id, cond_state_map_.LookupId(dst_node)); + JoinCondStatesNonMerge(dst_id, state_map_.LookupCondId(dst_node)); TF_RETURN_IF_ERROR(id_or.status()); - cond_state_map_.ResetId(dst_node, id_or.ValueOrDie()); + state_map_.ResetCondId(dst_node, id_or.ValueOrDie()); } } else if (BranchType(switch_branch) != b) { - cond_state_map_.MarkDead(dst_node); + state_map_.MarkDead(dst_node); delete_nodes_.push_back(dst_node->id()); continue; } @@ -1195,20 +1143,47 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { return Status::OK(); } -Status FunctionalizeCond::DetermineCondStates( - std::vector rev_topo_order) { +Status FunctionalizeCond::DetermineStates(std::vector rev_topo_order) { // The state that is propagated along the given edge. for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) { Node* dst = *it; TF_RETURN_IF_ERROR(DetermineCondState(dst)); + TF_RETURN_IF_ERROR(DetermineAncestorState(dst)); if (IsSwitch(dst)) TF_RETURN_IF_ERROR(RemoveRedundantSwitch(dst)); if (IsMerge(dst)) TF_RETURN_IF_ERROR(RemoveRedundantMerge(dst)); - VLOG(5) << dst->name() << " :: " << cond_state_map_.CondStateToString(dst); + VLOG(5) << dst->name() << " :: " << state_map_.CondStateToString(dst) + << " @ " << state_map_.AncestorStateToString(dst); + if (VLOG_IS_ON(10)) DumpGraphWithCondState("cond_it"); } return Status::OK(); } +Status FunctionalizeCond::DetermineAncestorState(Node* dst) { + StateMap::AncestorId id = nullptr; + StateMap::AncestorState state; + + auto insert = [&](StateMap::AncestorId id, Node* src) { + auto other_id = state_map_.LookupAncestorId(src); + if (other_id != id && other_id != nullptr) { + state.insert(other_id->begin(), other_id->end()); + } + if (IsSwitch(src) || IsMerge(src)) { + state.insert(src); + } + return state_map_.GetAncestorId(state); + }; + + // Compute the union of all the switch/merge nodes that affects the input of + // dst. + for (auto e : dst->in_edges()) { + Node* src = e->src(); + id = insert(id, src); + } + state_map_.ResetAncestorId(dst, id); + return Status::OK(); +} + void FunctionalizeCond::DeleteReachableNodes() { // Delete all nodes that have been extracted or are reachable from // deleted/dead nodes. The input and outgoing edges should have already been @@ -1239,16 +1214,8 @@ void FunctionalizeCond::SortMergeNodes(std::vector* merge_order) { inner_to_outer_merge_order.reserve(merge_order->size()); for (auto it = merge_order->rbegin(); it != merge_order->rend(); ++it) { Node* merge = *it; - CondStateMap::CondId id = cond_state_map_.LookupId(merge); - int depth = 0; - for (auto cond_node_it = id->begin(); cond_node_it != id->end(); - ++cond_node_it) { - if (cond_node_it->type == CondStateMap::CondNode::Type::kSwitch && - (cond_node_it->branch == BranchType::kThenBranch || - cond_node_it->branch == BranchType::kElseBranch)) { - ++depth; - } - } + StateMap::CondId id = state_map_.LookupCondId(merge); + int depth = id != nullptr ? id->size() : 0; inner_to_outer_merge_order.emplace_back(depth, merge); } std::stable_sort( @@ -1271,10 +1238,10 @@ Status FunctionalizeCond::FunctionalizeInternal() { // determine deeper equivalence). We shall refer to this structure as the // CondState; // 3. Sort the merge nodes by nesting depth; - // 4. Extract merge nodes together that have the same CondState and whose - // input nodes have the same state from the innermost to the outermost into - // IfOps; Note: In the above only nodes paths that converge to a merge node - // will be considered for removal. + // 4. Extract merge nodes together that have the same CondState and + // AncestorState from the innermost to the outermost into IfOps; + // Note: In the above only nodes that feed into a merge node will be + // considered for functionalization. // Perform a DFS over the graph and // * Determine the reverse topological order of the nodes (there should be no @@ -1306,40 +1273,40 @@ Status FunctionalizeCond::FunctionalizeInternal() { return Status::OK(); } - TF_RETURN_IF_ERROR(DetermineCondStates(std::move(rev_topo_order))); - + TF_RETURN_IF_ERROR(DetermineStates(std::move(rev_topo_order))); if (VLOG_IS_ON(4)) DumpGraphWithCondState("cond_id"); // Sort the merge nodes from innermost outwards. SortMergeNodes(&merge_order); - // Extract from innermost out. - for (auto it = merge_order.begin(); it != merge_order.end(); ++it) { - Node* merge = *it; - auto id = cond_state_map_.LookupId(merge); - if (cond_state_map_.IsDead(id)) continue; - - // Construct a Conditional with the predicate of the merge (which is the - // last entry of the CondState for the merge) and this as parent. - DCHECK(id->back().predicate.node != nullptr); - Conditional cond(id->back().predicate, this, &cond_state_map_); - TF_RETURN_IF_ERROR(cond.AddMerge(merge)); - - // Find all merge nodes with the same CondId. This is done repeatedly as - // the CondId can change due replaced conditionals. E.g., the one branch - // could previously have had a conditional nested in it, and so would have - // had CondState with sub-state [switch(p,b),m] (where p is some predicate), - // post removing the nested conditional that sub-state would no longer be - // path of the propagated state along that path. - auto end = merge_order.end(); - for (auto merge_candidate_it = std::next(it); merge_candidate_it != end; - ++merge_candidate_it) { - auto merge_candidate_it_id = - cond_state_map_.LookupId(*merge_candidate_it); - if (merge_candidate_it_id != id) continue; - TF_RETURN_IF_ERROR(cond.AddMerge(*merge_candidate_it)); + // Cluster merge nodes by CondId and AncestorId in order of nesting. + using ClusterPair = std::pair; + std::deque> merge_clusters; + std::map merge_cluster_index; + for (Node* merge : merge_order) { + auto cond_id = state_map_.LookupCondId(merge); + if (state_map_.IsDead(cond_id)) continue; + + ClusterPair key = + std::make_pair(cond_id, state_map_.LookupAncestorId(merge)); + auto idx = merge_cluster_index.find(key); + if (idx == merge_cluster_index.end()) { + merge_cluster_index[key] = merge_clusters.size(); + merge_clusters.push_back({merge}); + } else { + merge_clusters[idx->second].emplace_back(merge); } + } + // Extract the conditionals from inner most to outer most. Extracting from + // innermost to outermost enables the extraction pass to stop once it + // encounters a Switch node instead of having to keep track of Switch/Merge + // nodes seen. + for (const auto& cluster : merge_clusters) { + // Construct a Conditional with the predicate of the merge. + Conditional cond(merge_to_predicate_.at(cluster.front()), this, + &state_map_); + for (Node* merge : cluster) TF_RETURN_IF_ERROR(cond.AddMerge(merge)); TF_RETURN_IF_ERROR(cond.BuildAndReplace(graph_, library_)); if (VLOG_IS_ON(4)) DumpGraphWithCondState("after_extract"); @@ -1359,7 +1326,9 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) { for (Node* n : graph_->nodes()) { n->ClearAttr(kCondGroupDebugAttr); - n->AddAttr(kCondGroupDebugAttr, cond_state_map_.CondStateToString(n)); + n->AddAttr(kCondGroupDebugAttr, + strings::StrCat(state_map_.CondStateToString(n), "_", + state_map_.AncestorStateToString(n))); } LOG(INFO) << "FunctionalizeControlFlow (" << name << "): " << dump_graph::DumpGraphToFile( diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 86436011c6..28301150ea 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -43,105 +43,88 @@ enum class BranchType { kNeither = 3, }; -// CondStateMap is responsible for mapping from each graph Node to a CondState, -// where each CondState is the array of CondNodes (corresponding to switch, -// merge or dead states) as described below. For efficiency, this class interns -// the CondState, so that CondState equality comparisons are simply pointer +// StateMap is responsible for mapping from each graph Node to +// * a CondState, where each CondState is a map from predicate to branch (i,e., +// what predicates have to hold or not hold). +// * a AncestorState, where each AncestorState is a set of switch/merge nodes +// that are an ancestor of the node in the graph; +// For efficiency, this class interns the CondState (AncestorState), so that +// CondState (AncestorState) equality comparisons are simply pointer // comparisons. -class CondStateMap { +class StateMap { public: - explicit CondStateMap(Graph* graph); - - // Represents an entry in the CondState. An entry can either be the - // switch (along with predicate), merge, or dead: - // * switch node indicates a node that is executed along a branch with the - // given predicate - a branch can be then, else or both; - // * merge node indicates that the node is executed as output of a merge; - // * dead indicates that this node can never be executed; - struct CondNode { - enum class Type { kSwitch = 1, kMerge = 2, kDead = 3 }; - - CondNode(Type type, Node* switch_node = nullptr, - BranchType branch = BranchType::kNeither); - - string ToString() const; - bool operator==(const CondNode& other) const; - bool operator!=(const CondNode& other) const; - - // Type of node. - Type type; - - // Predicate and branch, only used when type is kSwitch. - OutputTensor predicate; - BranchType branch; + explicit StateMap(Graph* graph); + + // Compare two OutputTensors by (node id, index). + struct OutputTensorLess { + bool operator()(const OutputTensor& lhs, const OutputTensor& rhs) const; }; - // A node in the graph is executed when multiple conditions hold. The order - // represents the nesting of the predicates that hold and is used when - // extracting the nested conditionals. - using CondState = std::vector; + // A node in the graph is executed when multiple conditions hold. Keep track + // of the predicates that must hold for a node to execute. + using CondState = std::map; // Every unique ID is mapped to a CondState. using CondId = const CondState*; + // Keep track of which switch/merge node's feed into a node's values. + using AncestorState = std::set; + + // Every unique ID is mapped to a AncestorState. + using AncestorId = const AncestorState*; + // Returns the CondId for a given node. - CondId LookupId(const Node* node) const; + CondId LookupCondId(const Node* node) const; // Returns the unique CondId for CondState. - CondId GetUniqueId(const CondState& state); + CondId GetCondId(const CondState& state); + + // Resets the CondId for a given node. + void ResetCondId(const Node* node, CondId id); + + // Returns the AncestorId for a given node. + AncestorId LookupAncestorId(const Node* node) const; + + // Returns the unique AncestorId for CondState. + AncestorId GetAncestorId(const AncestorState& state); + + // Resets the AncestorId for a given node. + void ResetAncestorId(const Node* node, AncestorId id); // Returns the CondState for a Node. // REQUIRES: node has a non-empty CondState. const CondState& LookupState(const Node* node) const; - // Resets the CondId for a given node. - void ResetId(const Node* node, CondId id); - // Marks `node` as dead. void MarkDead(const Node* node); // Determine branch execution of CondState. BranchType FindBranchOf(CondId id, OutputTensor predicate) const; - // Enum to represent whether one cond flow state contains another. - enum ContainsResult { - kIncomparable, - kEqual, - kLhsContainsRhs, - kRhsContainsLhs - }; - - // Returns whether the lhs CondState holds wherever rhs CondState hols. I.e., - // [(p,t)] contains [(p,t), (r,t)]. - ContainsResult LhsHoldsWhereverRhsHolds(CondId lhs, CondId rhs); - // Returns textual representation of node's CondState. string CondStateToString(const Node* node) const; string CondStateToString(CondId id) const; + // Returns textual representation of node's AncestorState. + string AncestorStateToString(const Node* node) const; + // Returns whether the cond state is the dead state. bool IsDead(CondId id) const; // Returns whether the cond state is the empty state. bool IsEmpty(CondId id) const; - // Computes the predicates that have to hold for a node to execute and returns - // whether it was possible to determine the predicates that must hold. `scope` - // is populated with these predicates. Scope differs from state in that it - // does not include merge and both nodes. - bool ScopeIn(CondId id, CondId* scope); - private: - // Hash for CondNode and CondState. - struct CondHash { - size_t operator()(const CondNode& item) const; - size_t operator()(const CondState& vec) const; + // Hash for CondState and AncestorState. + struct Hash { + size_t operator()(const CondState& map) const; + size_t operator()(const AncestorState& map) const; }; // Set to keep track of unique CondStates. // Pointers to the entries in the unordered set are used as identifiers: // unordered_set guarantees that the pointers remain the same. - std::unordered_set condstate_set_; + std::unordered_set condstate_set_; // Mapping from Node id to CondId. std::vector node_to_condid_map_; @@ -150,7 +133,12 @@ class CondStateMap { // from Node id in the original graph to the CondId, but there will be nodes // added to the original graph (such as If nodes) whose CondState needs to be // tracked too. - std::unordered_map added_node_mapping_; + std::unordered_map added_node_condid_mapping_; + + // AncestorId variants of the CondId members. + std::unordered_set ancestorstate_set_; + std::vector node_to_ancestorid_map_; + std::unordered_map added_node_ancestorid_mapping_; // Identifier of the dead flow state. The empty flow state is represented with // a nullptr. @@ -173,7 +161,8 @@ class FunctionalizeCond { // Add a If node to the graph defined by def that will, amongst other, replace // replacee in the graph. - xla::StatusOr AddIfNode(const NodeDef& def, const Node* replacee); + xla::StatusOr AddIfNode(const NodeDef& def, const Node* replacee, + const OutputTensor& predicate); // Propagates the state of a newly inserted node. Status PropagateUpdatedState(const Node* replacee); @@ -185,35 +174,42 @@ class FunctionalizeCond { FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library); // Performs the actual cond functionalization. Iterate over groups of merge - // nodes (linked by common predicate & CondIds of the incomming edges), - // from innermost to outermost, and extract into If nodes. + // nodes (linked by common predicates & ancestor IDs), from innermost to + // outermost, and extract into If nodes. Status FunctionalizeInternal(); // Returns the forward flow state propagated along edge `e`. - // This may modify cond_state_map_. - CondStateMap::CondId StateAlongEdge(const Edge* e); + // This may modify state_map_. + StateMap::CondId StateAlongEdge(const Edge* e); - // Determines the CondState of all the nodes in the given vector where - // the input is expected in reverse topological order. - // This populates the cond_state_map_. - Status DetermineCondStates(std::vector rev_topo_order); + // Determines the CondState and AncestorState of all the nodes in the given + // vector where the input is expected in reverse topological order. + // This populates the state_map_. + Status DetermineStates(std::vector rev_topo_order); // Determine the CondState for a given node using the incomming edges // to the node. Note: it is expected that this node's CondState is only // determined once its input's CondState is. - Status DetermineCondState(Node* dst); + Status DetermineCondState(Node* dst) { + if (IsMerge(dst)) return DetermineCondStateMerge(dst); + return DetermineCondStateNonMerge(dst); + } // Helper functions for DetermineCondState. + Status DetermineCondStateNonMerge(Node* dst); Status DetermineCondStateMerge(Node* dst); - // Helper functions for DetermineCondStates. Determines the dst node's - // CondState by joining the src and dst's CondState where either - // the dst node is a merge or not. - // These may modify cond_state_map_. - xla::StatusOr JoinCondStatesMerge( - CondStateMap::CondId src, CondStateMap::CondId dst); - xla::StatusOr JoinCondStatesNonMerge( - CondStateMap::CondId src, CondStateMap::CondId dst); + // Determines the dst node's CondState by joining the src and dst's CondState + // where either the dst node is a merge or not. + // These may modify state_map_. + xla::StatusOr JoinCondStatesMerge(Node* merge, + StateMap::CondId src, + StateMap::CondId dst); + xla::StatusOr JoinCondStatesNonMerge(StateMap::CondId src, + StateMap::CondId dst); + + // Determines which switch/merge nodes are ancestors of this node. + Status DetermineAncestorState(Node* dst); // Checks if a merge node is redundant and if so removes it from the graph. Status RemoveRedundantMerge(Node* node); @@ -228,9 +224,13 @@ class FunctionalizeCond { // Deletes all nodes in/consumers of `delete_nodes_`. void DeleteReachableNodes(); - // Member used to unique the CondState to a unique CondId and keep track of - // CondState/CondId per Node. - CondStateMap cond_state_map_; + // Member used to unique the CondState to a unique CondId (AncestorState to a + // unique AncestorId) and keep track of CondState/CondId + // (AncestorState/AncestorId) per Node. + StateMap state_map_; + + // Mapping from merge nodes to predicate. + std::unordered_map merge_to_predicate_; // Nodes to be deleted. std::deque delete_nodes_; diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index a27f889392..b0aabd63bb 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -37,28 +37,23 @@ class FunctionalizeCondTest : public ::testing::Test { flib_def_.get())); } - CondStateMap::CondId GetUniqueId( - const CondStateMap::CondStateMap::CondState& state) { - return fc_->cond_state_map_.GetUniqueId(state); + StateMap::CondId GetUniqueId(const StateMap::StateMap::CondState& state) { + return fc_->state_map_.GetCondId(state); } - xla::StatusOr JoinCondStatesNonMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { - return fc_->JoinCondStatesNonMerge(src, dst); - } - - xla::StatusOr JoinCondStatesMerge( - CondStateMap::CondId src, CondStateMap::CondId dst) { - return fc_->JoinCondStatesMerge(src, dst); + string GetString(const StateMap::StateMap::CondId id) { + return fc_->state_map_.CondStateToString(id); } - bool ScopeIn(CondStateMap::CondId ff, CondStateMap::CondId* scope) { - return fc_->cond_state_map_.ScopeIn(ff, scope); + xla::StatusOr JoinCondStatesNonMerge(StateMap::CondId src, + StateMap::CondId dst) { + return fc_->JoinCondStatesNonMerge(src, dst); } - CondStateMap::ContainsResult LhsHoldsWhereverRhsHolds( - CondStateMap::CondId lhs, CondStateMap::CondId rhs) { - return fc_->cond_state_map_.LhsHoldsWhereverRhsHolds(lhs, rhs); + xla::StatusOr JoinCondStatesMerge(Node* n, + StateMap::CondId src, + StateMap::CondId dst) { + return fc_->JoinCondStatesMerge(n, src, dst); } FunctionDefLibrary fdef_lib_; @@ -69,50 +64,6 @@ class FunctionalizeCondTest : public ::testing::Test { namespace { -TEST_F(FunctionalizeCondTest, ScopeIn) { - Tensor pred_tensor(DT_BOOL, TensorShape()); - pred_tensor.flat().setZero(); - Node* pred = test::graph::Constant(graph_.get(), pred_tensor, "pred"); - Tensor val_tensor(DT_INT32, TensorShape()); - val_tensor.flat().setZero(); - Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); - Node* s = test::graph::Switch(graph_.get(), val, pred); - - { - CondStateMap::CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch)); - CondStateMap::CondId id = GetUniqueId(ss); - CondStateMap::CondId scope; - ASSERT_TRUE(ScopeIn(id, &scope)); - ASSERT_TRUE(id == scope); - } - - CondStateMap::CondState empty; - { - CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth)); - ss.emplace_back( - CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge)); - CondStateMap::CondId id = GetUniqueId(ss); - CondStateMap::CondId scope_1; - ASSERT_TRUE(ScopeIn(id, &scope_1)); - ASSERT_TRUE(scope_1 == GetUniqueId(empty)); - ASSERT_TRUE(id != scope_1); - - ss.clear(); - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kBoth)); - id = GetUniqueId(ss); - CondStateMap::CondId scope_2; - ASSERT_TRUE(ScopeIn(id, &scope_2)); - - ASSERT_TRUE(LhsHoldsWhereverRhsHolds(scope_1, scope_2) == - CondStateMap::ContainsResult::kLhsContainsRhs); - } -} - TEST_F(FunctionalizeCondTest, JoinCondStates) { Tensor pred_tensor(DT_BOOL, TensorShape()); pred_tensor.flat().setZero(); @@ -120,22 +71,18 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { Tensor val_tensor(DT_INT32, TensorShape()); val_tensor.flat().setZero(); Node* val = test::graph::Constant(graph_.get(), val_tensor, "val"); - Node* s = test::graph::Switch(graph_.get(), val, pred); + Node* m = test::graph::Merge(graph_.get(), val, val); - CondStateMap::CondId empty = GetUniqueId({}); - - CondStateMap::CondId then_branch; + StateMap::CondId then_branch; { - CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kThenBranch)); + StateMap::CondState ss; + ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kThenBranch)); then_branch = GetUniqueId(ss); } - CondStateMap::CondId else_branch; + StateMap::CondId else_branch; { - CondStateMap::CondState ss; - ss.emplace_back(CondStateMap::CondNode( - CondStateMap::CondNode::Type::kSwitch, s, BranchType::kElseBranch)); + StateMap::CondState ss; + ss.insert(std::make_pair(OutputTensor(pred, 0), BranchType::kElseBranch)); else_branch = GetUniqueId(ss); } @@ -144,39 +91,14 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { EXPECT_TRUE(errors::IsInvalidArgument(status)); // Merge between then and else branch. - auto joined_or = JoinCondStatesMerge(then_branch, else_branch); + auto joined_or = JoinCondStatesMerge(m, then_branch, else_branch); TF_EXPECT_OK(joined_or.status()); - CondStateMap::CondId joined = joined_or.ValueOrDie(); + StateMap::CondId joined = joined_or.ValueOrDie(); // Merge between then branch and both branch. auto t = JoinCondStatesNonMerge(then_branch, joined); // Note: this is OK in terms of constraint predication, but TF_EXPECT_OK(t.status()); - - // Post merge the propagated forward flow state has an additional merge. - CondStateMap::CondId post_merge; - { - CondStateMap::CondState ss; - ss = *joined; - ss.emplace_back( - CondStateMap::CondNode(CondStateMap::CondNode::Type::kMerge)); - post_merge = GetUniqueId(ss); - } - - t = JoinCondStatesNonMerge(post_merge, joined); - TF_EXPECT_OK(t.status()); - EXPECT_TRUE(joined == t.ValueOrDie()); - - // No predicate that results in two paths predicated on different conditions - // merge. - t = JoinCondStatesMerge(post_merge, joined); - EXPECT_FALSE(t.ok()); - - // Post the merge we are effectively in the root scope and merging should - // result in the more restrictive post merge state. - t = JoinCondStatesNonMerge(post_merge, empty); - TF_EXPECT_OK(t.status()); - EXPECT_TRUE(post_merge == t.ValueOrDie()); } } // namespace -- GitLab From 580a50a4bb30853199de191ba4d98f7390a138db Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Wed, 5 Sep 2018 07:34:32 -0700 Subject: [PATCH 116/540] utils cleanup: move the builtins module under operators. PiperOrigin-RevId: 211631516 --- .../autograph/converters/builtin_functions.py | 41 ++-- .../converters/builtin_functions_test.py | 9 +- tensorflow/contrib/autograph/impl/api.py | 4 +- tensorflow/contrib/autograph/operators/BUILD | 11 + .../contrib/autograph/operators/__init__.py | 5 + .../autograph/operators/control_flow.py | 6 +- .../autograph/operators/py_builtins.py | 225 ++++++++++++++++++ .../autograph/operators/py_builtins_test.py | 131 ++++++++++ tensorflow/contrib/autograph/utils/BUILD | 23 +- .../contrib/autograph/utils/__init__.py | 3 - .../contrib/autograph/utils/builtins.py | 143 ----------- .../contrib/autograph/utils/builtins_test.py | 145 ----------- tensorflow/contrib/autograph/utils/tensors.py | 41 ++++ .../contrib/autograph/utils/tensors_test.py | 57 +++++ 14 files changed, 508 insertions(+), 336 deletions(-) create mode 100644 tensorflow/contrib/autograph/operators/py_builtins.py create mode 100644 tensorflow/contrib/autograph/operators/py_builtins_test.py delete mode 100644 tensorflow/contrib/autograph/utils/builtins.py delete mode 100644 tensorflow/contrib/autograph/utils/builtins_test.py create mode 100644 tensorflow/contrib/autograph/utils/tensors.py create mode 100644 tensorflow/contrib/autograph/utils/tensors_test.py diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py index b26c52294c..29dce13999 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -21,6 +21,8 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.core import converter +from tensorflow.contrib.autograph.operators import py_builtins +from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates @@ -31,41 +33,32 @@ class BuiltinFunctionTransformer(converter.Base): TF equivalent, like `len`. """ - def _convert_builtin(self, node): + def _convert_builtin(self, f, args, as_expression): template = """ - ag__.utils.dynamic_builtin(func, args) + ag__.func(args) """ - return templates.replace(template, func=node.func, args=node.args)[0].value - - def _convert_print(self, node): - template = """ - ag__.utils.dynamic_print(args) - """ - return templates.replace(template, args=node.args)[0].value + if as_expression: + return templates.replace_as_expression( + template, func=py_builtins.overload_of(f).__name__, args=args) + else: + return templates.replace( + template, func=py_builtins.overload_of(f).__name__, args=args) def visit_Call(self, node): - self.generic_visit(node) - # TODO(mdan): This won't work if the function was hidden. - # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead. - if (isinstance(node.func, gast.Name) and - node.func.id in ('len', 'range', 'xrange', 'float', 'int')): - return self._convert_builtin(node) - # Print needs to be handled separately because it can be read as statement. - if isinstance(node.func, gast.Name) and node.func.id == 'print': - return self._convert_print(node) + node = self.generic_visit(node) + if anno.hasanno(node.func, 'live_val'): + live_val = anno.getanno(node.func, 'live_val') + if live_val in py_builtins.SUPPORTED_BUILTINS: + node = self._convert_builtin(live_val, node.args, as_expression=True) return node def visit_Print(self, node): - self.generic_visit(node) + node = self.generic_visit(node) args = node.values # Following is the case when calling print(a, b) if len(args) == 1 and isinstance(args[0], gast.Tuple): args = args[0].elts - template = """ - fname(args) - """ - function_call = templates.replace(template, fname='print', args=args)[0] - return self.visit(function_call) + return self._convert_builtin(print, args, as_expression=False) def transform(node, ctx): diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py index d0a0cbbeb6..3e3a04f38b 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py @@ -23,6 +23,7 @@ import six from tensorflow.contrib.autograph.converters import builtin_functions from tensorflow.contrib.autograph.core import converter_testing from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -34,11 +35,11 @@ class BuiltinFunctionsTest(converter_testing.TestCase): def test_fn(a): return len(a) - with self.converted(test_fn, builtin_functions, {'len': len}, - array_ops.shape) as result: + with self.converted(test_fn, builtin_functions, {'len': len}) as result: with self.cached_session() as sess: - ops = result.test_fn(constant_op.constant([0, 0, 0])) - self.assertEqual(sess.run(ops), 3) + p = array_ops.placeholder(dtype=dtypes.int32, shape=None) + ops = result.test_fn(p) + self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3) def test_print(self): diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index 276a387180..8b38d5d080 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -29,9 +29,9 @@ import six from tensorflow.contrib.autograph.core import config from tensorflow.contrib.autograph.core import converter from tensorflow.contrib.autograph.impl import conversion +from tensorflow.contrib.autograph.operators import py_builtins from tensorflow.contrib.autograph.pyct import compiler from tensorflow.contrib.autograph.pyct import inspect_utils -from tensorflow.contrib.autograph.utils import builtins from tensorflow.contrib.autograph.utils import py_func from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_decorator @@ -150,7 +150,7 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args, unknown_arg_value = object() # Sentinel for arguments of unknown value if inspect_utils.isbuiltin(f): - return builtins.dynamic_builtin(f, *args, **kwargs) + return py_builtins.overload_of(f)(*args, **kwargs) if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): # Regular functions diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index 332d5dab19..29759bad79 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -22,6 +22,7 @@ py_library( "__init__.py", "control_flow.py", "data_structures.py", + "py_builtins.py", "slices.py", ], srcs_version = "PY2AND3", @@ -61,6 +62,16 @@ py_test( ], ) +py_test( + name = "py_builtins_test", + srcs = ["py_builtins_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "slices_test", srcs = ["slices_test.py"], diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index 392cb60bcc..c4fbc260a2 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -45,6 +45,11 @@ from tensorflow.contrib.autograph.operators.data_structures import list_stack from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts from tensorflow.contrib.autograph.operators.data_structures import new_list +from tensorflow.contrib.autograph.operators.py_builtins import float_ +from tensorflow.contrib.autograph.operators.py_builtins import int_ +from tensorflow.contrib.autograph.operators.py_builtins import len_ +from tensorflow.contrib.autograph.operators.py_builtins import print_ +from tensorflow.contrib.autograph.operators.py_builtins import range_ from tensorflow.contrib.autograph.operators.slices import get_item from tensorflow.contrib.autograph.operators.slices import GetItemOpts from tensorflow.contrib.autograph.operators.slices import set_item diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index 9909e52164..9a66a6bb60 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.utils import builtins +from tensorflow.contrib.autograph.operators import py_builtins from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -82,8 +82,8 @@ def _py_for_stmt(iter_, extra_test, body, init_state): def _known_len_for_stmt(iter_, extra_test, body, init_state): - """Overload of for_stmt that iterates over objects that define a length.""" - n = builtins.dynamic_len(iter_) + """Overload of for_stmt that iterates over objects that admit a length.""" + n = py_builtins.len_(iter_) def while_body(iterate_index, *state): iterate = iter_[iterate_index] diff --git a/tensorflow/contrib/autograph/operators/py_builtins.py b/tensorflow/contrib/autograph/operators/py_builtins.py new file mode 100644 index 0000000000..c5730934e7 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/py_builtins.py @@ -0,0 +1,225 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators corresponding to Python builtin functions. + +List of built-in functions: https://docs.python.org/3/library/functions.html +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.autograph.utils import py_func +from tensorflow.contrib.autograph.utils import tensors +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_parsing_ops +from tensorflow.python.ops import gen_string_ops +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import math_ops + + +UNDEFINED = object() + + +def overload_of(f): + if f in SUPPORTED_BUILTINS: + return BUILTIN_FUINCTIONS_MAP[f.__name__] + return f + + +def abs_(x): + if tensor_util.is_tensor(x): + return _tf_abs(x) + return _py_abs(x) + + +def _tf_abs(x): + return math_ops.abs(x) + + +def _py_abs(x): + return abs(x) + + +def float_(x=0): + if tensor_util.is_tensor(x): + return _tf_float(x) + return _py_float(x) + + +def _tf_float(x): + # TODO(mdan): We shouldn't assume float32. + if x.dtype == dtypes.string: + return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32) + return math_ops.cast(x, dtype=dtypes.float32) + + +def _py_float(x): + return float(x) + + +def int_(x=0, base=UNDEFINED): + if tensor_util.is_tensor(x): + return _tf_int(x, base) + return _py_int(x, base) + + +def _tf_int(x, base): + if base not in (10, UNDEFINED): + raise NotImplementedError('base {} not supported for int'.format(base)) + + # TODO(mdan): We shouldn't assume int32. + if x.dtype == dtypes.string: + return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32) + return math_ops.cast(x, dtype=dtypes.int32) + + +def _py_int(x, base): + if base is UNDEFINED: + return int(x) + return int(x, base) + + +def len_(s): + if tensors.is_tensor_array(s): + return _tf_tensor_array_len(s) + elif tensors.is_tensor_list(s): + return _tf_tensor_list_len(s) + elif tensor_util.is_tensor(s): + return _tf_tensor_len(s) + return _py_len(s) + + +def _tf_tensor_array_len(s): + return s.size() + + +def _tf_tensor_list_len(s): + return list_ops.tensor_list_length(s) + + +def _tf_tensor_len(s): + """Overload of len_ for Tensor arguments.""" + # Statically shaped tensors: length is known ahead of time. + if s.shape.ndims and s.shape[0].value is not None: + return s.shape[0].value + + # Static shape of unknown dimensions: use dynamic shape but statically + # chech that it's a scalar. + shape = array_ops.shape(s) + + assert shape.shape, 'shape tensor of zero size? {}'.format(shape) + + if shape.shape[0] == 0: + raise ValueError( + 'len requires a non-scalar tensor, got one of shape {}'.format(shape)) + + if shape.shape[0].value is not None: + return array_ops.shape(s)[0] + + # Fully dynamic shape: use ops. + rank = array_ops.rank(s) + + def raise_zero_rank_error(): + msg = gen_string_ops.string_join( + ['len requires non-zero rank, got ', + gen_string_ops.as_string(rank)]) + with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]): + return constant_op.constant(0, dtype=dtypes.int32) + + return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0], + raise_zero_rank_error) + + +def _py_len(s): + return len(s) + + +def print_(*objects, **kwargs): + # Note: Python 2.6 doesn't support explicit keywords after starargs. + unknown_kwargs = tuple( + set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush'))) + if unknown_kwargs: + raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs)) + + # TODO(mdan): use logging_ops.Print when py_func is not supported. + return _tf_py_func_print(objects, kwargs) + + +def _tf_py_func_print(objects, kwargs): + """Overload of print_ as a py_func implementation.""" + override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED} + if 'flush' not in override_kwargs: + # Defaulting to flushing the console in graph mode, which helps reduce + # garbled output in IPython. + override_kwargs['flush'] = True + + def print_wrapper(*vals): + if six.PY3: + # TensorFlow doesn't seem to generate Unicode when passing strings to + # py_func. This causes the print to add a "b'" wrapper to the output, + # which is probably never what you want. + vals = tuple( + v.decode('utf-8') if isinstance(v, bytes) else v for v in vals) + six.print_(*vals, **override_kwargs) + + return py_func.wrap_py_func( + print_wrapper, None, objects, use_dummy_return=True) + + +def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED): + if any(tensor_util.is_tensor(s) for s in (start_or_stop, stop, step)): + return _tf_range(start_or_stop, stop, step) + return _py_range(start_or_stop, stop, step) + + +def _tf_range(start_or_stop, stop, step): + # TODO(mdan): We should optimize this when a full tensor is not required. + if step is not UNDEFINED: + return math_ops.range(start_or_stop, stop, step) + if stop is not UNDEFINED: + return math_ops.range(start_or_stop, stop) + return math_ops.range(start_or_stop) + + +def _py_range(start_or_stop, stop, step): + if step is not UNDEFINED: + return range(start_or_stop, stop, step) + if stop is not UNDEFINED: + return range(start_or_stop, stop) + return range(start_or_stop) + + +SUPPORTED_BUILTINS = set((abs, float, int, len, print, range)) + +if six.PY2: + SUPPORTED_BUILTINS.add(xrange) + +BUILTIN_FUINCTIONS_MAP = { + 'abs': abs_, + 'float': float_, + 'int': int_, + 'len': len_, + 'print': print_, + 'range': range_, + 'xrange': range_, +} diff --git a/tensorflow/contrib/autograph/operators/py_builtins_test.py b/tensorflow/contrib/autograph/operators/py_builtins_test.py new file mode 100644 index 0000000000..4073c51785 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/py_builtins_test.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for py_builtins module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import six + +from tensorflow.contrib.autograph.operators import data_structures +from tensorflow.contrib.autograph.operators import py_builtins +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test + + +class PyBuiltinsTest(test.TestCase): + + def test_abs(self): + self.assertEqual(py_builtins.abs_(-1), 1) + with self.test_session() as sess: + t = py_builtins.abs_(constant_op.constant(-1)) + self.assertEqual(sess.run(t), 1) + t = py_builtins.abs_(constant_op.constant([-1, 2, -3])) + self.assertAllEqual(sess.run(t), [1, 2, 3]) + + def test_float(self): + self.assertEqual(py_builtins.float_(10), 10.0) + self.assertEqual(py_builtins.float_('10.0'), 10.0) + with self.test_session() as sess: + t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64)) + self.assertEqual(sess.run(t), 1.0) + st = py_builtins.float_(constant_op.constant('1.0')) + self.assertEqual(sess.run(st), 1.0) + + def test_int(self): + self.assertEqual(py_builtins.int_(10.0), 10) + self.assertEqual(py_builtins.int_('11', 2), 3) + with self.test_session() as sess: + t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64)) + self.assertEqual(sess.run(t), 1) + st = py_builtins.int_(constant_op.constant('1')) + self.assertEqual(sess.run(st), 1) + st = py_builtins.int_(constant_op.constant('1'), 10) + self.assertEqual(sess.run(st), 1) + + def test_int_unsupported_base(self): + t = constant_op.constant(1, dtype=dtypes.float64) + with self.assertRaises(NotImplementedError): + py_builtins.int_(t, 2) + + def test_len(self): + self.assertEqual(py_builtins.len_([1, 2, 3]), 3) + with self.test_session() as sess: + t = py_builtins.len_(constant_op.constant([[1], [2], [3]])) + self.assertEqual(t, 3) + ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5)) + self.assertEqual(sess.run(ta), 5) + tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5])) + self.assertEqual(sess.run(tl), 3) + + def test_len_scalar(self): + with self.assertRaises(ValueError): + py_builtins.len_(constant_op.constant(1)) + + def test_len_dynamic_shape(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtype=dtypes.int32, shape=None) + t = py_builtins.len_(p) + self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3) + + with self.assertRaises(errors_impl.InvalidArgumentError): + t = py_builtins.len_(p) + sess.run(t, {p: 1}) + + def test_print_tensors(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(py_builtins.print_(constant_op.constant('test message'), 1)) + self.assertEqual(out_capturer.getvalue(), 'test message 1\n') + finally: + sys.stdout = sys.__stdout__ + + def test_print_complex(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run( + py_builtins.print_(constant_op.constant('test message'), [1, 2])) + self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') + finally: + sys.stdout = sys.__stdout__ + + def test_range(self): + self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2]) + self.assertListEqual(list(py_builtins.range_(1, 3)), [1, 2]) + self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1]) + + def test_range_tensor(self): + with self.test_session() as sess: + r = py_builtins.range_(constant_op.constant(3)) + self.assertAllEqual(sess.run(r), [0, 1, 2]) + r = py_builtins.range_(1, constant_op.constant(3)) + self.assertAllEqual(sess.run(r), [1, 2]) + r = py_builtins.range_(2, 0, constant_op.constant(-1)) + self.assertAllEqual(sess.run(r), [2, 1]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD index d2b399f19b..4504a5c7a3 100644 --- a/tensorflow/contrib/autograph/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -20,12 +20,12 @@ py_library( name = "utils", srcs = [ "__init__.py", - "builtins.py", "context_managers.py", "misc.py", "multiple_dispatch.py", "py_func.py", "tensor_list.py", + "tensors.py", "testing.py", "type_check.py", ], @@ -41,17 +41,6 @@ py_library( ], ) -py_test( - name = "builtins_test", - srcs = ["builtins_test.py"], - srcs_version = "PY2AND3", - tags = ["no_windows"], - deps = [ - ":utils", - "//tensorflow/python:client_testlib", - ], -) - py_test( name = "context_managers_test", srcs = ["context_managers_test.py"], @@ -113,3 +102,13 @@ py_test( "//tensorflow/python:list_ops", ], ) + +py_test( + name = "tensors_test", + srcs = ["tensors_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py index 57b5f74741..38e0a0a8f0 100644 --- a/tensorflow/contrib/autograph/utils/__init__.py +++ b/tensorflow/contrib/autograph/utils/__init__.py @@ -18,9 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin -from tensorflow.contrib.autograph.utils.builtins import dynamic_print -from tensorflow.contrib.autograph.utils.builtins import dynamic_range from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns from tensorflow.contrib.autograph.utils.misc import alias_tensors from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py deleted file mode 100644 index 4dd440ef19..0000000000 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Builtin conversion utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - -import six - -from tensorflow.contrib.autograph.utils import py_func -from tensorflow.contrib.autograph.utils import type_check -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import list_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import math_ops - - -def dynamic_builtin(f, *args, **kwargs): - """Converts a builtin function call inline.""" - if f is len: - return dynamic_len(*args, **kwargs) - if six.PY2 and f is xrange: - return dynamic_range(*args, **kwargs) - if f is range: - return dynamic_range(*args, **kwargs) - if f is int: - return dynamic_int(*args, **kwargs) - if f is float: - return dynamic_float(*args, **kwargs) - if f is abs: - return dynamic_abs(*args, **kwargs) - - raise NotImplementedError( - 'The "%s" builtin is not yet supported.' % f.__name__) - - -def dynamic_len(list_or_tensor): - """Implementation of len using dynamic dispatch.""" - if _is_tensor_list(list_or_tensor): - return list_ops.tensor_list_length(list_or_tensor) - elif tensor_util.is_tensor(list_or_tensor): - shape = list_or_tensor.shape - if not shape.ndims: - raise ValueError( - 'len requires non-zero rank for tensor "%s"' % list_or_tensor) - return array_ops.shape(list_or_tensor)[0] - return len(list_or_tensor) - - -def _is_tensor_list(list_or_tensor): - return (tensor_util.is_tensor(list_or_tensor) - and list_or_tensor.dtype == dtypes.variant) - - -def dynamic_int(num_or_tensor, **kwargs): - """Implementation of int() using dynamic dispatch.""" - if tensor_util.is_tensor(num_or_tensor): - return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs) - return int(num_or_tensor) - - -def dynamic_float(num_or_tensor, **kwargs): - """Implementation of float() using dynamic dispatch.""" - if tensor_util.is_tensor(num_or_tensor): - return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs) - return float(num_or_tensor) - - -def dynamic_abs(num_or_tensor, **kwargs): - if tensor_util.is_tensor(num_or_tensor): - return math_ops.abs(num_or_tensor, **kwargs) - else: - return abs(num_or_tensor, **kwargs) - - -def dynamic_range(start_or_stop, stop=None, step=None): - """Implementation of range using dynamic dispatch.""" - if type_check.is_tensor(start_or_stop, stop, step): - if step is not None: - return math_ops.range(start_or_stop, stop, step) - if stop is not None: - return math_ops.range(start_or_stop, stop) - return math_ops.range(start_or_stop) - - if step is not None: - return range(start_or_stop, stop, step) - elif stop is not None: - return range(start_or_stop, stop) - return range(start_or_stop) - - -def is_tf_print_compatible(value): - # TODO(mdan): Enable once we can reliably test this. - # This is currently disabled because we can't capture the output of - # op kernels from Python. - del value - return False - - -def dynamic_print(*values): - """Implementation of print using dynamic dispatch. - - The function attempts to use tf.Print if all the values are compatible. - Otherwise, it will fall back to py_func. - - Args: - *values: values to print - Returns: - A dummy value indicating the print completed. If tf. - """ - - if all(map(is_tf_print_compatible, values)): - return logging_ops.Print(1, values) - - def print_wrapper(*vals): - if six.PY3: - # TensorFlow doesn't seem to generate Unicode when passing strings to - # py_func. This causes the print to add a "b'" wrapper to the output, - # which is probably never what you want. - vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals) - print(*vals) - # The flush helps avoid garbled output in IPython. - sys.stdout.flush() - - return py_func.wrap_py_func( - print_wrapper, None, values, use_dummy_return=True) diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py deleted file mode 100644 index b1cd5253bc..0000000000 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for builtins module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - -import six - -from tensorflow.contrib.autograph.utils import builtins -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.platform import test - - -class BuiltinsTest(test.TestCase): - - def test_dynamic_len_tf_scalar(self): - a = constant_op.constant(1) - - with self.assertRaisesRegexp(ValueError, - 'len requires non-zero rank for tensor.*'): - with self.test_session() as sess: - sess.run(builtins.dynamic_builtin(len, a)) - - def test_dynamic_len_tf_array(self): - a = constant_op.constant([1, 2, 3]) - - with self.test_session() as sess: - self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a))) - - def test_dynamic_abs_tf_scalar(self): - a = constant_op.constant(-1) - - with self.test_session() as sess: - self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a))) - - def test_dynamic_abs_tf_array(self): - a = constant_op.constant([-1, 2, -3]) - - with self.test_session() as sess: - self.assertListEqual([1, 2, 3], - list(sess.run(builtins.dynamic_builtin(abs, a)))) - - def test_dynamic_abs_py_scalar(self): - a = -1 - self.assertEqual(1, builtins.dynamic_builtin(abs, a)) - - def test_dynamic_len_tf_matrix(self): - a = constant_op.constant([[1, 2], [3, 4]]) - - with self.test_session() as sess: - self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a))) - - def test_dynamic_len_py_list(self): - a = [3] * 5 - - self.assertEqual(5, builtins.dynamic_builtin(len, a)) - - def test_dynamic_range_all_python(self): - self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2]) - self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2]) - self.assertListEqual( - list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1]) - - def test_dynamic_range_tf(self): - with self.test_session() as sess: - self.assertAllEqual( - sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))), - [0, 1, 2]) - self.assertAllEqual( - sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))), - [1, 2]) - self.assertAllEqual( - sess.run( - builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))), - [2, 1]) - - def test_dynamic_range_detection(self): - def range(x): # pylint:disable=redefined-builtin - return x - - # Functions that just have the names of builtins are rejected. - with self.assertRaises(NotImplementedError): - self.assertEqual(builtins.dynamic_builtin(range, 1), 1) - if six.PY2: - self.assertListEqual( - list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2]) - self.assertListEqual( - list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2]) - self.assertListEqual( - list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2]) - - def test_casts(self): - i = constant_op.constant(2, dtype=dtypes.int32) - f = constant_op.constant(1.0, dtype=dtypes.float32) - - self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32) - self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32) - self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32) - self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32) - - self.assertEqual(builtins.dynamic_builtin(int, True), 1) - self.assertEqual(builtins.dynamic_builtin(int, False), 0) - self.assertEqual(builtins.dynamic_builtin(float, True), 1.0) - self.assertEqual(builtins.dynamic_builtin(float, False), 0.0) - - def test_dynamic_print_tf(self): - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - with self.test_session() as sess: - sess.run(builtins.dynamic_print('test message', 1)) - self.assertEqual(out_capturer.getvalue(), 'test message 1\n') - finally: - sys.stdout = sys.__stdout__ - - def test_dynamic_print_complex(self): - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - with self.test_session() as sess: - sess.run(builtins.dynamic_print('test message', [1, 2])) - self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') - finally: - sys.stdout = sys.__stdout__ - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/autograph/utils/tensors.py b/tensorflow/contrib/autograph/utils/tensors.py new file mode 100644 index 0000000000..fa5db81a71 --- /dev/null +++ b/tensorflow/contrib/autograph/utils/tensors.py @@ -0,0 +1,41 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""This module defines tensor utilities not found in TensorFlow. + +The reason these utilities are not defined in TensorFlow is because they may +not be not fully robust, although they work in the vast majority of cases. So +we define them here in order for their behavior to be consistently verified. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import tensor_array_ops + + +def is_tensor_array(t): + return isinstance(t, tensor_array_ops.TensorArray) + + +def is_tensor_list(t): + # TODO(mdan): This is just a heuristic. + # With TF lacking support for templated types, this is unfortunately the + # closest we can get right now. A dedicated op ought to be possible to + # construct. + return (tensor_util.is_tensor(t) and t.dtype == dtypes.variant and + not t.shape.ndims) diff --git a/tensorflow/contrib/autograph/utils/tensors_test.py b/tensorflow/contrib/autograph/utils/tensors_test.py new file mode 100644 index 0000000000..e855e0b6cb --- /dev/null +++ b/tensorflow/contrib/autograph/utils/tensors_test.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensors module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils import tensors +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.platform import test + + +class TensorsTest(test.TestCase): + + def _simple_tensor_array(self): + return tensor_array_ops.TensorArray(dtypes.int32, size=3) + + def _simple_tensor_list(self): + return list_ops.empty_tensor_list( + element_shape=constant_op.constant([1]), element_dtype=dtypes.int32) + + def _simple_list_of_tensors(self): + return [constant_op.constant(1), constant_op.constant(2)] + + def test_is_tensor_array(self): + self.assertTrue(tensors.is_tensor_array(self._simple_tensor_array())) + self.assertFalse(tensors.is_tensor_array(self._simple_tensor_list())) + self.assertFalse(tensors.is_tensor_array(constant_op.constant(1))) + self.assertFalse(tensors.is_tensor_array(self._simple_list_of_tensors())) + self.assertFalse(tensors.is_tensor_array(None)) + + def test_is_tensor_list(self): + self.assertFalse(tensors.is_tensor_list(self._simple_tensor_array())) + self.assertTrue(tensors.is_tensor_list(self._simple_tensor_list())) + self.assertFalse(tensors.is_tensor_list(constant_op.constant(1))) + self.assertFalse(tensors.is_tensor_list(self._simple_list_of_tensors())) + self.assertFalse(tensors.is_tensor_list(None)) + + +if __name__ == '__main__': + test.main() -- GitLab From 1f96f9d350726b06a9f44aebcb4c1df54693894a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 07:56:52 -0700 Subject: [PATCH 117/540] Convert more kernel signatures to use runtime shapes. PiperOrigin-RevId: 211633744 --- .../internal/optimized/optimized_ops.h | 12 + .../internal/reference/reference_ops.h | 397 +++++++++++++----- .../contrib/lite/kernels/internal/types.h | 6 +- 3 files changed, 309 insertions(+), 106 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 9b35648b4e..2c8e8f90e3 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -43,6 +43,14 @@ namespace optimized_ops { // Unoptimized reference ops: using reference_ops::ArgMax; using reference_ops::ArgMinMax; +using reference_ops::Broadcast4DSlowGreater; +using reference_ops::Broadcast4DSlowGreaterEqual; +using reference_ops::Broadcast4DSlowGreaterEqualWithScaling; +using reference_ops::Broadcast4DSlowGreaterWithScaling; +using reference_ops::Broadcast4DSlowLess; +using reference_ops::Broadcast4DSlowLessEqual; +using reference_ops::Broadcast4DSlowLessEqualWithScaling; +using reference_ops::Broadcast4DSlowLessWithScaling; using reference_ops::BroadcastAdd4DSlow; using reference_ops::BroadcastGreater; using reference_ops::BroadcastGreaterEqual; @@ -58,8 +66,12 @@ using reference_ops::FakeQuant; using reference_ops::Gather; using reference_ops::Greater; using reference_ops::GreaterEqual; +using reference_ops::GreaterEqualWithScaling; +using reference_ops::GreaterWithScaling; using reference_ops::Less; using reference_ops::LessEqual; +using reference_ops::LessEqualWithScaling; +using reference_ops::LessWithScaling; using reference_ops::Mean; using reference_ops::RankOneSelect; using reference_ops::Relu1; diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index e5b71f81fa..00f9616cc2 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -3452,23 +3452,55 @@ inline void Floor(const RuntimeShape& input_shape, const float* input_data, } template -inline void Gather(const T* input_data, const Dims<4>& input_dims, - int input_rank, const int32* coords_data, - const Dims<4>& coords_dims, T* output_data, - const Dims<4>& output_dims) { - TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); - int stride = input_dims.strides[input_rank - 1]; +inline void Gather(const tflite::GatherParams& op_params, + const RuntimeShape& input_shape, const T* input_data, + const RuntimeShape& coords_shape, const int32* coords_data, + const RuntimeShape& output_shape, T* output_data) { + // TODO(b/80418076): Enable these checks when moving legacy ops to + // legacy_reference_ops. + // + // TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1); + const int input_rank = op_params.input_rank; + const int gather_dimensions = output_shape.DimensionsCount(); + TFLITE_DCHECK_LE(input_shape.DimensionsCount(), gather_dimensions); + const int axis = gather_dimensions - input_rank; + TFLITE_DCHECK_LT(axis, gather_dimensions); + TFLITE_DCHECK_GE(axis, 0); + const int coords_count = coords_shape.FlatSize(); + TFLITE_DCHECK_EQ(coords_count, output_shape.Dims(axis)); + + int64_t stride = 1; + for (int i = axis + 1; i < gather_dimensions; ++i) { + stride *= input_shape.Dims(i); + } T* out = output_data; - for (int i = 0; i < coords_dims.sizes[0]; i++) { + for (int i = 0; i < coords_count; ++i) { TFLITE_DCHECK_GE(coords_data[i], 0); - TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); + TFLITE_DCHECK_LT(coords_data[i], input_shape.Dims(axis)); const T* in = input_data + coords_data[i] * stride; memcpy(out, in, sizeof(T) * stride); out += stride; } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy Dims<4> version. +// When moving legacy ops to legacy_reference_ops, replace content with looser +// implementation. +template +inline void Gather(const T* input_data, const Dims<4>& input_dims, + int input_rank, const int32* coords_data, + const Dims<4>& coords_dims, T* output_data, + const Dims<4>& output_dims) { + tflite::GatherParams op_params; + op_params.input_rank = input_rank; + + Gather(op_params, DimsToShape(input_dims), input_data, + DimsToShape(coords_dims), coords_data, DimsToShape(output_dims), + output_data); +} + template inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, const RuntimeShape& unextended_input_shape, @@ -4337,9 +4369,10 @@ template using ComparisonFn = bool (*)(T, T); template F> -inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data, - const RuntimeShape& input2_shape, const T* input2_data, - const RuntimeShape& output_shape, bool* output_data) { +inline void ComparisonImpl( + const ComparisonParams& op_params, const RuntimeShape& input1_shape, + const T* input1_data, const RuntimeShape& input2_shape, + const T* input2_data, const RuntimeShape& output_shape, bool* output_data) { const int64_t flatsize = MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int64_t i = 0; i < flatsize; ++i) { @@ -4347,25 +4380,45 @@ inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data, } } +template F> +inline void Comparison(const ComparisonParams& op_params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, bool* output_data) { + ComparisonImpl(op_params, input1_shape, input1_data, input2_shape, + input2_data, output_shape, output_data); +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. template F> inline void Comparison(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, bool* output_data, const Dims<4>& output_dims) { - Comparison(DimsToShape(input1_dims), input1_data, - DimsToShape(input2_dims), input2_data, - DimsToShape(output_dims), output_data); + ComparisonParams op_params; + // No parameters needed. + ComparisonImpl(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); } template F> -inline void Comparison(int left_shift, const T* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const T* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, bool* output_data, - const Dims<4>& output_dims) { +inline void ComparisonWithScaling( + const ComparisonParams& op_params, const RuntimeShape& input1_shape, + const T* input1_data, const RuntimeShape& input2_shape, + const T* input2_data, const RuntimeShape& output_shape, bool* output_data) { + int left_shift = op_params.left_shift; + int32 input1_offset = op_params.input1_offset; + int32 input1_multiplier = op_params.input1_multiplier; + int input1_shift = op_params.input1_shift; + int32 input2_offset = op_params.input2_offset; + int32 input2_multiplier = op_params.input2_multiplier; + int input2_shift = op_params.input2_shift; + const int64_t flatsize = - MatchingFlatSize(input1_dims, input2_dims, output_dims); + MatchingFlatSize(input1_shape, input2_shape, output_shape); for (int64_t i = 0; i < flatsize; ++i) { const int32 input1_val = input1_offset + input1_data[i]; const int32 input2_val = input2_offset + input2_data[i]; @@ -4373,68 +4426,140 @@ inline void Comparison(int left_shift, const T* input1_data, const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); + shifted_input1_val, input1_multiplier, input1_shift); const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); + shifted_input2_val, input2_multiplier, input2_shift); output_data[i] = F(scaled_input1_val, scaled_input2_val); } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +template F> +inline void Comparison(int left_shift, const T* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const T* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, bool* output_data, + const Dims<4>& output_dims) { + tflite::ComparisonParams op_params; + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + + ComparisonWithScaling(op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, + DimsToShape(output_dims), output_data); +} + template F> -inline void BroadcastComparison(const T* input1_data, - const Dims<4>& input1_dims, - const T* input2_data, - const Dims<4>& input2_dims, bool* output_data, - const Dims<4>& output_dims) { +inline void BroadcastComparison4DSlowImpl( + const ComparisonParams& op_params, + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const T* input2_data, + const RuntimeShape& unextended_output_shape, bool* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow"); + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = - F(input1_data[SubscriptToIndex(desc1, c, x, y, b)], - input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, + unextended_input2_shape, &desc1, &desc2); + + for (int b = 0; b < output_shape.Dims(0); ++b) { + for (int y = 0; y < output_shape.Dims(1); ++y) { + for (int x = 0; x < output_shape.Dims(2); ++x) { + for (int c = 0; c < output_shape.Dims(3); ++c) { + output_data[Offset(output_shape, b, y, x, c)] = + F(input1_data[SubscriptToIndex(desc1, b, y, x, c)], + input2_data[SubscriptToIndex(desc2, b, y, x, c)]); } } } } } +template F> +inline void BroadcastComparison4DSlow(const ComparisonParams& op_params, + const RuntimeShape& input1_shape, + const float* input1_data, + const RuntimeShape& input2_shape, + const float* input2_data, + const RuntimeShape& output_shape, + bool* output_data) { + BroadcastComparison4DSlowImpl(op_params, input1_shape, input1_data, + input2_shape, input2_data, + output_shape, output_data); +} -template F> -inline void BroadcastComparison(int left_shift, const T* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +template F> +inline void BroadcastComparison(const T* input1_data, + const Dims<4>& input1_dims, const T* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 input2_multiplier, int input2_shift, - bool* output_data, const Dims<4>& output_dims) { + const Dims<4>& input2_dims, bool* output_data, + const Dims<4>& output_dims) { + ComparisonParams op_params; + // No parameters needed. + BroadcastComparison4DSlowImpl(op_params, DimsToShape(input1_dims), + input1_data, DimsToShape(input2_dims), + input2_data, DimsToShape(output_dims), + output_data); +} + +template F> +inline void BroadcastComparison4DSlowWithScaling( + const ComparisonParams& op_params, + const RuntimeShape& unextended_input1_shape, const T* input1_data, + const RuntimeShape& unextended_input2_shape, const T* input2_data, + const RuntimeShape& unextended_output_shape, bool* output_data) { + gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling"); + TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; - NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); - for (int b = 0; b < ArraySize(output_dims, 3); ++b) { - for (int y = 0; y < ArraySize(output_dims, 2); ++y) { - for (int x = 0; x < ArraySize(output_dims, 1); ++x) { - for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, + unextended_input2_shape, &desc1, &desc2); + + int left_shift = op_params.left_shift; + int32 input1_offset = op_params.input1_offset; + int32 input1_multiplier = op_params.input1_multiplier; + int input1_shift = op_params.input1_shift; + int32 input2_offset = op_params.input2_offset; + int32 input2_multiplier = op_params.input2_multiplier; + int input2_shift = op_params.input2_shift; + + for (int b = 0; b < output_shape.Dims(0); ++b) { + for (int y = 0; y < output_shape.Dims(1); ++y) { + for (int x = 0; x < output_shape.Dims(2); ++x) { + for (int c = 0; c < output_shape.Dims(3); ++c) { const int32 input1_val = - input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)]; const int32 input2_val = - input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)]; const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input1_val, input1_multiplier, - kReverseShift * input1_shift); + shifted_input1_val, input1_multiplier, input1_shift); const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOneExp( - shifted_input2_val, input2_multiplier, - kReverseShift * input2_shift); - output_data[Offset(output_dims, c, x, y, b)] = + shifted_input2_val, input2_multiplier, input2_shift); + output_data[Offset(output_shape, b, y, x, c)] = F(scaled_input1_val, scaled_input2_val); } } @@ -4442,51 +4567,117 @@ inline void BroadcastComparison(int left_shift, const T* input1_data, } } -#define TFLITE_COMPARISON_OP(name) \ - template \ - inline void name(const T* input1_data, const Dims<4>& input1_dims, \ - const T* input2_data, const Dims<4>& input2_dims, \ - bool* output_data, const Dims<4>& output_dims) { \ - gemmlowp::ScopedProfilingLabel label(#name); \ - Comparison(input1_data, input1_dims, input2_data, \ - input2_dims, output_data, output_dims); \ - } \ - template \ - inline void name( \ - int left_shift, const T* input1_data, const Dims<4>& input1_dims, \ - int32 input1_offset, int32 input1_multiplier, int input1_shift, \ - const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \ - int32 input2_multiplier, int input2_shift, bool* output_data, \ - const Dims<4>& output_dims) { \ - gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \ - Comparison(left_shift, input1_data, input1_dims, \ - input1_offset, input1_multiplier, input1_shift, \ - input2_data, input2_dims, input2_offset, \ - input2_multiplier, input2_shift, output_data, \ - output_dims); \ - } \ - template \ - inline void Broadcast##name( \ - const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \ - const Dims<4>& input2_dims, bool* output_data, \ - const Dims<4>& output_dims) { \ - gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \ - BroadcastComparison(input1_data, input1_dims, input2_data, \ - input2_dims, output_data, output_dims); \ - } \ - template \ - inline void Broadcast##name( \ - int left_shift, const T* input1_data, const Dims<4>& input1_dims, \ - int32 input1_offset, int32 input1_multiplier, int input1_shift, \ - const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \ - int32 input2_multiplier, int input2_shift, bool* output_data, \ - const Dims<4>& output_dims) { \ - gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \ - BroadcastComparison(left_shift, input1_data, input1_dims, \ - input1_offset, input1_multiplier, \ - input1_shift, input2_data, input2_dims, \ - input2_offset, input2_multiplier, \ - input2_shift, output_data, output_dims); \ +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +template F> +inline void BroadcastComparison(int left_shift, const T* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const T* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 input2_multiplier, int input2_shift, + bool* output_data, const Dims<4>& output_dims) { + ComparisonParams op_params; + + op_params.left_shift = left_shift; + op_params.input1_offset = input1_offset; + op_params.input1_multiplier = input1_multiplier; + op_params.input1_shift = kReverseShift * input1_shift; + op_params.input2_offset = input2_offset; + op_params.input2_multiplier = input2_multiplier; + op_params.input2_shift = kReverseShift * input2_shift; + + BroadcastComparison4DSlowWithScaling( + op_params, DimsToShape(input1_dims), input1_data, + DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), + output_data); +} + +#define TFLITE_COMPARISON_OP(name) \ + template \ + inline void name(const T* input1_data, const Dims<4>& input1_dims, \ + const T* input2_data, const Dims<4>& input2_dims, \ + bool* output_data, const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label(#name); \ + Comparison(input1_data, input1_dims, input2_data, \ + input2_dims, output_data, output_dims); \ + } \ + template \ + inline void name( \ + int left_shift, const T* input1_data, const Dims<4>& input1_dims, \ + int32 input1_offset, int32 input1_multiplier, int input1_shift, \ + const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \ + int32 input2_multiplier, int input2_shift, bool* output_data, \ + const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \ + Comparison(left_shift, input1_data, input1_dims, \ + input1_offset, input1_multiplier, input1_shift, \ + input2_data, input2_dims, input2_offset, \ + input2_multiplier, input2_shift, output_data, \ + output_dims); \ + } \ + template \ + inline void Broadcast##name( \ + const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \ + const Dims<4>& input2_dims, bool* output_data, \ + const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \ + BroadcastComparison(input1_data, input1_dims, input2_data, \ + input2_dims, output_data, output_dims); \ + } \ + template \ + inline void Broadcast##name( \ + int left_shift, const T* input1_data, const Dims<4>& input1_dims, \ + int32 input1_offset, int32 input1_multiplier, int input1_shift, \ + const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \ + int32 input2_multiplier, int input2_shift, bool* output_data, \ + const Dims<4>& output_dims) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \ + BroadcastComparison(left_shift, input1_data, input1_dims, \ + input1_offset, input1_multiplier, \ + input1_shift, input2_data, input2_dims, \ + input2_offset, input2_multiplier, \ + input2_shift, output_data, output_dims); \ + } \ + inline void name(const ComparisonParams& op_params, \ + const RuntimeShape& input1_shape, const float* input1_data, \ + const RuntimeShape& input2_shape, const float* input2_data, \ + const RuntimeShape& output_shape, bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label(#name); \ + Comparison(op_params, input1_shape, input1_data, input2_shape, \ + input2_data, output_shape, output_data); \ + } \ + template \ + inline void name##WithScaling( \ + const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ + const T* input1_data, const RuntimeShape& input2_shape, \ + const T* input2_data, const RuntimeShape& output_shape, \ + bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \ + ComparisonWithScaling(op_params, input1_shape, input1_data, \ + input2_shape, input2_data, \ + output_shape, output_data); \ + } \ + inline void Broadcast4DSlow##name( \ + const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ + const float* input1_data, const RuntimeShape& input2_shape, \ + const float* input2_data, const RuntimeShape& output_shape, \ + bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \ + BroadcastComparison4DSlow(op_params, input1_shape, input1_data, \ + input2_shape, input2_data, \ + output_shape, output_data); \ + } \ + template \ + inline void Broadcast4DSlow##name##WithScaling( \ + const ComparisonParams& op_params, const RuntimeShape& input1_shape, \ + const T* input1_data, const RuntimeShape& input2_shape, \ + const T* input2_data, const RuntimeShape& output_shape, \ + bool* output_data) { \ + gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \ + BroadcastComparison4DSlowWithScaling( \ + op_params, input1_shape, input1_data, input2_shape, input2_data, \ + output_shape, output_data); \ } TFLITE_COMPARISON_OP(Equal); TFLITE_COMPARISON_OP(NotEqual); diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 6ae4ebc79e..9f6e74a267 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -720,12 +720,12 @@ struct ConcatenationParams { struct ComparisonParams { // uint8 inference params. int left_shift; - int32 input0_offset; - int32 input0_multiplier; - int input0_shift; int32 input1_offset; int32 input1_multiplier; int input1_shift; + int32 input2_offset; + int32 input2_multiplier; + int input2_shift; // Shape dependent / common to inference types. bool is_broadcast; }; -- GitLab From cb520088ac02b25e7ccc720ca7fbb01692d2a0c2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 08:23:07 -0700 Subject: [PATCH 118/540] Exclude icf=all from TFLite linker options on iOS. PiperOrigin-RevId: 211637019 --- tensorflow/contrib/lite/build_def.bzl | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index fc199f0a0e..0246e7fa30 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -57,6 +57,7 @@ def tflite_linkopts_unstripped(): "-Wl,--as-needed", # Don't link unused libs. ], "//tensorflow:darwin": [], + "//tensorflow:ios": [], "//tensorflow/contrib/lite:mips": [], "//tensorflow/contrib/lite:mips64": [], "//conditions:default": [ -- GitLab From cdf986398f9c92b636a0c8a973e4cccb3749d9ef Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 5 Sep 2018 08:42:48 -0700 Subject: [PATCH 119/540] Alias tensorflow::gtl::InlinedVector to absl::InlinedVector PiperOrigin-RevId: 211639440 --- tensorflow/core/BUILD | 2 +- .../core/common_runtime/pool_allocator.cc | 1 + tensorflow/core/lib/gtl/inlined_vector.h | 665 +------------ .../core/lib/gtl/inlined_vector_test.cc | 898 ------------------ .../core/platform/default/build_config.bzl | 1 + tensorflow/stream_executor/blas.h | 1 + 6 files changed, 9 insertions(+), 1559 deletions(-) delete mode 100644 tensorflow/core/lib/gtl/inlined_vector_test.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 5c314f359c..c06fea130f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -695,6 +695,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":lib_internal", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], @@ -3220,7 +3221,6 @@ tf_cc_tests( "lib/gtl/edit_distance_test.cc", "lib/gtl/flatmap_test.cc", "lib/gtl/flatset_test.cc", - "lib/gtl/inlined_vector_test.cc", "lib/gtl/int_type_test.cc", "lib/gtl/iterator_range_test.cc", "lib/gtl/manual_constructor_test.cc", diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc index 10a24ed14c..fdad8de8d6 100644 --- a/tensorflow/core/common_runtime/pool_allocator.cc +++ b/tensorflow/core/common_runtime/pool_allocator.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h index c18dc9ad1a..2d622dc229 100644 --- a/tensorflow/core/lib/gtl/inlined_vector.h +++ b/tensorflow/core/lib/gtl/inlined_vector.h @@ -13,674 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// An InlinedVector is like a std::vector, except that storage -// for sequences of length <= N are provided inline without requiring -// any heap allocation. Typically N is very small (e.g., 4) so that -// sequences that are expected to be short do not require allocations. -// -// Only some of the std::vector<> operations are currently implemented. -// Other operations may be added as needed to facilitate migrating -// code that uses std::vector<> to InlinedVector<>. -// -// NOTE: If you want an inlined version to replace use of a -// std::vector, consider using util::bitmap::InlinedBitVector -// in util/bitmap/inlined_bitvector.h -// -// TODO(billydonahue): change size_t to size_type where appropriate. - #ifndef TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ #define TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/core/lib/gtl/manual_constructor.h" -#include "tensorflow/core/platform/byte_order.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mem.h" +#include "absl/container/inlined_vector.h" +// TODO(kramerb): This is kept only because lots of targets transitively depend +// on it. Remove all targets' dependencies. +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include // NOLINT(build/include_order) - namespace tensorflow { namespace gtl { -template -class InlinedVector { - public: - typedef T value_type; - typedef T* pointer; - typedef const T* const_pointer; - typedef T& reference; - typedef const T& const_reference; - typedef size_t size_type; - typedef std::ptrdiff_t difference_type; - typedef pointer iterator; - typedef const_pointer const_iterator; - - // Create an empty vector - InlinedVector(); - - // Create a vector with n copies of value_type(). - explicit InlinedVector(size_t n); - - // Create a vector with n copies of elem - InlinedVector(size_t n, const value_type& elem); - - // Create and initialize with the elements [range_start .. range_end). - // The unused enable_if argument restricts this constructor so that it is - // elided when value_type is an integral type. This prevents ambiguous - // interpretation between a call to this constructor with two integral - // arguments and a call to the preceding (n, elem) constructor. - template - InlinedVector( - InputIterator range_start, InputIterator range_end, - typename std::enable_if::value>::type* = - NULL) { - InitRep(); - AppendRange(range_start, range_end); - } - - InlinedVector(std::initializer_list init) { - InitRep(); - AppendRange(init.begin(), init.end()); - } - - InlinedVector(const InlinedVector& v); - - ~InlinedVector() { clear(); } - - InlinedVector& operator=(const InlinedVector& v) { - // Optimized to avoid reallocation. - // Prefer reassignment to copy construction for elements. - const size_t s = size(); - const size_t vs = v.size(); - if (s < vs) { // grow - reserve(vs); - if (s) std::copy(v.begin(), v.begin() + s, begin()); - std::copy(v.begin() + s, v.end(), std::back_inserter(*this)); - } else { // maybe shrink - erase(begin() + vs, end()); - std::copy(v.begin(), v.end(), begin()); - } - return *this; - } - - size_t size() const { return size_internal(); } - - bool empty() const { return (size() == 0); } - - // Return number of elements that can be stored in vector - // without requiring a reallocation of underlying memory - size_t capacity() const { - if (is_inline()) { - return kFit; - } else { - return static_cast(1) << u_.data[kSize - 2]; - } - } - - // Return a pointer to the underlying array. - // Only result[0,size()-1] are defined. - pointer data() { - if (is_inline()) { - return reinterpret_cast(u_.data); - } else { - return outofline_pointer(); - } - } - const_pointer data() const { - return const_cast*>(this)->data(); - } - - // Remove all elements - void clear() { - DiscardStorage(); - u_.data[kSize - 1] = 0; - } - - // Return the ith element - // REQUIRES: 0 <= i < size() - const value_type& at(size_t i) const { - DCHECK_LT(i, size()); - return data()[i]; - } - const value_type& operator[](size_t i) const { - DCHECK_LT(i, size()); - return data()[i]; - } - - // Return a non-const reference to the ith element - // REQUIRES: 0 <= i < size() - value_type& at(size_t i) { - DCHECK_LT(i, size()); - return data()[i]; - } - value_type& operator[](size_t i) { - DCHECK_LT(i, size()); - return data()[i]; - } - - value_type& back() { - DCHECK(!empty()); - return at(size() - 1); - } - - const value_type& back() const { - DCHECK(!empty()); - return at(size() - 1); - } - - value_type& front() { - DCHECK(!empty()); - return at(0); - } - - const value_type& front() const { - DCHECK(!empty()); - return at(0); - } - - // Append a T constructed with args to the vector. - // Increases size() by one. - // Amortized complexity: O(1) - // Worst-case complexity: O(size()) - template - void emplace_back(Args&&... args) { - size_t s = size(); - DCHECK_LE(s, capacity()); - if (s < capacity()) { - new (data() + s) T(std::forward(args)...); - set_size_internal(s + 1); - } else { - EmplaceBackSlow(std::forward(args)...); - } - } - - // Append t to the vector. - // Increases size() by one. - // Amortized complexity: O(1) - // Worst-case complexity: O(size()) - void push_back(const value_type& t) { emplace_back(t); } - void push_back(value_type&& t) { emplace_back(std::move(t)); } - - inline void pop_back() { - DCHECK(!empty()); - const size_t s = size(); - Destroy(data() + s - 1, 1); - set_size_internal(s - 1); - } - - // Resizes the vector to contain "n" elements. - // If "n" is smaller than the initial size, extra elements are destroyed. - // If "n" is larger than the initial size, enough copies of "elem" - // are appended to increase the size to "n". If "elem" is omitted, - // new elements are value-initialized. - void resize(size_t n) { Resize(n, nullptr); } - void resize(size_t n, const value_type& elem) { Resize(n, &elem); } - - iterator begin() { return data(); } - const_iterator begin() const { return data(); } - - iterator end() { return data() + size(); } - const_iterator end() const { return data() + size(); } - - iterator insert(iterator pos, const value_type& v); - - iterator erase(iterator pos) { - DCHECK_LT(pos, end()); - DCHECK_GE(pos, begin()); - std::copy(pos + 1, end(), pos); - pop_back(); - return pos; - } - - iterator erase(iterator first, iterator last); - - // Enlarges the underlying representation so it can hold at least - // "n" elements without reallocation. - // Does not change size() or the actual contents of the vector. - void reserve(size_t n) { - if (n > capacity()) { - // Make room for new elements - Grow(n); - } - } - - // Swap the contents of *this with other. - // REQUIRES: value_type is swappable and copyable. - void swap(InlinedVector& other); - - private: - // Representation can either be inlined or out-of-line. - // In either case, at least sizeof(void*) + 8 bytes are available. - // - // Inlined: - // Last byte holds the length. - // First (length*sizeof(T)) bytes stores the elements. - // Outlined: - // Last byte holds kSentinel. - // Second-last byte holds lg(capacity) - // Preceding 6 bytes hold size. - // First sizeof(T*) bytes hold pointer. - - // Compute rep size. - static const size_t kSizeUnaligned = N * sizeof(T) + 1; // Room for tag - static const size_t kSize = ((kSizeUnaligned + 15) / 16) * 16; // Align - - // See how many fit T we can fit inside kSize, but no more than 254 - // since 255 is used as sentinel tag for out-of-line allocation. - static const unsigned int kSentinel = 255; - static const size_t kFit1 = (kSize - 1) / sizeof(T); - static const size_t kFit = (kFit1 >= kSentinel) ? (kSentinel - 1) : kFit1; - - union { - unsigned char data[kSize]; - // Force data to be aligned enough for a pointer. - T* unused_aligner; - } u_; - - inline void InitRep() { u_.data[kSize - 1] = 0; } - inline bool is_inline() const { return u_.data[kSize - 1] != kSentinel; } - - inline T* outofline_pointer() const { - T* ptr; - memcpy(&ptr, &u_.data[0], sizeof(ptr)); - return ptr; - } - - inline void set_outofline_pointer(T* p) { - memcpy(&u_.data[0], &p, sizeof(p)); - } - - inline uint64_t outofline_word() const { - uint64_t word; - memcpy(&word, &u_.data[kSize - 8], sizeof(word)); - return word; - } - - inline void set_outofline_word(uint64_t w) { - memcpy(&u_.data[kSize - 8], &w, sizeof(w)); - } - - inline size_t size_internal() const { - uint8_t s = static_cast(u_.data[kSize - 1]); - if (s != kSentinel) { - return static_cast(s); - } else { - const uint64_t word = outofline_word(); - if (port::kLittleEndian) { - // The sentinel and capacity bits are most-significant bits in word. - return static_cast(word & 0xffffffffffffull); - } else { - // The sentinel and capacity bits are least-significant bits in word. - return static_cast(word >> 16); - } - } - } - - void set_size_internal(size_t n) { - if (is_inline()) { - DCHECK_LT(n, kSentinel); - u_.data[kSize - 1] = static_cast(n); - } else { - uint64_t word; - if (port::kLittleEndian) { - // The sentinel and capacity bits are most-significant bits in word. - word = (static_cast(n) | - (static_cast(u_.data[kSize - 2]) << 48) | - (static_cast(kSentinel) << 56)); - } else { - // The sentinel and capacity bits are least-significant bits in word. - word = ((static_cast(n) << 16) | - (static_cast(u_.data[kSize - 2]) << 8) | - (static_cast(kSentinel))); - } - set_outofline_word(word); - DCHECK_EQ(u_.data[kSize - 1], kSentinel) << n; - } - } - - void DiscardStorage() { - T* base = data(); - size_t n = size(); - Destroy(base, n); - if (!is_inline()) { - port::Free(base); - } - } - - template - void EmplaceBackSlow(Args&&... args) { - const size_t s = size(); - DCHECK_EQ(s, capacity()); - Grow(s + 1, std::forward(args)...); - set_size_internal(s + 1); - } - - // Movers for Grow - // Does nothing. - static void Nop(T* src, size_t n, T* dst) {} - - // Moves srcs[0,n-1] contents to dst[0,n-1]. - static void Move(T* src, size_t n, T* dst) { - for (size_t i = 0; i < n; i++) { - new (dst + i) T(std::move(*(src + i))); - } - } - - // Initializers for Resize. - // Initializes dst[0,n-1] with empty constructor. - static void ValueInit(const T*, size_t n, T* dst) { - for (size_t i = 0; i < n; i++) { - new (dst + i) T(); - } - } - - // Initializes dst[0,n-1] with copies of *src. - static void Fill(const T* src, size_t n, T* dst) { - for (size_t i = 0; i < n; i++) { - new (dst + i) T(*src); - } - } - - void Destroy(T* src, int n) { - if (!std::is_trivially_destructible::value) { - for (int i = 0; i < n; i++) { - (src + i)->~T(); - } - } - } - - // Initialization methods for Grow. - // 1) Leave uninitialized memory. - struct Uninitialized { - void operator()(T*) const {} - }; - // 2) Construct a T with args at not-yet-initialized memory pointed by dst. - struct Construct { - template - void operator()(T* dst, Args&&... args) const { - new (dst) T(std::forward(args)...); - } - }; - - // Grow so that capacity >= n. Uses Mover to move existing elements - // to new buffer, and possibly initialize the new element according - // to InitType. - // We pass the InitType and Mover as template arguments so that - // this code compiles even if T does not support copying or default - // construction. - template - void Grow(size_t n, Args&&... args) { - size_t s = size(); - DCHECK_LE(s, capacity()); - - // Compute new capacity by repeatedly doubling current capacity - size_t target = 1; - size_t target_lg = 0; - while (target < kFit || target < n) { - // TODO(psrc): Check and avoid overflow? - target_lg++; - target <<= 1; - } - - T* src = data(); - T* dst = static_cast(port::Malloc(target * sizeof(T))); - - // Need to copy elem before discarding src since it might alias src. - InitType{}(dst + s, std::forward(args)...); - Mover(src, s, dst); - DiscardStorage(); - - u_.data[kSize - 1] = kSentinel; - u_.data[kSize - 2] = static_cast(target_lg); - set_size_internal(s); - DCHECK_EQ(capacity(), target); - set_outofline_pointer(dst); - } - - // Resize to size n. Any new elements are initialized by passing - // elem and the destination to Initializer. We pass the Initializer - // as a template argument so that this code compiles even if T does - // not support copying. - template - void Resize(size_t n, const T* elem) { - size_t s = size(); - if (n <= s) { - Destroy(data() + n, s - n); - set_size_internal(n); - return; - } - reserve(n); - DCHECK_GE(capacity(), n); - set_size_internal(n); - Initializer(elem, n - s, data() + s); - } - - template - void AppendRange(Iter first, Iter last, std::input_iterator_tag); - - // Faster path for forward iterators. - template - void AppendRange(Iter first, Iter last, std::forward_iterator_tag); - - template - void AppendRange(Iter first, Iter last); -}; - -// Provide linkage for constants. -template -const size_t InlinedVector::kSizeUnaligned; -template -const size_t InlinedVector::kSize; -template -const unsigned int InlinedVector::kSentinel; -template -const size_t InlinedVector::kFit1; -template -const size_t InlinedVector::kFit; - -template -inline void swap(InlinedVector& a, InlinedVector& b) { - a.swap(b); -} - -template -inline bool operator==(const InlinedVector& a, - const InlinedVector& b) { - return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin()); -} - -template -inline bool operator!=(const InlinedVector& a, - const InlinedVector& b) { - return !(a == b); -} - -template -inline bool operator<(const InlinedVector& a, - const InlinedVector& b) { - return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end()); -} - -template -inline bool operator>(const InlinedVector& a, - const InlinedVector& b) { - return b < a; -} - -template -inline bool operator<=(const InlinedVector& a, - const InlinedVector& b) { - return !(b < a); -} - -template -inline bool operator>=(const InlinedVector& a, - const InlinedVector& b) { - return !(a < b); -} - -// ======================================== -// Implementation - -template -inline InlinedVector::InlinedVector() { - InitRep(); -} - -template -inline InlinedVector::InlinedVector(size_t n) { - InitRep(); - if (n > capacity()) { - Grow(n); // Must use Nop in case T is not copyable - } - set_size_internal(n); - ValueInit(nullptr, n, data()); -} - -template -inline InlinedVector::InlinedVector(size_t n, const value_type& elem) { - InitRep(); - if (n > capacity()) { - Grow(n); // Can use Nop since we know we have nothing to copy - } - set_size_internal(n); - Fill(&elem, n, data()); -} - -template -inline InlinedVector::InlinedVector(const InlinedVector& v) { - InitRep(); - *this = v; -} - -template -typename InlinedVector::iterator InlinedVector::insert( - iterator pos, const value_type& v) { - DCHECK_GE(pos, begin()); - DCHECK_LE(pos, end()); - if (pos == end()) { - push_back(v); - return end() - 1; - } - size_t s = size(); - size_t idx = std::distance(begin(), pos); - if (s == capacity()) { - Grow(s + 1); - } - CHECK_LT(s, capacity()); - pos = begin() + idx; // Reset 'pos' into a post-enlarge iterator. - Fill(data() + s - 1, 1, data() + s); // data[s] = data[s-1] - std::copy_backward(pos, data() + s - 1, data() + s); - *pos = v; - - set_size_internal(s + 1); - return pos; -} - -template -typename InlinedVector::iterator InlinedVector::erase( - iterator first, iterator last) { - DCHECK_LE(begin(), first); - DCHECK_LE(first, last); - DCHECK_LE(last, end()); - - size_t s = size(); - ptrdiff_t erase_gap = std::distance(first, last); - std::copy(last, data() + s, first); - Destroy(data() + s - erase_gap, erase_gap); - set_size_internal(s - erase_gap); - return first; -} - -template -void InlinedVector::swap(InlinedVector& other) { - using std::swap; // Augment ADL with std::swap. - if (&other == this) { - return; - } - - InlinedVector* a = this; - InlinedVector* b = &other; - - const bool a_inline = a->is_inline(); - const bool b_inline = b->is_inline(); - - if (!a_inline && !b_inline) { - // Just swap the top-level representations. - T* aptr = a->outofline_pointer(); - T* bptr = b->outofline_pointer(); - a->set_outofline_pointer(bptr); - b->set_outofline_pointer(aptr); - - uint64_t aword = a->outofline_word(); - uint64_t bword = b->outofline_word(); - a->set_outofline_word(bword); - b->set_outofline_word(aword); - return; - } - - // Make a the larger of the two to reduce number of cases. - size_t a_size = a->size(); - size_t b_size = b->size(); - if (a->size() < b->size()) { - swap(a, b); - swap(a_size, b_size); - } - DCHECK_GE(a_size, b_size); - - if (b->capacity() < a_size) { - b->Grow(a_size); - } - - // One is inline and one is not. - // 'a' is larger. Swap the elements up to the smaller array size. - std::swap_ranges(a->data(), a->data() + b_size, b->data()); - std::uninitialized_copy(a->data() + b_size, a->data() + a_size, - b->data() + b_size); - Destroy(a->data() + b_size, a_size - b_size); - a->set_size_internal(b_size); - b->set_size_internal(a_size); - DCHECK_EQ(b->size(), a_size); - DCHECK_EQ(a->size(), b_size); -} - -template -template -inline void InlinedVector::AppendRange(Iter first, Iter last, - std::input_iterator_tag) { - std::copy(first, last, std::back_inserter(*this)); -} - -template -template -inline void InlinedVector::AppendRange(Iter first, Iter last, - std::forward_iterator_tag) { - typedef typename std::iterator_traits::difference_type Length; - Length length = std::distance(first, last); - size_t s = size(); - reserve(s + length); - std::uninitialized_copy_n(first, length, data() + s); - set_size_internal(s + length); -} - -template -template -inline void InlinedVector::AppendRange(Iter first, Iter last) { - typedef typename std::iterator_traits::iterator_category IterTag; - AppendRange(first, last, IterTag()); -} +using absl::InlinedVector; } // namespace gtl } // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc deleted file mode 100644 index 2721885c4a..0000000000 --- a/tensorflow/core/lib/gtl/inlined_vector_test.cc +++ /dev/null @@ -1,898 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/lib/gtl/inlined_vector.h" - -#include -#include -#include -#include - -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -typedef tensorflow::gtl::InlinedVector IntVec; - -// A type that counts number of live occurrences of the type -static int64 instances = 0; -class Instance { - public: - int value_; - explicit Instance(int x) : value_(x) { instances++; } - Instance(const Instance& x) : value_(x.value_) { instances++; } - ~Instance() { instances--; } - - friend inline void swap(Instance& a, Instance& b) { - using std::swap; - swap(a.value_, b.value_); - } - - friend std::ostream& operator<<(std::ostream& o, const Instance& v) { - return o << "[value:" << v.value_ << "]"; - } -}; - -typedef tensorflow::gtl::InlinedVector InstanceVec; - -// A simple reference counted class to make sure that the proper elements are -// destroyed in the erase(begin, end) test. -class RefCounted { - public: - RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); } - - RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) { - VLOG(5) << "[RefCounted: copy" - << " from count @" << v.count_ << "]"; - Ref(); - } - - ~RefCounted() { - Unref(); - count_ = nullptr; - } - - friend void swap(RefCounted& a, RefCounted& b) { - using std::swap; - swap(a.value_, b.value_); - swap(a.count_, b.count_); - } - - RefCounted& operator=(RefCounted v) { - using std::swap; - swap(*this, v); - return *this; - } - - void Ref() const { - CHECK(count_ != nullptr); - ++(*count_); - VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]"; - } - - void Unref() const { - --(*count_); - CHECK_GE(*count_, 0); - VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]"; - } - - int count() const { return *count_; } - - friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) { - return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]"; - } - - int value_; - int* count_; -}; - -typedef tensorflow::gtl::InlinedVector RefCountedVec; - -// A class with a vtable pointer -class Dynamic { - public: - virtual ~Dynamic() {} - - friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) { - return o << "[Dynamic]"; - } -}; - -typedef tensorflow::gtl::InlinedVector DynamicVec; - -// Append 0..len-1 to *v -static void Fill(IntVec* v, int len, int offset = 0) { - for (int i = 0; i < len; i++) { - v->push_back(i + offset); - } -} - -static IntVec Fill(int len, int offset = 0) { - IntVec v; - Fill(&v, len, offset); - return v; -} - -TEST(IntVec, SimpleOps) { - for (int len = 0; len < 20; len++) { - IntVec v; - const IntVec& cv = v; // const alias - - Fill(&v, len); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - - for (int i = 0; i < len; i++) { - EXPECT_EQ(i, v[i]); - } - EXPECT_EQ(v.begin(), v.data()); - EXPECT_EQ(cv.begin(), cv.data()); - - int counter = 0; - for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) { - EXPECT_EQ(counter, *iter); - counter++; - } - EXPECT_EQ(counter, len); - - counter = 0; - for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) { - EXPECT_EQ(counter, *iter); - counter++; - } - EXPECT_EQ(counter, len); - - if (len > 0) { - EXPECT_EQ(0, v.front()); - EXPECT_EQ(len - 1, v.back()); - v.pop_back(); - EXPECT_EQ(len - 1, v.size()); - for (size_t i = 0; i < v.size(); ++i) { - EXPECT_EQ(i, v[i]); - } - } - } -} - -TEST(IntVec, Erase) { - for (int len = 1; len < 20; len++) { - for (int i = 0; i < len; ++i) { - IntVec v; - Fill(&v, len); - v.erase(v.begin() + i); - EXPECT_EQ(len - 1, v.size()); - for (int j = 0; j < i; ++j) { - EXPECT_EQ(j, v[j]); - } - for (int j = i; j < len - 1; ++j) { - EXPECT_EQ(j + 1, v[j]); - } - } - } -} - -// At the end of this test loop, the elements between [erase_begin, erase_end) -// should have reference counts == 0, and all others elements should have -// reference counts == 1. -TEST(RefCountedVec, EraseBeginEnd) { - for (int len = 1; len < 20; ++len) { - for (int erase_begin = 0; erase_begin < len; ++erase_begin) { - for (int erase_end = erase_begin; erase_end <= len; ++erase_end) { - std::vector counts(len, 0); - RefCountedVec v; - for (int i = 0; i < len; ++i) { - v.push_back(RefCounted(i, &counts[i])); - } - - int erase_len = erase_end - erase_begin; - - v.erase(v.begin() + erase_begin, v.begin() + erase_end); - - EXPECT_EQ(len - erase_len, v.size()); - - // Check the elements before the first element erased. - for (int i = 0; i < erase_begin; ++i) { - EXPECT_EQ(i, v[i].value_); - } - - // Check the elements after the first element erased. - for (size_t i = erase_begin; i < v.size(); ++i) { - EXPECT_EQ(i + erase_len, v[i].value_); - } - - // Check that the elements at the beginning are preserved. - for (int i = 0; i < erase_begin; ++i) { - EXPECT_EQ(1, counts[i]); - } - - // Check that the erased elements are destroyed - for (int i = erase_begin; i < erase_end; ++i) { - EXPECT_EQ(0, counts[i]); - } - - // Check that the elements at the end are preserved. - for (int i = erase_end; i < len; ++i) { - EXPECT_EQ(1, counts[i]); - } - } - } - } -} - -struct NoDefaultCtor { - explicit NoDefaultCtor(int) {} -}; -struct NoCopy { - NoCopy() {} - NoCopy(const NoCopy&) = delete; -}; -struct NoAssign { - NoAssign() {} - NoAssign& operator=(const NoAssign&) = delete; -}; -struct MoveOnly { - MoveOnly() {} - MoveOnly(MoveOnly&&) = default; - MoveOnly& operator=(MoveOnly&&) = default; -}; -TEST(InlinedVectorTest, NoDefaultCtor) { - tensorflow::gtl::InlinedVector v(10, NoDefaultCtor(2)); - (void)v; -} -TEST(InlinedVectorTest, NoCopy) { - tensorflow::gtl::InlinedVector v(10); - (void)v; -} -TEST(InlinedVectorTest, NoAssign) { - tensorflow::gtl::InlinedVector v(10); - (void)v; -} -TEST(InlinedVectorTest, MoveOnly) { - gtl::InlinedVector v; - v.push_back(MoveOnly{}); - v.push_back(MoveOnly{}); - v.push_back(MoveOnly{}); -} - -TEST(IntVec, Insert) { - for (int len = 0; len < 20; len++) { - for (int pos = 0; pos <= len; pos++) { - IntVec v; - Fill(&v, len); - v.insert(v.begin() + pos, 9999); - EXPECT_EQ(v.size(), len + 1); - for (int i = 0; i < pos; i++) { - EXPECT_EQ(v[i], i); - } - EXPECT_EQ(v[pos], 9999); - for (size_t i = pos + 1; i < v.size(); i++) { - EXPECT_EQ(v[i], i - 1); - } - } - } -} - -TEST(RefCountedVec, InsertConstructorDestructor) { - // Make sure the proper construction/destruction happen during insert - // operations. - for (int len = 0; len < 20; len++) { - SCOPED_TRACE(len); - for (int pos = 0; pos <= len; pos++) { - SCOPED_TRACE(pos); - std::vector counts(len, 0); - int inserted_count = 0; - RefCountedVec v; - for (int i = 0; i < len; ++i) { - SCOPED_TRACE(i); - v.push_back(RefCounted(i, &counts[i])); - } - - for (auto elem : counts) { - EXPECT_EQ(1, elem); - } - - RefCounted insert_element(9999, &inserted_count); - EXPECT_EQ(1, inserted_count); - v.insert(v.begin() + pos, insert_element); - EXPECT_EQ(2, inserted_count); - // Check that the elements at the end are preserved. - for (auto elem : counts) { - EXPECT_EQ(1, elem); - } - EXPECT_EQ(2, inserted_count); - } - } -} - -TEST(IntVec, Resize) { - for (int len = 0; len < 20; len++) { - IntVec v; - Fill(&v, len); - - // Try resizing up and down by k elements - static const int kResizeElem = 1000000; - for (int k = 0; k < 10; k++) { - // Enlarging resize - v.resize(len + k, kResizeElem); - EXPECT_EQ(len + k, v.size()); - EXPECT_LE(len + k, v.capacity()); - for (int i = 0; i < len + k; i++) { - if (i < len) { - EXPECT_EQ(i, v[i]); - } else { - EXPECT_EQ(kResizeElem, v[i]); - } - } - - // Shrinking resize - v.resize(len, kResizeElem); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - for (int i = 0; i < len; i++) { - EXPECT_EQ(i, v[i]); - } - } - } -} - -TEST(IntVec, InitWithLength) { - for (int len = 0; len < 20; len++) { - IntVec v(len, 7); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - for (int i = 0; i < len; i++) { - EXPECT_EQ(7, v[i]); - } - } -} - -TEST(IntVec, CopyConstructorAndAssignment) { - for (int len = 0; len < 20; len++) { - IntVec v; - Fill(&v, len); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - - IntVec v2(v); - EXPECT_EQ(v, v2); - - for (int start_len = 0; start_len < 20; start_len++) { - IntVec v3; - Fill(&v3, start_len, 99); // Add dummy elements that should go away - v3 = v; - EXPECT_EQ(v, v3); - } - } -} - -TEST(OverheadTest, Storage) { - // Check for size overhead. - using tensorflow::gtl::InlinedVector; - EXPECT_EQ(2 * sizeof(int*), sizeof(InlinedVector)); - EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector)); - EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector)); - EXPECT_EQ(6 * sizeof(int*), sizeof(InlinedVector)); - - EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector)); - EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector)); - EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector)); - EXPECT_EQ(2 * sizeof(char*), - sizeof(InlinedVector)); - EXPECT_EQ(4 * sizeof(char*), sizeof(InlinedVector)); -} - -TEST(IntVec, Clear) { - for (int len = 0; len < 20; len++) { - SCOPED_TRACE(len); - IntVec v; - Fill(&v, len); - v.clear(); - EXPECT_EQ(0, v.size()); - EXPECT_EQ(v.begin(), v.end()); - } -} - -TEST(IntVec, Reserve) { - for (size_t len = 0; len < 20; len++) { - IntVec v; - Fill(&v, len); - - for (size_t newlen = 0; newlen < 100; newlen++) { - const int* start_rep = v.data(); - v.reserve(newlen); - const int* final_rep = v.data(); - if (newlen <= len) { - EXPECT_EQ(start_rep, final_rep); - } - EXPECT_LE(newlen, v.capacity()); - - // Filling up to newlen should not change rep - while (v.size() < newlen) { - v.push_back(0); - } - EXPECT_EQ(final_rep, v.data()); - } - } -} - -template -static std::vector Vec(const T& src) { - std::vector result; - for (const auto& elem : src) { - result.push_back(elem); - } - return result; -} - -TEST(IntVec, SelfRefPushBack) { - std::vector std_v; - tensorflow::gtl::InlinedVector v; - const string s = "A quite long string to ensure heap."; - std_v.push_back(s); - v.push_back(s); - for (int i = 0; i < 20; ++i) { - EXPECT_EQ(std_v, Vec(v)); - - v.push_back(v.back()); - std_v.push_back(std_v.back()); - } - EXPECT_EQ(std_v, Vec(v)); -} - -TEST(IntVec, SelfRefPushBackWithMove) { - std::vector std_v; - gtl::InlinedVector v; - const string s = "A quite long string to ensure heap."; - std_v.push_back(s); - v.push_back(s); - for (int i = 0; i < 20; ++i) { - EXPECT_EQ(v.back(), std_v.back()); - - v.push_back(std::move(v.back())); - std_v.push_back(std::move(std_v.back())); - } - EXPECT_EQ(v.back(), std_v.back()); -} - -TEST(IntVec, Swap) { - for (int l1 = 0; l1 < 20; l1++) { - SCOPED_TRACE(l1); - for (int l2 = 0; l2 < 20; l2++) { - SCOPED_TRACE(l2); - IntVec a = Fill(l1, 0); - IntVec b = Fill(l2, 100); - { - using std::swap; - swap(a, b); - } - EXPECT_EQ(l1, b.size()); - EXPECT_EQ(l2, a.size()); - for (int i = 0; i < l1; i++) { - SCOPED_TRACE(i); - EXPECT_EQ(i, b[i]); - } - for (int i = 0; i < l2; i++) { - SCOPED_TRACE(i); - EXPECT_EQ(100 + i, a[i]); - } - } - } -} - -TEST(InstanceVec, Swap) { - for (int l1 = 0; l1 < 20; l1++) { - for (int l2 = 0; l2 < 20; l2++) { - InstanceVec a, b; - for (int i = 0; i < l1; i++) a.push_back(Instance(i)); - for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i)); - EXPECT_EQ(l1 + l2, instances); - { - using std::swap; - swap(a, b); - } - EXPECT_EQ(l1 + l2, instances); - EXPECT_EQ(l1, b.size()); - EXPECT_EQ(l2, a.size()); - for (int i = 0; i < l1; i++) { - EXPECT_EQ(i, b[i].value_); - } - for (int i = 0; i < l2; i++) { - EXPECT_EQ(100 + i, a[i].value_); - } - } - } -} - -TEST(IntVec, EqualAndNotEqual) { - IntVec a, b; - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - - a.push_back(3); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - b.push_back(3); - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - - b.push_back(7); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - a.push_back(6); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - a.clear(); - b.clear(); - for (int i = 0; i < 100; i++) { - a.push_back(i); - b.push_back(i); - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - - b[i] = b[i] + 1; - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - b[i] = b[i] - 1; // Back to before - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - } -} - -TEST(IntVec, RelationalOps) { - IntVec a, b; - EXPECT_FALSE(a < b); - EXPECT_FALSE(b < a); - EXPECT_FALSE(a > b); - EXPECT_FALSE(b > a); - EXPECT_TRUE(a <= b); - EXPECT_TRUE(b <= a); - EXPECT_TRUE(a >= b); - EXPECT_TRUE(b >= a); - b.push_back(3); - EXPECT_TRUE(a < b); - EXPECT_FALSE(b < a); - EXPECT_FALSE(a > b); - EXPECT_TRUE(b > a); - EXPECT_TRUE(a <= b); - EXPECT_FALSE(b <= a); - EXPECT_FALSE(a >= b); - EXPECT_TRUE(b >= a); -} - -TEST(InstanceVec, CountConstructorsDestructors) { - const int start = instances; - for (int len = 0; len < 20; len++) { - InstanceVec v; - for (int i = 0; i < len; i++) { - v.push_back(Instance(i)); - } - EXPECT_EQ(start + len, instances); - - { // Copy constructor should create 'len' more instances. - InstanceVec v_copy(v); - EXPECT_EQ(start + len + len, instances); - } - EXPECT_EQ(start + len, instances); - - // Enlarging resize() must construct some objects - v.resize(len + 10, Instance(100)); - EXPECT_EQ(start + len + 10, instances); - - // Shrinking resize() must destroy some objects - v.resize(len, Instance(100)); - EXPECT_EQ(start + len, instances); - - // reserve() must not increase the number of initialized objects - v.reserve(len + 1000); - EXPECT_EQ(start + len, instances); - - // pop_back() and erase() must destroy one object - if (len > 0) { - v.pop_back(); - EXPECT_EQ(start + len - 1, instances); - if (!v.empty()) { - v.erase(v.begin()); - EXPECT_EQ(start + len - 2, instances); - } - } - } - EXPECT_EQ(start, instances); -} - -TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) { - const int start = instances; - for (int len = 0; len < 20; len++) { - for (int longorshort = 0; longorshort <= 1; ++longorshort) { - InstanceVec longer, shorter; - for (int i = 0; i < len; i++) { - longer.push_back(Instance(i)); - shorter.push_back(Instance(i)); - } - longer.push_back(Instance(len)); - EXPECT_EQ(start + len + len + 1, instances); - - if (longorshort) { - shorter = longer; - EXPECT_EQ(start + (len + 1) + (len + 1), instances); - } else { - longer = shorter; - EXPECT_EQ(start + len + len, instances); - } - } - } - EXPECT_EQ(start, instances); -} - -TEST(RangedConstructor, SimpleType) { - std::vector source_v = {4, 5, 6, 7}; - // First try to fit in inline backing - tensorflow::gtl::InlinedVector v(source_v.begin(), source_v.end()); - tensorflow::gtl::InlinedVector empty4; - EXPECT_EQ(4, v.size()); - EXPECT_EQ(empty4.capacity(), v.capacity()); // Must still be inline - EXPECT_EQ(4, v[0]); - EXPECT_EQ(5, v[1]); - EXPECT_EQ(6, v[2]); - EXPECT_EQ(7, v[3]); - - // Now, force a re-allocate - tensorflow::gtl::InlinedVector realloc_v(source_v.begin(), - source_v.end()); - tensorflow::gtl::InlinedVector empty2; - EXPECT_EQ(4, realloc_v.size()); - EXPECT_LT(empty2.capacity(), realloc_v.capacity()); - EXPECT_EQ(4, realloc_v[0]); - EXPECT_EQ(5, realloc_v[1]); - EXPECT_EQ(6, realloc_v[2]); - EXPECT_EQ(7, realloc_v[3]); -} - -TEST(RangedConstructor, ComplexType) { - // We also use a list here to pass a different flavor of iterator (e.g. not - // random-access). - std::list source_v = {Instance(0)}; - - // First try to fit in inline backing - tensorflow::gtl::InlinedVector v(source_v.begin(), - source_v.end()); - tensorflow::gtl::InlinedVector empty1; - EXPECT_EQ(1, v.size()); - EXPECT_EQ(empty1.capacity(), v.capacity()); // Must still be inline - EXPECT_EQ(0, v[0].value_); - - std::list source_v2 = {Instance(0), Instance(1), Instance(2), - Instance(3)}; - // Now, force a re-allocate - tensorflow::gtl::InlinedVector realloc_v(source_v2.begin(), - source_v2.end()); - EXPECT_EQ(4, realloc_v.size()); - EXPECT_LT(empty1.capacity(), realloc_v.capacity()); - EXPECT_EQ(0, realloc_v[0].value_); - EXPECT_EQ(1, realloc_v[1].value_); - EXPECT_EQ(2, realloc_v[2].value_); - EXPECT_EQ(3, realloc_v[3].value_); -} - -TEST(RangedConstructor, ElementsAreConstructed) { - std::vector source_v = {"cat", "dog"}; - - // Force expansion and re-allocation of v. Ensures that when the vector is - // expanded that new elements are constructed. - tensorflow::gtl::InlinedVector v(source_v.begin(), source_v.end()); - EXPECT_EQ("cat", v[0]); - EXPECT_EQ("dog", v[1]); -} - -TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) { - auto vec = tensorflow::gtl::InlinedVector{4, 5, 6}; - EXPECT_EQ(3, vec.size()); - EXPECT_EQ(3, vec.capacity()); - EXPECT_EQ(4, vec[0]); - EXPECT_EQ(5, vec[1]); - EXPECT_EQ(6, vec[2]); -} - -TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) { - auto vec = tensorflow::gtl::InlinedVector{4, 5, 6}; - EXPECT_EQ(3, vec.size()); - EXPECT_LE(3, vec.capacity()); - EXPECT_EQ(4, vec[0]); - EXPECT_EQ(5, vec[1]); - EXPECT_EQ(6, vec[2]); -} - -TEST(InitializerListConstructor, DisparateTypesInList) { - EXPECT_EQ((std::vector{-7, 8}), - Vec(tensorflow::gtl::InlinedVector{-7, 8ULL})); - - EXPECT_EQ( - (std::vector{"foo", "bar"}), - Vec(tensorflow::gtl::InlinedVector{"foo", string("bar")})); -} - -TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) { - tensorflow::gtl::InlinedVector empty; - auto vec = tensorflow::gtl::InlinedVector{Instance(0)}; - EXPECT_EQ(1, vec.size()); - EXPECT_EQ(empty.capacity(), vec.capacity()); - EXPECT_EQ(0, vec[0].value_); -} - -TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) { - auto vec = - tensorflow::gtl::InlinedVector{Instance(0), Instance(1)}; - EXPECT_EQ(2, vec.size()); - EXPECT_LE(2, vec.capacity()); - EXPECT_EQ(0, vec[0].value_); - EXPECT_EQ(1, vec[1].value_); -} - -TEST(DynamicVec, DynamicVecCompiles) { - DynamicVec v; - (void)v; -} - -static void BM_InlinedVectorFill(int iters, int len) { - for (int i = 0; i < iters; i++) { - IntVec v; - for (int j = 0; j < len; j++) { - v.push_back(j); - } - } - testing::BytesProcessed((int64{iters} * len) * sizeof(int)); -} -BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024); - -static void BM_InlinedVectorFillRange(int iters, int len) { - std::unique_ptr ia(new int[len]); - for (int j = 0; j < len; j++) { - ia[j] = j; - } - for (int i = 0; i < iters; i++) { - IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len); - } - testing::BytesProcessed((int64{iters} * len) * sizeof(int)); -} -BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024); - -static void BM_StdVectorFill(int iters, int len) { - for (int i = 0; i < iters; i++) { - std::vector v; - v.reserve(len); - for (int j = 0; j < len; j++) { - v.push_back(j); - } - } - testing::BytesProcessed((int64{iters} * len) * sizeof(int)); -} -BENCHMARK(BM_StdVectorFill)->Range(0, 1024); - -bool StringRepresentedInline(string s) { - const char* chars = s.data(); - string s1 = std::move(s); - return s1.data() != chars; -} - -static void BM_InlinedVectorFillString(int iters, int len) { - string strings[4] = {"a quite long string", "another long string", - "012345678901234567", "to cause allocation"}; - for (int i = 0; i < iters; i++) { - gtl::InlinedVector v; - for (int j = 0; j < len; j++) { - v.push_back(strings[j & 3]); - } - } - testing::ItemsProcessed(int64{iters} * len); -} -BENCHMARK(BM_InlinedVectorFillString)->Range(0, 1024); - -static void BM_StdVectorFillString(int iters, int len) { - string strings[4] = {"a quite long string", "another long string", - "012345678901234567", "to cause allocation"}; - for (int i = 0; i < iters; i++) { - std::vector v; - v.reserve(len); - for (int j = 0; j < len; j++) { - v.push_back(strings[j & 3]); - } - } - testing::ItemsProcessed(int64{iters} * len); - // The purpose of the benchmark is to verify that inlined vector is - // efficient when moving is more efficient than copying. To do so, we - // use strings that are larger than the small string optimization. - CHECK(!StringRepresentedInline(strings[0])); -} -BENCHMARK(BM_StdVectorFillString)->Range(0, 1024); - -namespace { -struct Buffer { // some arbitrary structure for benchmarking. - char* base; - int length; - int capacity; - void* user_data; -}; -} // anonymous namespace - -static void BM_InlinedVectorTenAssignments(int iters, int len) { - typedef tensorflow::gtl::InlinedVector BufferVec; - - BufferVec src; - src.resize(len); - - iters *= 10; - BufferVec dst; - for (int i = 0; i < iters; i++) { - dst = src; - } -} -BENCHMARK(BM_InlinedVectorTenAssignments) - ->Arg(0) - ->Arg(1) - ->Arg(2) - ->Arg(3) - ->Arg(4) - ->Arg(20); - -static void BM_CreateFromInitializerList(int iters) { - for (; iters > 0; iters--) { - tensorflow::gtl::InlinedVector x{1, 2, 3}; - (void)x[0]; - } -} -BENCHMARK(BM_CreateFromInitializerList); - -namespace { - -struct LargeSwappable { - LargeSwappable() : d_(1024, 17) {} - ~LargeSwappable() {} - LargeSwappable(const LargeSwappable& o) : d_(o.d_) {} - - friend void swap(LargeSwappable& a, LargeSwappable& b) { - using std::swap; - swap(a.d_, b.d_); - } - - LargeSwappable& operator=(LargeSwappable o) { - using std::swap; - swap(*this, o); - return *this; - } - - std::vector d_; -}; - -} // namespace - -static void BM_LargeSwappableElements(int iters, int len) { - typedef tensorflow::gtl::InlinedVector Vec; - Vec a(len); - Vec b; - while (--iters >= 0) { - using std::swap; - swap(a, b); - } -} -BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024); - -} // namespace tensorflow diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 07b2e3426b..bb841aeab7 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -625,6 +625,7 @@ def tf_additional_lib_deps(): """Additional dependencies needed to build TF libraries.""" return [ "@com_google_absl//absl/base:base", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:optional", ] + if_static( diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index 7f851e3646..f25ed700d6 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -41,6 +41,7 @@ limitations under the License. #define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_ #include +#include #include "tensorflow/stream_executor/host_or_device_scalar.h" #include "tensorflow/stream_executor/lib/array_slice.h" -- GitLab From db9cc8dc4aecec10eb8052666dabbbd7a9952f1f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 09:03:18 -0700 Subject: [PATCH 120/540] Fix categorical feature handler accumulator to use high precision 64 bit accumulator. PiperOrigin-RevId: 211642436 --- .../lib/learner/batch/categorical_split_handler.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index e6407174b1..35d727482b 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -141,11 +141,18 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): # The bias is computed on gradients and hessians (and not # filtered_gradients) which have exactly one value per example, so we # don't double count a gradient in multivalent columns. + # Since unsorted_segment_sum can be numerically unstable, use 64bit + # operation. + gradients64 = math_ops.cast(gradients, dtypes.float64) + hessians64 = math_ops.cast(hessians, dtypes.float64) per_partition_gradients = math_ops.unsorted_segment_sum( - gradients, mapped_partitions, array_ops.size(unique_partitions)) + gradients64, mapped_partitions, array_ops.size(unique_partitions)) per_partition_hessians = math_ops.unsorted_segment_sum( - hessians, mapped_partitions, array_ops.size(unique_partitions)) - + hessians64, mapped_partitions, array_ops.size(unique_partitions)) + per_partition_gradients = math_ops.cast(per_partition_gradients, + dtypes.float32) + per_partition_hessians = math_ops.cast(per_partition_hessians, + dtypes.float32) # Prepend a bias feature per partition that accumulates the stats for all # examples in that partition. # Bias is added to the stats even if there are no examples with values in -- GitLab From 47860208eee575119b0dd1b6168dc24cf51caf64 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Wed, 5 Sep 2018 09:08:18 -0700 Subject: [PATCH 121/540] [XLA] Give "big" and "small" params different colors in hlo_graph_dumper. PiperOrigin-RevId: 211643209 --- .../compiler/xla/service/hlo_graph_dumper.cc | 41 +++++++++---------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 3041d94fa9..0345a2a5f8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -120,12 +120,19 @@ class NodeFilter { std::function filter_; }; +// We arbitrarily set this as the boundary between "large" and "small" +// instructions. +bool IsSmall(const HloInstruction* instr) { + return ShapeUtil::ElementsInRecursive(instr->shape()) < 4096; +} + // Node color schemes, used by NodeColorAttributes. enum ColorScheme { kBlue, kBrown, kDarkBlue, kDarkGreen, + kDarkOrange, kDarkRed, kGray, kGreen, @@ -158,6 +165,10 @@ NodeColors NodeColorsForScheme(ColorScheme color) { return NodeColors{"filled", "#1565c0", "#003c8f", "white"}; case kDarkGreen: return NodeColors{"filled", "#2e7d32", "#005005", "white"}; + case kDarkOrange: + // This is more of a "medium" orange, made to look close to kOrange; + // there's probably room for a darker weight if desired. + return NodeColors{"filled", "#ffb74d", "#c88719", "black"}; case kDarkRed: return NodeColors{"filled", "#b71c1c", "#7f0000", "white"}; case kGray: @@ -893,7 +904,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { sharding_colors_.emplace(instr->sharding(), color); return color; } - const auto kParameterColor = kOrange; + + // Choose different weights of orange for small vs large parameters. This + // distinction is often important, especially in fusion nodes. + auto parameter_color = IsSmall(instr) ? kOrange : kDarkOrange; // Special case: If this instruction has a parameter merged into it, paint it // the same color as a parameter. Unless the merged-in parameter is a @@ -905,7 +919,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { ShouldMergeIntoUsers(operand) && TryGetFusionParameterConstant(operand) == nullptr; })) { - return kParameterColor; + return parameter_color; } // Pick different colors or shapes for instructions which are particularly @@ -1015,7 +1029,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReducePrecision: return kRed; case HloOpcode::kParameter: - return kParameterColor; + return parameter_color; case HloOpcode::kBatchNormGrad: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormTraining: @@ -1160,20 +1174,6 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { return StrJoin(lines, "
"); } -// Gets the total number of array elements in the given shape. For tuples, this -// is the sum of all the sizes of all of the array elements recursively in the -// tuple. -static int64 TotalElementsInShape(const Shape& shape) { - int64 elems = 0; - ShapeUtil::ForEachSubshape( - shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) { - if (ShapeUtil::IsArray(subshape)) { - elems += ShapeUtil::ElementsIn(subshape); - } - }); - return elems; -} - void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, int64 operand_num, bool control_edge = false) { @@ -1196,14 +1196,11 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { } // We print "small" arrays using a hollow arrowhead and "large" arrays using - // a filled arrowhead. For now, we use an arbitrary cutoff for what "big" - // means. - bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; - + // a filled arrowhead. constexpr char kEdgeFmt[] = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; edges_.push_back(StrFormat(kEdgeFmt, InstructionId(from), InstructionId(to), - (is_big_array ? "normal" : "empty"), + (IsSmall(from) ? "empty" : "normal"), from->name(), to->name(), edge_label)); }; -- GitLab From 11548e0ab987ec3935b1dfb87753c4bbe95f6ad1 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Wed, 5 Sep 2018 09:39:04 -0700 Subject: [PATCH 122/540] Set CUDA_VISIBLE_DEVICES='' tfcompile and tfcompile tests' genrules. This prevents these build-time rules from accessing any GPUs which might be present on the build machine and interfering with GPU tests which might be running concurrently. PiperOrigin-RevId: 211647681 --- tensorflow/compiler/aot/tests/BUILD | 7 +++- tensorflow/compiler/aot/tfcompile.bzl | 59 +++++++++++++++++---------- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 723e9bec8a..8d94f5495c 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -67,7 +67,12 @@ genrule( "test_graph_tfmatmulandadd.pb", "test_graph_tfsplits.pb", ], - cmd = "$(location :make_test_graphs) --out_dir $(@D)", + # Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any + # GPUs which might be present. This is important because builds may run + # concurrently with tests, and tests need to be able to assume that they + # have control of the full GPU. + cmd = "CUDA_VISIBLE_DEVICES='' " + + "$(location :make_test_graphs) --out_dir $(@D)", tags = ["manual"], tools = [":make_test_graphs"], ) diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 326f73b975..792b7fe14a 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -105,12 +105,18 @@ def tf_library( freeze_file = freeze_name + ".pb" # First run tfcompile to generate the list of out_nodes. + # + # Here and below, we set CUDA_VISIBLE_DEVICES='' to prevent the code we + # launch from using any GPUs which might be present. This is important + # because builds may run concurrently with tests, and tests need to be + # able to assume that they have control of the full GPU. out_nodes_file = "out_nodes_" + freeze_name native.genrule( name = ("gen_" + out_nodes_file), srcs = [config], outs = [out_nodes_file], - cmd = ("$(location " + tfcompile_tool + ")" + + cmd = ("CUDA_VISIBLE_DEVICES='' " + + "$(location " + tfcompile_tool + ")" + " --config=$(location " + config + ")" + " --dump_fetch_nodes > $@"), tools = [tfcompile_tool], @@ -142,9 +148,12 @@ def tf_library( out_nodes_file, ] + freeze_saver_srcs, outs = [freeze_file], - cmd = ("$(location " + - "//tensorflow/python/tools:freeze_graph)" + - freeze_args), + cmd = ( + "CUDA_VISIBLE_DEVICES='' " + + "$(location " + + "//tensorflow/python/tools:freeze_graph)" + + freeze_args + ), tools = ["//tensorflow/python/tools:freeze_graph"], tags = tags, ) @@ -177,16 +186,19 @@ def tf_library( metadata_object_file, function_object_file, ], - cmd = ("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_header=$(@D)/" + header_file + - " --out_metadata_object=$(@D)/" + metadata_object_file + - " --out_function_object=$(@D)/" + function_object_file + - " " + flags + " " + profiling_flag), + cmd = ( + "CUDA_VISIBLE_DEVICES='' " + + "$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_header=$(@D)/" + header_file + + " --out_metadata_object=$(@D)/" + metadata_object_file + + " --out_function_object=$(@D)/" + function_object_file + + " " + flags + " " + profiling_flag + ), tools = [tfcompile_tool], visibility = visibility, testonly = testonly, @@ -216,14 +228,17 @@ def tf_library( outs = [ session_module_pb, ], - cmd = ("$(location " + tfcompile_tool + ")" + - " --graph=$(location " + tfcompile_graph + ")" + - " --config=$(location " + config + ")" + - " --entry_point=" + ep + - " --cpp_class=" + cpp_class + - " --target_triple=" + target_llvm_triple() + - " --out_session_module=$(@D)/" + session_module_pb + - " " + flags), + cmd = ( + "CUDA_VISIBLE_DEVICES='' " + + "$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_session_module=$(@D)/" + session_module_pb + + " " + flags + ), tools = [tfcompile_tool], visibility = visibility, testonly = testonly, -- GitLab From 08313b87960962efb98bcd684776c8305fa9909a Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 5 Sep 2018 10:02:12 -0700 Subject: [PATCH 123/540] Optimize CuboidConvolutionBwdInput. ~25-30% speedup when compiled with AVX. * collapse inner dims before contraction * eval kernel tensor before contraction PiperOrigin-RevId: 211651030 --- .../eigen_backward_cuboid_convolutions.h | 201 +++++++++--------- .../eigen_backward_spatial_convolutions.h | 7 +- 2 files changed, 107 insertions(+), 101 deletions(-) diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h index 3ebeb7be2b..27918b410b 100644 --- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h @@ -51,14 +51,18 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional< internal::traits::NumDimensions>, const TensorContractionOp< const array< - IndexPair::Index>, 2>, - const TensorReshapingOp< + IndexPair::Index>, 1>, + const Eigen::TensorForcedEvalOp::Index, - 3>, - const TensorReverseOp, const Kernel> >, + 2>, + const TensorShufflingOp< + const array< + typename internal::traits::Index, 5>, + const TensorReverseOp, + const Kernel> > > >, const TensorReshapingOp< const DSizes::Index, - 3>, + 2>, const TensorVolumePatchOp > > >, TensorReshapingOp< @@ -66,24 +70,27 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional< internal::traits::NumDimensions>, const TensorContractionOp< const array< - IndexPair::Index>, 2>, + IndexPair::Index>, 1>, const TensorReshapingOp< const DSizes::Index, - 3>, + 2>, const TensorVolumePatchOp >, - const TensorReshapingOp< + const Eigen::TensorForcedEvalOp::Index, - 3>, - const TensorReverseOp, - const Kernel> > > > >::type + 2>, + const TensorShufflingOp< + const array< + typename internal::traits::Index, 5>, + const TensorReverseOp, + const Kernel> > > > > > >::type CuboidConvolutionBackwardInput( const Kernel& kernel, const OutputBackward& output_backward, typename internal::traits::Index inputPlanes, typename internal::traits::Index inputRows, typename internal::traits::Index inputCols, - const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1, - const DenseIndex strideCols = 1) { + const DenseIndex plane_stride = 1, const DenseIndex row_stride = 1, + const DenseIndex col_stride = 1) { typedef typename internal::traits::Index TensorIndex; const TensorRef::Scalar, internal::traits::NumDimensions, @@ -125,58 +132,45 @@ CuboidConvolutionBackwardInput( const TensorIndex outputCols = isColMajor ? out.dimensions()[3] : out.dimensions()[NumDims - 4]; - TensorIndex forward_pad_z, forward_pad_y, forward_pad_x; - const TensorIndex size_z = - Eigen::divup(inputPlanes, static_cast(stridePlanes)); - const TensorIndex size_y = - Eigen::divup(inputRows, static_cast(strideRows)); - const TensorIndex size_x = - Eigen::divup(inputCols, static_cast(strideCols)); - - // Infer padding type. - if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) { - // SAME padding. - const TensorIndex dz = numext::maxi( - 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes); - const TensorIndex dy = numext::maxi( - 0, (size_y - 1) * strideRows + kernelRows - inputRows); - const TensorIndex dx = numext::maxi( - 0, (size_x - 1) * strideCols + kernelCols - inputCols); - - forward_pad_z = dz / 2; - forward_pad_y = dy / 2; - forward_pad_x = dx / 2; - } else { - // VALID padding. - forward_pad_z = 0; - forward_pad_y = 0; - forward_pad_x = 0; - } - const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z; - const TensorIndex padding_top = kernelRows - 1 - forward_pad_y; - const TensorIndex padding_left = kernelCols - 1 - forward_pad_x; - - const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 - - (outputPlanes - 1) * stridePlanes - 1 - - padding_ztop; - const TensorIndex padding_bottom = inputRows + kernelRows - 1 - - (outputRows - 1) * strideRows - 1 - - padding_top; - const TensorIndex padding_right = inputCols + kernelCols - 1 - - (outputCols - 1) * strideCols - 1 - - padding_left; - - eigen_assert(padding_ztop >= 0); - eigen_assert(padding_zbottom >= 0); + // TODO(ezhulenev): Add support for inflated strides. Without inflated strides + // effective kernel planes/rows/cols are always the same as the kernel itself + // (see eigen_spatial_convolutions for details). + const TensorIndex kernelPlanesEff = kernelPlanes; + const TensorIndex kernelRowsEff = kernelRows; + const TensorIndex kernelColsEff = kernelCols; + + // Computing the forward padding. + const TensorIndex forward_pad_top_z = numext::maxi( + 0, + ((outputPlanes - 1) * plane_stride + kernelPlanesEff - inputPlanes) / 2); + const TensorIndex forward_pad_top = numext::maxi( + 0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2); + const TensorIndex forward_pad_left = numext::maxi( + 0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2); + + const TensorIndex padding_top_z = kernelPlanesEff - 1 - forward_pad_top_z; + const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top; + const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left; + + const TensorIndex padding_bottom_z = inputPlanes - + (outputPlanes - 1) * plane_stride - 2 - + padding_top_z + kernelPlanesEff; + const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride - + 2 - padding_top + kernelRowsEff; + const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride - + 2 - padding_left + kernelColsEff; + + eigen_assert(padding_top_z >= 0); eigen_assert(padding_top >= 0); eigen_assert(padding_left >= 0); + eigen_assert(padding_bottom_z >= 0); eigen_assert(padding_bottom >= 0); eigen_assert(padding_right >= 0); - // The kernel has dimensions filters X channels X patch_planes X patch_rows X - // patch_cols. + // The kernel has dimensions : + // filters x channels x patch_planes x patch_rows x patch_cols. // We need to reverse the kernel along the spatial dimensions. - array kernel_reverse; + Eigen::array kernel_reverse; if (isColMajor) { kernel_reverse[0] = false; kernel_reverse[1] = false; @@ -191,15 +185,35 @@ CuboidConvolutionBackwardInput( kernel_reverse[4] = false; } - DSizes kernel_dims; + // Reorder the dimensions to: + // filters x patch_planes x patch_rows x patch_cols x channels + array kernel_shuffle; if (isColMajor) { - kernel_dims[0] = kernelFilters; - kernel_dims[1] = kernelChannels; - kernel_dims[2] = kernelRows * kernelCols * kernelPlanes; + // From: filters x channels x planes x rows x cols + // To: filters x planes x rows x cols x channels + kernel_shuffle[0] = 0; + kernel_shuffle[1] = 2; + kernel_shuffle[2] = 3; + kernel_shuffle[3] = 4; + kernel_shuffle[4] = 1; } else { - kernel_dims[0] = kernelRows * kernelCols * kernelPlanes; + // From: cols x rows x planes x channels x filters + // To: channels x cols x rows x planes x filters + kernel_shuffle[0] = 3; + kernel_shuffle[1] = 0; + kernel_shuffle[2] = 1; + kernel_shuffle[3] = 2; + kernel_shuffle[4] = 4; + } + + // Collapse the dims + DSizes kernel_dims; + if (isColMajor) { + kernel_dims[0] = kernelFilters * kernelPlanes * kernelRows * kernelCols; kernel_dims[1] = kernelChannels; - kernel_dims[2] = kernelFilters; + } else { + kernel_dims[1] = kernelFilters * kernelPlanes * kernelRows * kernelCols; + kernel_dims[0] = kernelChannels; } // The output_backward has dimensions out_depth X out_planes X out_rows X @@ -208,36 +222,32 @@ CuboidConvolutionBackwardInput( // dimensions: // out_depth X (patch_planes * patch_rows * patch_cols) X (input_planes * // input_rows * input_cols * OTHERS) - DSizes pre_contract_dims; + DSizes pre_contract_dims; if (isColMajor) { - pre_contract_dims[0] = kernelFilters; - pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[2] = inputRows * inputCols * inputPlanes; + pre_contract_dims[0] = + kernelFilters * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[1] = inputPlanes * inputRows * inputCols; for (int i = 4; i < NumDims; ++i) { - pre_contract_dims[2] *= out.dimension(i); + pre_contract_dims[1] *= out.dimension(i); } } else { - pre_contract_dims[2] = kernelFilters; - pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[0] = inputRows * inputCols * inputPlanes; + pre_contract_dims[1] = + kernelFilters * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[0] = inputPlanes * inputRows * inputCols; for (int i = 0; i < NumDims - 4; ++i) { pre_contract_dims[0] *= out.dimension(i); } } - // We will contract along dimensions (0, 2) in kernel and (0, 1) in - // output_backward, if this is col-major, and - // dimensions (0, 2) in kernel and (1, 2) in output_backward, if this - // row-major. - array, 2> contract_dims; + // We will contract along the fused dimension that contains the kernelFilters, + // kernelPlanes, kernelRows and kernelCols. + array, 1> contract_dims; if (isColMajor) { // col-major: kernel.contract(output.patches) contract_dims[0] = IndexPair(0, 0); - contract_dims[1] = IndexPair(2, 1); } else { // row-major: output.patches.contract(kernel) - contract_dims[0] = IndexPair(1, 0); - contract_dims[1] = IndexPair(2, 2); + contract_dims[0] = IndexPair(1, 1); } // Post contraction, the dimensions of the input_backprop is @@ -261,40 +271,31 @@ CuboidConvolutionBackwardInput( } } - DSizes strides; - for (int i = 0; i < NumDims; i++) { - strides[i] = 1; - } - if (isColMajor) { - strides[1] = stridePlanes; - strides[2] = strideRows; - strides[3] = strideCols; - } else { - strides[NumDims - 2] = stridePlanes; - strides[NumDims - 3] = strideRows; - strides[NumDims - 4] = strideCols; - } - return choose( Cond::Layout == ColMajor>(), kernel.reverse(kernel_reverse) + .shuffle(kernel_shuffle) .reshape(kernel_dims) + .eval() .contract(output_backward .extract_volume_patches( kernelPlanes, kernelRows, kernelCols, 1, 1, 1, - stridePlanes, strideRows, strideCols, padding_ztop, - padding_zbottom, padding_top, padding_bottom, + plane_stride, row_stride, col_stride, padding_top_z, + padding_bottom_z, padding_top, padding_bottom, padding_left, padding_right) .reshape(pre_contract_dims), contract_dims) .reshape(post_contract_dims), output_backward .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, 1, 1, 1, - stridePlanes, strideRows, strideCols, - padding_ztop, padding_zbottom, padding_top, + plane_stride, row_stride, col_stride, + padding_top_z, padding_bottom_z, padding_top, padding_bottom, padding_left, padding_right) .reshape(pre_contract_dims) - .contract(kernel.reverse(kernel_reverse).reshape(kernel_dims), + .contract(kernel.reverse(kernel_reverse) + .shuffle(kernel_shuffle) + .reshape(kernel_dims) + .eval(), contract_dims) .reshape(post_contract_dims)); } diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h index cb0a76dac4..8d06107553 100644 --- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h @@ -189,14 +189,19 @@ SpatialConvolutionBackwardInput( } #endif - // Reorder the dimensions to filters X patch_rows X patch_cols X channels + // Reorder the dimensions to: + // filters x patch_rows x patch_cols x channels array kernel_shuffle; if (isColMajor) { + // From: filters x channels x rows x cols + // To: filters x rows x cols x channels kernel_shuffle[0] = 0; kernel_shuffle[1] = 2; kernel_shuffle[2] = 3; kernel_shuffle[3] = 1; } else { + // From: cols x rows x channels x filters + // To: channels x cols x rows x filters kernel_shuffle[0] = 2; kernel_shuffle[1] = 0; kernel_shuffle[2] = 1; -- GitLab From 7fa693209fe238478739b3982f652a7e35be91f3 Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Wed, 5 Sep 2018 10:34:12 -0700 Subject: [PATCH 124/540] Add HloSchedule class representing a sequential order of an HloModule. Currently we represent a sequential schedule of a module using a SequentialHloOrdering::HloModuleSequence which is a type alias of a bare map from HloComputation* to std::vector. This CL replaces this with a proper class which results in better encapsulation of code which deals with schedules and better enforcement of invariants. This CL also fixes a corner-case bug in dataflow analysis, where values of instructions which are live out of the computation erroneously did not interfere with the values of instructions scheduled after the root instruction. PiperOrigin-RevId: 211656888 --- tensorflow/compiler/xla/service/BUILD | 48 +++ .../compiler/xla/service/buffer_assignment.cc | 28 +- .../xla/service/buffer_assignment_test.cc | 98 ++--- .../xla/service/buffer_liveness_test.cc | 42 +-- .../compiler/xla/service/cpu/cpu_compiler.cc | 56 ++- .../compiler/xla/service/cpu/ir_emitter.cc | 2 +- .../compiler/xla/service/cpu/ir_emitter.h | 2 +- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/gpu_hlo_schedule.cc | 6 +- .../xla/service/gpu/gpu_hlo_schedule.h | 4 +- .../compiler/xla/service/heap_simulator.cc | 43 +-- .../compiler/xla/service/heap_simulator.h | 48 ++- .../xla/service/heap_simulator_test.cc | 36 +- .../xla/service/hlo_alias_analysis_test.cc | 16 +- .../xla/service/hlo_dataflow_analysis_test.cc | 29 +- .../compiler/xla/service/hlo_ordering.cc | 86 ++--- .../compiler/xla/service/hlo_ordering.h | 22 +- .../compiler/xla/service/hlo_ordering_test.cc | 101 ++++++ .../xla/service/hlo_rematerialization.cc | 87 ++--- .../xla/service/hlo_rematerialization.h | 19 +- .../xla/service/hlo_rematerialization_test.cc | 46 +-- .../compiler/xla/service/hlo_schedule.cc | 291 +++++++++++++++ .../compiler/xla/service/hlo_schedule.h | 151 ++++++++ .../compiler/xla/service/hlo_schedule_test.cc | 341 +++++++++++++++++ .../compiler/xla/service/hlo_scheduling.cc | 230 ++---------- .../compiler/xla/service/hlo_scheduling.h | 54 +-- .../xla/service/hlo_scheduling_test.cc | 343 +++--------------- 27 files changed, 1325 insertions(+), 905 deletions(-) create mode 100644 tensorflow/compiler/xla/service/hlo_schedule.cc create mode 100644 tensorflow/compiler/xla/service/hlo_schedule.h create mode 100644 tensorflow/compiler/xla/service/hlo_schedule_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index f6cfac6537..612302781c 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -989,6 +989,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1036,6 +1037,7 @@ tf_cc_test( ":flatten_call_graph", ":hlo", ":hlo_ordering", + ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1049,6 +1051,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1062,6 +1065,7 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", + ":hlo_schedule", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1082,6 +1086,7 @@ tf_cc_test( ":hlo", ":hlo_dataflow_analysis", ":hlo_ordering", + ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1089,6 +1094,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1102,6 +1108,7 @@ cc_library( ":hlo", ":hlo_ordering", ":hlo_proto", + ":hlo_schedule", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1125,6 +1132,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1169,6 +1177,43 @@ cc_library( ], ) +cc_library( + name = "hlo_schedule", + srcs = ["hlo_schedule.cc"], + hdrs = ["hlo_schedule.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_schedule_test", + srcs = ["hlo_schedule_test.cc"], + deps = [ + ":heap_simulator", + ":hlo", + ":hlo_dce", + ":hlo_ordering", + ":hlo_parser", + ":hlo_schedule", + ":hlo_scheduling", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + ], +) + cc_library( name = "hlo_scheduling", srcs = ["hlo_scheduling.cc"], @@ -1177,6 +1222,7 @@ cc_library( ":heap_simulator", ":hlo", ":hlo_ordering", + ":hlo_schedule", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1205,6 +1251,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2366,6 +2413,7 @@ cc_library( ":hlo", ":hlo_dce", ":hlo_ordering", + ":hlo_schedule", ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 8b8c6bfd26..0f0af57626 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -617,18 +617,24 @@ Status BufferAssignment::ComputeSummaryStats() { } // Only compute total fragmentation if all computations have schedules. - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(module_); + bool schedule_complete = true; for (const auto& computation : module_->computations()) { - const std::vector* sequence = - liveness_->hlo_ordering().SequentialOrder(*computation); - if (sequence != nullptr) { - module_sequence.emplace(computation, *sequence); + if (!computation->IsFusionComputation()) { + const std::vector* sequence = + liveness_->hlo_ordering().SequentialOrder(*computation); + if (sequence == nullptr) { + schedule_complete = false; + } else { + schedule.set_sequence(computation, *sequence); + } } } - if (module_sequence.size() == module_->computation_count()) { + if (schedule_complete) { + TF_RETURN_IF_ERROR(schedule.Verify()); TF_ASSIGN_OR_RETURN( const int64 min_size, - HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_)); + HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } @@ -1064,7 +1070,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( // since buffers for kCall, kWhile, and kConditional sub-computations are // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(&assignment->module()); FlatSet all_buffers_to_assign; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; @@ -1072,7 +1078,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const std::vector* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); - module_sequence[computation] = *instruction_sequence; + schedule.set_sequence(computation, *instruction_sequence); all_buffers_to_assign.insert(buffers_to_assign.begin(), buffers_to_assign.end()); } @@ -1090,7 +1096,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HeapSimulator::Result result, HeapSimulator::Run(absl::make_unique( absl::make_unique(alignment)), - assignment->module(), module_sequence, + assignment->module(), schedule, assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, @@ -1121,7 +1127,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( HeapSimulator::Run( absl::make_unique( absl::make_unique(alignment)), - *computation, *instruction_sequence, + *computation, HloInstructionSequence(*instruction_sequence), assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 7398f105a0..03e155fc11 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -40,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -120,14 +122,10 @@ class BufferAssignmentTest : public HloVerifiedTestBase { HloModule* module, absl::Span instruction_sequence, int64 alignment = 1) { - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[module->entry_computation()] = - std::vector(instruction_sequence.begin(), - instruction_sequence.end()); + HloSchedule schedule(module); + schedule.set_sequence(module->entry_computation(), instruction_sequence); return BufferAssigner::Run( - module, - absl::make_unique(module, - module_sequence), + module, absl::make_unique(schedule), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -1785,11 +1783,10 @@ class WhileBufferAssignmentTest : public HloVerifiedTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { - auto sequence = - ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); + HloSchedule schedule = + ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, - absl::make_unique(module, sequence), + module, absl::make_unique(schedule), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -2096,17 +2093,25 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // Create a sequential order among all the instructions in the entry // computation, since the issue this test stresses depends on the order the // nodes are traversed during BufferAssignment. - SequentialHloOrdering::HloModuleSequence sequence; - sequence[module->entry_computation()] = { - token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}; + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + schedule.set_sequence( + module->entry_computation(), + {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}); + TF_ASSERT_OK(schedule.Verify()); + TF_ASSERT_OK_AND_ASSIGN( auto assignment, - BufferAssigner::Run( - module, absl::make_unique(module, sequence), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module, + absl::make_unique(schedule), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -2263,29 +2268,6 @@ ENTRY Main { GetAllocation(*buffers, param0, {1, 1})); } -static bool IsPostOrderTraversal( - const std::vector& sequence) { - tensorflow::gtl::FlatSet seen_so_far; - auto has_not_been_seen_yet = [&](const HloInstruction* instruction) { - return seen_so_far.count(instruction) == 0; - }; - - for (auto instruction : sequence) { - if (std::any_of(instruction->operands().begin(), - instruction->operands().end(), has_not_been_seen_yet) || - std::any_of(instruction->control_predecessors().begin(), - instruction->control_predecessors().end(), - has_not_been_seen_yet)) { - return false; // Not a post order. - } - if (!seen_so_far.insert(instruction).second) { - return false; // Not a "traversal". - } - } - - return true; -} - TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); @@ -2340,27 +2322,27 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { RunCopyInsertion(module); - auto sequence = - ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); + HloSchedule schedule = + ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); - // To trigger b/38494731, we want a specific Hlo sequence for the + // To trigger b/38494731, we want a specific Hlo schedule for the // root computation, so we overwrite that entry with a manually // crafted sequence. - sequence[module->entry_computation()] = { - input1, weights1, one, output1, while1->operand(0), while1, - input0, weights0, zero, output0, while0->operand(0), while0, - gte0, gte1, root_add}; + schedule.set_sequence(module->entry_computation(), + {input1, weights1, one, output1, while1->operand(0), + while1, input0, weights0, zero, output0, + while0->operand(0), while0, gte0, gte1, root_add}); - // If this ASSERT_TRUE fails, we constructed a bogus sequence above - // and this test itself is buggy. - ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); + // If this ASSERT fails, we constructed a bogus sequence above and this test + // itself is buggy. + TF_ASSERT_OK(schedule.Verify()); auto assignment = - BufferAssigner::Run( - module, absl::make_unique(module, sequence), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true) + BufferAssigner::Run(module, + absl::make_unique(schedule), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 26e26e316d..414bfe7999 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -166,12 +167,12 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto module = CreateNewModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique( - module.get(), sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param0, negate, param1, exp, add}); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -291,13 +292,12 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - SequentialHloOrdering::HloModuleSequence module_sequence; - std::vector order = {param, negate, exp, add}; - module_sequence.emplace(computation, order); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param, negate, exp, add}); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) + .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -339,14 +339,14 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build(add)); - SequentialHloOrdering::HloModuleSequence module_sequence; - std::vector order = {param, add, recv, - recv_done, send, send_done}; - module_sequence.emplace(computation, order); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, + {param, add, token, recv, recv_done, send, send_done}); + TF_ASSERT_OK(schedule.Verify()); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) + .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); // Check the root instruction (add) buffer interferes with the recv buffer. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 796f36510e..e7b6075994 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -584,16 +584,14 @@ StatusOr> CpuCompiler::RunBackend( // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), - DFSMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run(module.get(), - absl::make_unique( - module.get(), module_sequence), + absl::make_unique(schedule), BufferSizeBytesFunction(), memory_alignment, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); @@ -627,9 +625,10 @@ StatusOr> CpuCompiler::RunBackend( } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) + .EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &schedule.sequence(embedded_computation).instructions()) .status()); } string function_name_prefix = entry_computation->name().empty() @@ -637,9 +636,10 @@ StatusOr> CpuCompiler::RunBackend( : entry_computation->name(); TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, - ir_emitter.EmitComputation(entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &module_sequence.at(entry_computation))); + ir_emitter.EmitComputation( + entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + &schedule.sequence(entry_computation).instructions())); string function_name = [&]() { llvm::SmallVector function_name_vector; @@ -771,20 +771,18 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - ScheduleComputationsInModule(*module, BufferSizeBytesFunction())); + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module, - absl::make_unique(module, module_sequence), - BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module, + absl::make_unique(schedule), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -824,18 +822,18 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation(embedded_computation, - embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) + .EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &schedule.sequence(embedded_computation).instructions()) .status()); } const string& entry_point_name = options.entry_point_name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation(computation, entry_point_name, - /*is_top_level_computation=*/true, - &module_sequence.at(computation))); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + computation, entry_point_name, + /*is_top_level_computation=*/true, + &schedule.sequence(computation).instructions())); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index e5cf15c686..df8c2a636b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -110,7 +110,7 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - std::vector* instruction_order) { + const std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]; ordered? " << (instruction_order != nullptr); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 58a333b8fb..3df99464ba 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -98,7 +98,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - std::vector* instruction_order); + const std::vector* instruction_order); llvm::IRBuilder<>* b() { return &b_; } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a68b7a1bef..13ccff35f8 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -813,6 +813,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", + "//tensorflow/compiler/xla/service:hlo_schedule", "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 743035a84e..ea9376e101 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" @@ -198,11 +199,12 @@ StatusOr> GpuHloSchedule::Build( // All kernels are launched on a single stream, so there's no loss of // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( - schedule->thunk_launch_order_, - ScheduleOneComputation( + HloInstructionSequence sequence, + ScheduleComputation( *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); + schedule->thunk_launch_order_ = sequence.instructions(); } else { // BFS tends to increase concurrency, but also increases memory usage. BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h index 30a0e7cecd..07a7fc67aa 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h @@ -33,7 +33,9 @@ namespace gpu { // launches, because thunks may be scheduled onto concurrent streams. This // schedule is used by BufferAssigner to determine buffer liveness (i.e. to // minimize allocations), and also by ThunkSchedule to determine the thunk -// launch order. +// launch order. This class differs from xla::HloSchedule in that HloSchedule +// represents a total order of all instructions in the module for backends which +// execute HLO instructions strictly sequentially. class GpuHloSchedule { public: // Constructs an GpuHloSchedule for the given module, based on the given diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 38c3982ebf..e0f3a7e0e2 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -29,13 +29,13 @@ using tensorflow::gtl::FlatSet; /*static*/ StatusOr HeapSimulator::MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { + if (schedule.empty()) { return 0; } - const HloModule* module = module_sequence.begin()->first->parent(); + const HloModule* module = schedule.module(); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -47,14 +47,13 @@ StatusOr HeapSimulator::MinimumMemoryForModule( TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(absl::make_unique(), *module, - module_sequence, *points_to_analysis, size_function)); + schedule, *points_to_analysis, size_function)); return result.heap_size; } /*static*/ StatusOr HeapSimulator::MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap* @@ -71,13 +70,13 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options) { - HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); + HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule); const HloComputation* entry_computation = module.entry_computation(); - const std::vector& instruction_sequence = - FindOrDie(module_sequence, entry_computation); + const HloInstructionSequence& instruction_sequence = + schedule.sequence(entry_computation); TF_RETURN_IF_ERROR(heap.RunComputation( *entry_computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -86,13 +85,13 @@ StatusOr HeapSimulator::Run( /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, const tensorflow::gtl::FlatMap* memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr, memory_by_computation); + /*schedule=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -102,7 +101,7 @@ StatusOr HeapSimulator::Run( // 'instruction_sequence'. Status HeapSimulator::RunComputation( const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis) { VLOG(3) << "Computation:\n" << computation.ToString(); // The goal here is to minimize memory usage, assuming the given sequential @@ -133,7 +132,8 @@ Status HeapSimulator::RunComputation( // set of instructions that need to be visited contains all users of all // aliases, that is, all users of all instructions that have the buffer // contained in their points-to set. - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const PointsToSet& points_to = points_to_analysis.GetPointsToSet(instruction); const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); @@ -166,7 +166,8 @@ Status HeapSimulator::RunComputation( std::vector dead_buffers_to_free; std::vector operand_buffers_to_free; - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); @@ -285,14 +286,14 @@ Status HeapSimulator::RunComputation( // The order that the sub-computations are simulated does not affect // correctness; since the whole module has been scheduled, we know that the // sub-computations will never be run concurrently. - if (module_sequence_ != nullptr) { + if (schedule_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || instruction->opcode() == HloOpcode::kConditional || instruction->opcode() == HloOpcode::kWhile) { for (const HloComputation* called_computation : instruction->called_computations()) { - const std::vector& called_sequence = - FindOrDie(*module_sequence_, called_computation); + const HloInstructionSequence& called_sequence = + schedule_->sequence(called_computation); TF_RETURN_IF_ERROR(RunComputation( *called_computation, called_sequence, points_to_analysis)); } @@ -343,16 +344,16 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence, + const HloSchedule* schedule, const tensorflow::gtl::FlatMap* memory_by_computation) : no_fragmentation_stats_(absl::make_unique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence), + schedule_(schedule), memory_by_computation_(memory_by_computation) { - debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); + debug_trace_.set_whole_module_simulation(schedule_ != nullptr); } HeapSimulator::~HeapSimulator() {} diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index af05bedee7..ffbf947d5a 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -88,23 +89,22 @@ class HeapSimulator { // Returns the minimum memory required to compute an HLO module where all // computations have been scheduled (represented by the given - // module_sequence), assuming no fragmentation. + // schedule), assuming no fragmentation. static StatusOr MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function); // Returns the minimum memory required to compute the given computation, // assuming no fragmentation. static StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap* memory_by_computation = nullptr); // Run the heap simulation with the given algorithm, assuming the given - // module_sequence, which must contain a topologically-consistent total + // schedule, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid // if instructions are not run in exactly this sequence. // @@ -112,12 +112,12 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - static StatusOr Run( - std::unique_ptr algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + static StatusOr Run(std::unique_ptr algorithm, + const HloModule& module, + const HloSchedule& schedule, + const TuplePointsToAnalysis& points_to_analysis, + const BufferValue::SizeFunction& size_fn, + const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' // must contain a topologically-consistent total ordering of all instructions @@ -126,7 +126,7 @@ class HeapSimulator { static StatusOr Run( std::unique_ptr algorithm, const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options = Options(), @@ -134,21 +134,19 @@ class HeapSimulator { memory_by_computation = nullptr); private: - // If 'module_sequence' is non-null, it is used to find kCall and kWhile + // If 'schedule' is non-null, it is used to find kCall and kWhile // sub-computations, and the heap simulation for those sub-computations will // be run recursively. I.e. the simulation is run over the whole module. - HeapSimulator( - std::unique_ptr algorithm, - const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr, - const tensorflow::gtl::FlatMap* - memory_by_computation = nullptr); + HeapSimulator(std::unique_ptr algorithm, + const BufferValue::SizeFunction& size_fn, + const Options& options, const HloSchedule* schedule = nullptr, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); ~HeapSimulator(); - Status RunComputation( - const HloComputation& computation, - const std::vector& instruction_sequence, - const TuplePointsToAnalysis& points_to_analysis); + Status RunComputation(const HloComputation& computation, + const HloInstructionSequence& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis); bool IgnoreBuffer(const BufferValue* buffer) const; void Alloc(const BufferValue* buffer, const HloInstruction* instruction); @@ -169,11 +167,11 @@ class HeapSimulator { const std::unique_ptr algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; - // module_sequence_ is set by buffer assignment, and memory_by_computation_ is + // schedule_ is set by buffer assignment, and memory_by_computation_ is // set by hlo scheduling. Then, in RunComputation, we check both in order to // handle subcomputations. It would be good to unify the handling of // subcomputations, but it's not clear how. - const SequentialHloOrdering::HloModuleSequence* module_sequence_; + const HloSchedule* schedule_; const tensorflow::gtl::FlatMap* memory_by_computation_; diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 576c5ff7a4..1d98c45567 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -85,13 +86,16 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) - .ValueOrDie()); + HloSchedule schedule(module.get()); + schedule.set_sequence(cond_computation, + {cond_param, cond_iter, cond_data, cond_lt}); + schedule.set_sequence(body_computation, {body_param}); + schedule.set_sequence(entry_computation, {iter, data, tuple, while_op}); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ( + 56, + HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie()); } const char kAlloc[] = "Alloc"; @@ -149,10 +153,11 @@ class HeapSimulatorTracker { auto zero_size = [](const BufferValue& buffer) { return 0; }; auto algorithm = absl::make_unique( absl::make_unique(&actual_calls_)); - result_ = HeapSimulator::Run( - std::move(algorithm), *module_->entry_computation(), - instruction_sequence, *points_to_analysis_, zero_size) - .ConsumeValueOrDie(); + result_ = + HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(), + HloInstructionSequence(instruction_sequence), + *points_to_analysis_, zero_size) + .ConsumeValueOrDie(); } explicit HeapSimulatorTracker(const string& name) { @@ -168,11 +173,12 @@ class HeapSimulatorTracker { TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); // Construct the module sequence grouped by computation. - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(module_.get()); tensorflow::gtl::FlatMap reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { const HloInstruction* instruction = full_module_sequence[i]; - module_sequence[instruction->parent()].push_back(instruction); + schedule.GetOrCreateSequence(instruction->parent()) + .push_back(instruction); reverse_position[instruction] = full_module_sequence.size() - i; } @@ -185,8 +191,8 @@ class HeapSimulatorTracker { }; auto algorithm = absl::make_unique( absl::make_unique(&actual_calls_)); - result_ = HeapSimulator::Run(std::move(algorithm), *module_, - module_sequence, *points_to_analysis_, size_fn) + result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule, + *points_to_analysis_, size_fn) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 54abe3345d..0cd0ab36fc 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -885,18 +885,20 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { // For a sequential order, if there is interference iff the negate is after // the while. - SequentialHloOrdering::HloModuleSequence sequence; - sequence[body] = {body_param, body_root}; - sequence[condition] = {cond_param, cond_root}; + HloSchedule schedule(module_); + schedule.set_sequence(body, {body_param, body_root}); + schedule.set_sequence(condition, {cond_param, cond_root}); { - sequence[entry] = {init, xla_while, negate, entry_root}; - SequentialHloOrdering ordering(module_, sequence); + schedule.set_sequence(entry, {init, xla_while, negate, entry_root}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } { - sequence[entry] = {init, negate, xla_while, entry_root}; - SequentialHloOrdering ordering(module_, sequence); + schedule.set_sequence(entry, {init, negate, xla_while, entry_root}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 62eea2b06c..0a86f83ed9 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -1261,9 +1262,10 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param0, negate, param1, exp, add}}); - SequentialHloOrdering ordering(module_.get(), sequence); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param0, negate, param1, exp, add}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -1339,14 +1341,16 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { bool ssa_form = GetParam(); RunAnalysis(ssa_form); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param, xla_while}}); - sequence.insert({condition, {cond_param, cond_constant}}); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param, xla_while}); + schedule.set_sequence(condition, {cond_param, cond_constant}); // Construct the order such that 'constant' and its use 'exp' are before // body_param. - sequence.insert({body, {constant, exp, body_param, add}}); + schedule.set_sequence( + body, {constant, exp, body_param, add, dead_constant, dead_negate}); + TF_ASSERT_OK(schedule.Verify()); - SequentialHloOrdering ordering(module_.get(), sequence); + SequentialHloOrdering ordering(schedule); // 'add' is live out of the body and will interfere with an later instructions // such as 'dead_constant' and 'dead_negate'. @@ -1476,11 +1480,10 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - SequentialHloOrdering::HloModuleSequence sequence; - std::vector order = {param, negate, exp, add}; - sequence.emplace(entry, order); - - SequentialHloOrdering ordering(module_.get(), sequence); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param, negate, exp, add}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 0581d5c404..2105f7a349 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -252,6 +253,12 @@ bool HloOrdering::LiveRangeStrictlyBefore( VLOG(4) << a << " not defined before " << b; return false; } + + if (a.live_out_of_module()) { + VLOG(4) << a << " is live out of module and defined before " << b; + return false; + } + // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), @@ -264,6 +271,18 @@ bool HloOrdering::LiveRangeStrictlyBefore( return false; } } + + if (a.instruction()->parent() == b.instruction()->parent()) { + for (const HloPosition& position : a.positions()) { + if (position.instruction == + a.instruction()->parent()->root_instruction()) { + VLOG(4) << a << " is live out of computation and defined before " << b + << " which is in same computation"; + return false; + } + } + } + return true; } @@ -336,15 +355,24 @@ string DependencyHloOrdering::ToString() const { return ToStringHelper("DependencyHloOrdering"); } -SequentialHloOrdering::SequentialHloOrdering( - const HloModule* module, const HloModuleSequence& module_sequence) - : HloOrdering(module), module_sequence_(module_sequence) { +SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule) + : HloOrdering(schedule.module()), schedule_(schedule) { + Initialize(); +} + +SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule) + : HloOrdering(schedule.module()), schedule_(std::move(schedule)) { + Initialize(); +} + +void SequentialHloOrdering::Initialize() { // Create a map from instruction to its order position. - for (auto computation_order : module_sequence_) { - const std::vector& order = computation_order.second; + TF_DCHECK_OK(schedule_.Verify()); + for (const auto& computation_sequence : schedule_.sequences()) { + const std::vector& order = + computation_sequence.second.instructions(); for (int i = 0; i < order.size(); ++i) { - DCHECK_EQ(0, order_position_.count(order[i])); - order_position_.emplace(order[i], i); + InsertOrDie(&order_position_, order[i], i); } } } @@ -362,49 +390,13 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( const std::vector* SequentialHloOrdering::SequentialOrder( const HloComputation& computation) const { - auto find_it = module_sequence_.find(&computation); - return find_it == module_sequence_.end() ? nullptr : &find_it->second; + return schedule_.is_computation_scheduled(&computation) + ? &schedule_.sequence(&computation).instructions() + : nullptr; } string SequentialHloOrdering::ToString() const { - std::vector pieces; - pieces.push_back("SequentialHloOrdering"); - for (auto* computation : module_->computations()) { - pieces.push_back( - absl::StrFormat("computation %s order:", computation->name())); - // Gather all instructions in the module sequence for this computation and - // sort them by their position. - std::vector instructions; - for (auto& instruction_position : order_position_) { - const HloInstruction* instruction = instruction_position.first; - if (instruction->parent() == computation) { - instructions.push_back(instruction); - } - } - std::sort(instructions.begin(), instructions.end(), - [this](const HloInstruction* a, const HloInstruction* b) { - return order_position_.at(a) < order_position_.at(b); - }); - for (auto instruction : instructions) { - pieces.push_back(absl::StrFormat(" %s", instruction->name())); - } - } - return absl::StrJoin(pieces, "\n"); -} - -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence) { - for (auto computation_pair : module_sequence) { - const HloComputation* computation = computation_pair.first; - const std::vector& computation_sequence = - computation_pair.second; - out << "Computation " << computation->name() << ":\n"; - for (auto* instruction : computation_sequence) { - out << " " << instruction->name() << "\n"; - } - } - return out; + return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 985f3fa64d..b21071c4b2 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -183,17 +184,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: - // TODO(dimvar): HloModuleSequence is not a good name because it sounds like - // a sequence of modules, instead of a map of schedules for all computations - // in a module. We should change it at some point. - // - // A sequence of instructions for each computation in the module. - using HloModuleSequence = - tensorflow::gtl::FlatMap>; - - SequentialHloOrdering(const HloModule* module, - const HloModuleSequence& module_sequence); + SequentialHloOrdering(const HloSchedule& schedule); + SequentialHloOrdering(HloSchedule&& schedule); ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. @@ -203,10 +195,12 @@ class SequentialHloOrdering : public HloOrdering { string ToString() const override; protected: + void Initialize(); + bool ExecutesBeforeInSameComputation(const HloInstruction* a, const HloInstruction* b) const override; - const HloModuleSequence module_sequence_; + const HloSchedule schedule_; // The position of every instruction in the HLO module in its respective // computation sequence (a value of zero indicates the instruction is first in @@ -217,10 +211,6 @@ class SequentialHloOrdering : public HloOrdering { tensorflow::gtl::FlatMap order_position_; }; -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 126d3a2d9c..6b6005e7a5 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -23,11 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -376,5 +378,104 @@ ENTRY root { dataflow->GetValueDefinedAt(add_3))); } +TEST_F(HloOrderingTest, + ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) { + // Tests that values live out of the module should interfere with values + // defined after the root instruction. That is: + // + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloComputation* entry = + module->AddEntryComputation(builder.Build(/*root_instruction=*/root)); + + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param, root, dead}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + +TEST_F(HloOrderingTest, + ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) { + // Tests that values live out of a computation should interfere with values + // defined after the root instruction of the computation. That is: + // + // subcomputation: + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // entry computation: + // %c = constant(42.0) + // ROOT %call = call({%c}), subcomputation + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto subbuilder = HloComputation::Builder(TestName() + ".sub"); + HloInstruction* param = subbuilder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = subbuilder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = subbuilder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloComputation* subcomputation = module->AddEmbeddedComputation( + subbuilder.Build(/*root_instruction=*/root)); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {c}, subcomputation)); + HloComputation* entry = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(subcomputation, {param, root, dead}); + schedule.set_sequence(entry, {c, call}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index c9629926ea..0a0a6a323e 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -962,8 +962,7 @@ StatusOr HloRematerialization::CalledComputationsMemoryUsage( } StatusOr HloRematerialization::RematerializeComputation( - HloComputation* computation, - SequentialHloOrdering::HloModuleSequence* sequence, + HloComputation* computation, HloSchedule* schedule, int64 memory_limit_bytes) { VLOG(1) << "Rematerializing computation " << computation->name() << " with limit " << HumanReadableNumBytes(memory_limit_bytes); @@ -971,7 +970,8 @@ StatusOr HloRematerialization::RematerializeComputation( << HumanReadableNumBytes(computation_peak_memory_.at(computation)); CHECK(!ContainsKey(rematerialized_computations_, computation)); - InstructionList instruction_list(sequence->at(computation)); + InstructionList instruction_list( + schedule->sequence(computation).instructions()); MemoryUsageTracker memory_tracker(computation, size_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -1145,7 +1145,7 @@ StatusOr HloRematerialization::RematerializeComputation( 0, memory_limit_bytes - memory_tracker.memory_usage()); TF_ASSIGN_OR_RETURN( bool subcomputation_changed, - RematerializeComputation(called_computation, sequence, + RematerializeComputation(called_computation, schedule, subcomputation_memory_limit_bytes)); changed |= subcomputation_changed; } @@ -1179,12 +1179,12 @@ StatusOr HloRematerialization::RematerializeComputation( computation_peak_memory_.at(computation) = peak_memory; // Update order to include rematerialized instructions. - auto& dst = sequence->at(computation); - dst.clear(); + HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation); + sequence.clear(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { const HloInstruction* instruction = item->instruction; - dst.push_back(instruction); + sequence.push_back(instruction); } rematerialized_computations_.insert(computation); @@ -1194,20 +1194,21 @@ StatusOr HloRematerialization::RematerializeComputation( return changed; } -StatusOr HloRematerialization::Run( - HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit_bytes, RematerializationSizes* sizes, - CopyInsertion* copy_insertion) { - // The sequence is constructed entirely by this method. - TF_RET_CHECK(sequence->empty()); +StatusOr HloRematerialization::Run(HloModule* module, + HloSchedule* schedule, + int64 memory_limit_bytes, + RematerializationSizes* sizes, + CopyInsertion* copy_insertion) { + // The schedule is constructed entirely by this method. + TF_RET_CHECK(schedule->empty()); VLOG(1) << "HloRematerialization() with memory limit of " << HumanReadableNumBytes(memory_limit_bytes); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( - *module, + // Create initial schedule of HLO instructions. + TF_ASSIGN_OR_RETURN(*schedule, + ScheduleModule(*module, [this](const BufferValue& buffer) { return size_function_(buffer.shape()); }, @@ -1217,16 +1218,7 @@ StatusOr HloRematerialization::Run( // ordering from the HLO schedule allows for more copies to be eliminated. // TODO(b/80249101): Instead of a separate copy elision pass, use the // ordering from the HLO schedule directly for copy insertion. - - // First create a copy of the schedule which contains HloInstruction unique - // ids instead of HloInstruction*. This is necessary for updating the - // schedule below. - // TODO(b/113175018): Remove this when the HLO schedule is self-contained - // and can update itself. - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(*sequence); - - SequentialHloOrdering ordering(module, *sequence); + SequentialHloOrdering ordering(*schedule); TF_RETURN_IF_ERROR( copy_insertion->RemoveUnnecessaryCopies(ordering, module)); @@ -1241,10 +1233,10 @@ StatusOr HloRematerialization::Run( // The passes above can add and remove copies, update the schedule to // account for these transformations. Newly added instructions will be // placed ASAP in the schedule. - TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence)); + TF_RETURN_IF_ERROR(schedule->Update()); TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference( - SequentialHloOrdering(module, *sequence), module)); + SequentialHloOrdering(*schedule), module)); } TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); @@ -1271,12 +1263,13 @@ StatusOr HloRematerialization::Run( // sequential context. call_graph_ = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph_->VisitNodes( - [this, sequence](const CallGraphNode& node) -> Status { + [this, schedule](const CallGraphNode& node) -> Status { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], - ComputePeakMemory(node.computation(), - sequence->at(node.computation()))); + ComputePeakMemory( + node.computation(), + schedule->sequence(node.computation()).instructions())); } return Status::OK(); }, @@ -1295,7 +1288,7 @@ StatusOr HloRematerialization::Run( // Subcomputations called by the entry computation will also be // rematerialized. TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( - module->entry_computation(), sequence, + module->entry_computation(), schedule, adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an @@ -1305,30 +1298,7 @@ StatusOr HloRematerialization::Run( // After DCE, the module sequence may include instructions which no longer // exist. - for (const auto* computation : module->MakeNonfusionComputations()) { - if (sequence->at(computation).size() != computation->instruction_count()) { - // A size mismatch between the computation instruction count and the size - // of the ordering of instructions can only be caused by DCE. Rebuild the - // order by removing the deleted instructions from the order. - tensorflow::gtl::FlatSet instruction_set; - for (const auto& instruction : computation->instructions()) { - instruction_set.insert(instruction); - } - // Move the old order into a temporary vector, then build new order - // inplace. - std::vector& order = sequence->at(computation); - std::vector old_order; - using std::swap; - swap(order, old_order); - std::copy_if(old_order.begin(), old_order.end(), - std::back_inserter(order), - [&instruction_set](const HloInstruction* instruction) { - return ContainsKey(instruction_set, instruction); - }); - TF_RET_CHECK(sequence->at(computation).size() == - computation->instruction_count()); - } - } + TF_RETURN_IF_ERROR(schedule->Update()); VLOG(1) << "Rematerialized " << instructions_rematerialized_ << " instructions in module " << module->name() << "; " << net_instructions_added_ << " net instructions added"; @@ -1366,11 +1336,10 @@ StatusOr HloRematerialization::Run( /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, - MemorySchedulerAlgorithm scheduler_algorithm, - SequentialHloOrdering::HloModuleSequence* sequence, + MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule, RematerializationSizes* sizes, CopyInsertion* copy_insertion) { HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, + return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes, copy_insertion); } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 2ec004350a..fa0414b472 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -21,6 +21,7 @@ #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -50,7 +51,7 @@ class HloRematerialization { // // hlo_module: HLO module to rematerialize instructions in. // - // sequence: Should point to an empty HloModuleSequence. Upon return + // schedule: Should point to an empty HloSchedule. Upon return // contains the HLO instruction order which was used for // rematerialization. This is the order in which HLO instructions should // be emitted to minimize memory use. @@ -75,8 +76,8 @@ class HloRematerialization { static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, - SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr); + HloSchedule* schedule, RematerializationSizes* sizes, + CopyInsertion* copy_insertion = nullptr); protected: HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, @@ -87,10 +88,9 @@ class HloRematerialization { // Runs rematerialization on the given module. Returns whether the module was // changed. memory_limit is the target maximum peak memory usage by the - // module. sequence should be an empty HloModuleSequence. Upon return sequence + // module. schedule should be an empty HloSchedule. Upon return sequence // contains the memory-minimizing order in which to emit the HLO instructions. - StatusOr Run(HloModule* module, - SequentialHloOrdering::HloModuleSequence* sequence, + StatusOr Run(HloModule* module, HloSchedule* schedule, int64 memory_limit, RematerializationSizes* sizes, CopyInsertion* copy_insertion); @@ -98,10 +98,9 @@ class HloRematerialization { // order in which the computation's instructions will be emitted in the // backend. Rematerialized instructions will be added to the HLO computation // and inserted into 'order'. - StatusOr RematerializeComputation( - HloComputation* computation, - SequentialHloOrdering::HloModuleSequence* sequence, - int64 computation_memory_limit); + StatusOr RematerializeComputation(HloComputation* computation, + HloSchedule* schedule, + int64 memory_limit_bytes); // Computes and returns the peak memory used by the given computation. The // peak memory is the maximum total size of all live HLO instruction values at diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index ac8c97d380..83cb113bfb 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -141,13 +141,13 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } - StatusOr RunHloRematerialization( - int64 memory_limit_bytes, HloModule* module, - SequentialHloOrdering::HloModuleSequence* sequence) { + StatusOr RunHloRematerialization(int64 memory_limit_bytes, + HloModule* module, + HloSchedule* schedule) { TF_EXPECT_OK(verifier().Run(module).status()); return HloRematerialization::RematerializeAndSchedule( ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - sequence, /*sizes=*/nullptr); + schedule, /*sizes=*/nullptr); } // Various shapes used in the canned computations. @@ -170,12 +170,12 @@ TEST_F(HloRematerializationTest, SingleComputation) { const HloInstruction* concat = slice->operand(0); const HloInstruction* bcast = concat->operand(0); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/14 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // Root should not have changed. @@ -187,9 +187,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { // The rematerialized broadcast should be immediate before the concat in the // sequence. - EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2], + EXPECT_EQ(schedule.sequence(computation) + .instructions()[computation->instruction_count() - 2], concat); - EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3], + EXPECT_EQ(schedule.sequence(computation) + .instructions()[computation->instruction_count() - 3], remat_bcast); } @@ -203,10 +205,10 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/20 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -242,10 +244,10 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/17 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -276,10 +278,10 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(entry_computation->instruction_count(), 7); EXPECT_EQ(body_computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/15 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -316,10 +318,10 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/13 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -382,14 +384,14 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { ASSERT_EQ(count_rngs(entry_computation), 1); const int64 original_instruction_count = entry_computation->instruction_count(); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( bool changed, RunHloRematerialization( /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -476,13 +478,13 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_EQ(add_3->operand(0), bcast); EXPECT_EQ(add_4->operand(0), bcast); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/22 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -571,13 +573,13 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { EXPECT_EQ(entry_computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/22 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc new file mode 100644 index 0000000000..a65b33bf40 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -0,0 +1,291 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_schedule.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace xla { + +void HloSchedule::set_sequence( + const HloComputation* computation, + absl::Span sequence) { + set_sequence(computation, HloInstructionSequence(sequence)); +} + +void HloSchedule::set_sequence(const HloComputation* computation, + HloInstructionSequence sequence) { + CHECK(computation->parent() == module_); + sequences_[computation->unique_id()] = std::move(sequence); +} + +HloInstructionSequence& HloSchedule::GetOrCreateSequence( + const HloComputation* computation) { + auto it = sequences_.find(computation->unique_id()); + if (it == sequences_.end()) { + // No sequence found for computation. Create and return an empty one. + CHECK(computation->parent() == module_); + return sequences_[computation->unique_id()]; + } else { + return it->second; + } +} + +const HloInstructionSequence& HloSchedule::sequence( + const HloComputation* computation) const { + return sequences_.at(computation->unique_id()); +} + +Status HloSchedule::UpdateComputationSchedule( + const HloComputation* computation) { + // Map from unique ID to HloInstruction pointer for instructions in the + // computation. + tensorflow::gtl::FlatMap id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); + } + + // Set of all HloInstructions in the schedule. + tensorflow::gtl::FlatSet ids_in_schedule; + for (int id : sequences_.at(computation->unique_id()).ids()) { + InsertOrDie(&ids_in_schedule, id); + } + + // Map from HloInstruction X to newly added instructions (instruction is in + // computation, but not in schedule) which use X. If an instruction is not in + // the map, then it has no users which are newly added instructions. + tensorflow::gtl::FlatMap> + new_instruction_uses; + + // For each newly added instruction, this is the count of the instruction's + // operands that have not yet been scheduled. When this value reaches zero, + // then the instruction may be placed in the schedule. + tensorflow::gtl::FlatMap + unscheduled_operand_count; + + // Create a worklist of newly added instructions which are ready to be added + // to the schedule. Initialize worklist with those that have zero operands. + std::queue worklist; + + for (const HloInstruction* instruction : computation->instructions()) { + if (ids_in_schedule.count(instruction->unique_id()) == 0) { + // This is a newly added instruction which is not in the schedule. + if (instruction->operands().empty()) { + worklist.push(instruction); + } else { + for (const HloInstruction* operand : instruction->operands()) { + new_instruction_uses[operand].push_back(instruction); + } + unscheduled_operand_count[instruction] = instruction->operand_count(); + } + } + } + + // Update the schedule with the newly added instructions, and remove any + // instructions no longer in the graph. + HloInstructionSequence new_sequence; + + // Lambda which schedules all instructions on the worklist. + auto schedule_worklist = [&]() { + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop(); + new_sequence.push_back(instruction); + std::vector* new_users = + tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); + if (new_users != nullptr) { + // This just-scheduled instruction has users which are newly added to + // the module. Update the number of unscheduled operands and push the + // newly added instruction to the worklist if it is ready to + // schedule. + for (const HloInstruction* new_user : *new_users) { + unscheduled_operand_count.at(new_user)--; + CHECK_GE(unscheduled_operand_count.at(new_user), 0); + if (unscheduled_operand_count.at(new_user) == 0) { + worklist.push(new_user); + } + } + } + } + }; + + schedule_worklist(); + for (int id : sequences_.at(computation->unique_id()).ids()) { + auto it = id_to_instruction.find(id); + if (it == id_to_instruction.end()) { + // This instruction in the schedule is no longer in the module. Do not add + // it to the new schedule. + continue; + } + worklist.push(it->second); + schedule_worklist(); + } + + set_sequence(computation, std::move(new_sequence)); + return Status::OK(); +} + +Status HloSchedule::Update() { + // The schedule must contain a sequence for every non-fusion computation in + // the module, but can have sequences for computations which no longer exist + // (these are removed). + std::vector nonfusion_computations = + module_->MakeNonfusionComputations(); + for (const HloComputation* computation : nonfusion_computations) { + TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + << "Computation " << computation->name() << " not in HloSchedule."; + } + if (sequences_.size() > nonfusion_computations.size()) { + // Schedule contains some computations which have been removed from the + // HloModule. Remove them from the schedule as well. + tensorflow::gtl::FlatSet nonfusion_computations_ids; + for (const HloComputation* computation : nonfusion_computations) { + nonfusion_computations_ids.insert(computation->unique_id()); + } + for (auto it = sequences_.begin(); it != sequences_.end();) { + if (nonfusion_computations_ids.count(it->first) == 0) { + it = sequences_.erase(it); + } else { + it++; + } + } + } + CHECK_EQ(sequences_.size(), nonfusion_computations.size()); + + for (const HloComputation* computation : nonfusion_computations) { + TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation)); + } + + TF_RETURN_IF_ERROR(Verify()); + return Status::OK(); +} + +Status HloSchedule::Verify() const { + VLOG(2) << "VerifySchedule()"; + XLA_VLOG_LINES(3, module_->ToString()); + XLA_VLOG_LINES(2, ToString()); + + // Verify schedule contains exactly the same set of non-fusion computations as + // module currently does. + std::vector nonfusion_computations = + module_->MakeNonfusionComputations(); + TF_RET_CHECK(nonfusion_computations.size() == sequences_.size()) + << "Schedule has " << sequences_.size() << " sequences, but module has " + << nonfusion_computations.size() << " non-fusion computations"; + for (const HloComputation* computation : nonfusion_computations) { + TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + << "Computation " << computation->name() + << " missing from HLO schedule."; + } + + // For each computation verify the set of instructions is the same and that + // each dependency and control edge is honored. + for (const HloComputation* computation : nonfusion_computations) { + tensorflow::gtl::FlatMap instruction_position; + int pos = 0; + for (const HloInstruction* instruction : + sequence(computation).instructions()) { + TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) + << "Instruction " << instruction->name() + << " appears more than once in the schedule"; + pos++; + } + + TF_RET_CHECK(instruction_position.size() == + computation->instruction_count()); + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(instruction_position.count(instruction) == 1) + << "Instruction " << instruction->name() << " is not in schedule"; + } + + for (const HloInstruction* instruction : computation->instructions()) { + for (const HloInstruction* operand : instruction->operands()) { + TF_RET_CHECK(instruction_position.at(operand) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its operand " << operand->name(); + } + + for (const HloInstruction* pred : instruction->control_predecessors()) { + TF_RET_CHECK(instruction_position.at(pred) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its control predecessor " + << pred->name(); + } + } + } + + return Status::OK(); +} + +namespace { + +// Returns the computation in the given module with the given unique ID. Returns +// nullptr if no such computation exists. +const HloComputation* IdToComputation(const HloModule* module, int64 id) { + for (const HloComputation* computation : module->computations()) { + if (computation->unique_id() == id) { + return computation; + } + } + return nullptr; +} + +} // namespace + +string HloSchedule::ToString() const { + std::vector pieces; + + pieces.push_back("HloSchedule"); + for (const auto& id_sequence : sequences_) { + const HloComputation* computation = + IdToComputation(module_, id_sequence.first); + if (computation == nullptr) { + // The computation is not in the module and may have been deleted so it is + // not safe to dereference any HLO pointers. Just use the HLO unique ids + // stored in this object. + pieces.push_back( + absl::StrFormat("computation with id %d (no longer in HLO module):", + id_sequence.first)); + for (int id : id_sequence.second.ids()) { + pieces.push_back(absl::StrCat(" ", id)); + } + } else { + pieces.push_back(absl::StrFormat("computation %s:", computation->name())); + for (const HloInstruction* instruction : + id_sequence.second.instructions()) { + pieces.push_back(absl::StrCat(" ", instruction->name())); + } + } + } + return absl::StrJoin(pieces, "\n"); +} + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) { + out << schedule.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h new file mode 100644 index 0000000000..21c6988638 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { + +// Class representing a sequence of HLO instructions such as the sequential +// execution order of an HLO computation. +class HloInstructionSequence { + public: + HloInstructionSequence() = default; + HloInstructionSequence(absl::Span instructions) { + for (const HloInstruction* instruction : instructions) { + push_back(instruction); + } + } + + // Adds the instruction to the end of the sequence. + void push_back(const HloInstruction* instruction) { + instruction_sequence_.push_back(instruction); + id_sequence_.push_back(instruction->unique_id()); + } + + // Clears the sequence of all instructions. + void clear() { + instruction_sequence_.clear(); + id_sequence_.clear(); + } + + int64 size() const { return instruction_sequence_.size(); } + + // Returns the sequence of HLO instructions. + const std::vector& instructions() const { + return instruction_sequence_; + } + + // Returns the unique IDs of the instructions in the sequence (in order). + const std::vector& ids() const { return id_sequence_; } + + private: + // The sequence as HloInstructions. + std::vector instruction_sequence_; + + // The sequence of HLO instructions, represented by their unique IDs. The + // sequence is stored as both HloInstructions and unique IDs because the + // sequence may be referenced after transformations to the HLO graph and HLO + // pointers can be invalidated or recycled in this process (see + // HloSchedule::Update). + std::vector id_sequence_; +}; + +// A class representing a sequential schedule of instructions for an HLO +// module. A complete HLO schedule contains an instruction sequence for every +// non-fusion computation in the HLO module. +class HloSchedule { + public: + HloSchedule(const HloModule* module) : module_(module) {} + + // Returns a reference to the sequence for the given computation. + const HloInstructionSequence& sequence( + const HloComputation* computation) const; + + // Returns the sequence for the given computation. An empty sequence is + // created if none exists for the computation. + HloInstructionSequence& GetOrCreateSequence( + const HloComputation* computation); + + // Sets the sequence for the given computation to the given sequence. + void set_sequence(const HloComputation* computation, + absl::Span sequence); + void set_sequence(const HloComputation* computation, + HloInstructionSequence sequence); + + // Returns a map from HloComputation unique ID to instruction sequence. The + // map contains all sequences in the schedule. + const tensorflow::gtl::FlatMap& sequences() + const { + return sequences_; + } + + // Returns true if the schedule has a sequence for the given computation. + bool is_computation_scheduled(const HloComputation* computation) const { + return sequences_.count(computation->unique_id()) == 1; + } + + // Updates the schedule such that it is (again) a valid schedule for the + // module. This is used to update a schedule after the HLO module has been + // transformed in some way. In general, the only transformations to the module + // for which a schedule can be updated is the addition or removal of + // instructions and removal of computations. Updating the schedule after new + // dependencies between existing instructions in the module is not supported + // and may result in an error status returned. + // + // Instructions in the module which also exist in the given schedule will + // remain in the same order in the updated schedule. Instructions which exist + // in the module but not in the given schedule will be placed as early as + // possible in the updated schedule. + Status Update(); + + // Verifies that the given schedule is valid for the given module. + // Specifically, the schedule contains exactly the instructions in the + // non-fusion computations in the module and every dependency in the module is + // satisfied in the schedule. + Status Verify() const; + + string ToString() const; + + bool empty() const { return sequences_.empty(); } + + const HloModule* module() const { return module_; } + + private: + // Updates the instruction sequence for the given computation. + Status UpdateComputationSchedule(const HloComputation* computation); + + const HloModule* module_; + + // A map from computation unique ID to instruction sequence. Unique IDs are + // used rather than HloComputation pointers because HLO pointers are not + // unique across HLO transformations because pointers may be recycled. + tensorflow::gtl::FlatMap sequences_; +}; + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc new file mode 100644 index 0000000000..eb52582bb5 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -0,0 +1,341 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_schedule.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloScheduleTest : public HloTestBase {}; + +TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) { + // Updating the schedule of an unchanged HLO module should not affect the + // schedule at all. + const string module_str = R"( +HloModule UpdateScheduleUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + const std::vector& entry_schedule = + schedule.sequence(module->entry_computation()).instructions(); + + EXPECT_EQ(entry_schedule.size(), 6); + + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(entry_schedule, + schedule.sequence(module->entry_computation()).instructions()); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) { + // Add some additional instructions to a module and verify the schedule can be + // updated. + const string module_str = R"( +HloModule UpdateScheduleWithNewInstructions + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + HloComputation* entry = module->entry_computation(); + const Shape shape = entry->root_instruction()->shape(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, constant, entry->root_instruction())); + entry->set_root_instruction(sub); + + auto in_schedule = [&](const HloInstruction* hlo) { + return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo); + }; + + EXPECT_EQ(schedule.sequence(entry).size(), 6); + EXPECT_FALSE(in_schedule(constant)); + EXPECT_FALSE(in_schedule(sub)); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 8); + EXPECT_TRUE(in_schedule(constant)); + EXPECT_TRUE(in_schedule(sub)); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) { + // Add and delete some instructions from a module and verify that the schedule + // can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithAddedAndDeletedInstruction + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + // Set the entry root to some expression containing just a parameter and a + // constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* new_root = entry->AddInstruction( + HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, + constant, entry->parameter_instruction(0))); + entry->set_root_instruction(new_root); + + // DCE should remove everything but the parameters and the newly added code. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(entry).size(), 6); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 4); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) { + // Completely replace a module with an entirely new set of instructions and + // verify that the schedule can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithCompletelyReplacedModule + +ENTRY main { + a = f32[] constant(42.0) + b = f32[] constant(123.0) + ROOT sum = f32[] add(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + // Replace the entry computation with the negation of a constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNegate, constant)); + entry->set_root_instruction(new_root); + + // DCE the old instructions. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(entry).size(), 3); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 2); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) { + // Create changes to more than one computation in an HLO module and verify + // that the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + const HloInstruction* xla_while = + module->entry_computation()->root_instruction()->operand(0); + HloComputation* body = xla_while->while_body(); + HloComputation* cond = xla_while->while_condition(); + + // Negate the root of the cond. + cond->set_root_instruction(cond->AddInstruction( + HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kNot, cond->root_instruction()))); + + // Replace the body with a computation which just passes through its + // parameter. + body->set_root_instruction(body->parameter_instruction(0)); + + // DCE the dead code in the body. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(body).size(), 7); + EXPECT_EQ(schedule.sequence(cond).size(), 4); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(body).size(), 1); + EXPECT_EQ(schedule.sequence(cond).size(), 5); +} + +TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) { + // Remove computations from a module and verify the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + HloInstruction* xla_while = + module->entry_computation()->root_instruction()->mutable_operand(0); + HloInstruction* init = xla_while->mutable_operand(0); + + // Replace the while with its init value. The conditional and body + // computations should then be dead. + TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init)); + + // DCE the dead code in the body. + HloDCE dce; + ASSERT_EQ(module->computation_count(), 3); + TF_ASSERT_OK(dce.Run(module.get()).status()); + ASSERT_EQ(module->computation_count(), 1); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 0fc3b268c0..9bfb0af96c 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -70,7 +70,7 @@ class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions // containing the given HLO computation. - static StatusOr> Run( + static StatusOr Run( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -229,8 +229,8 @@ class ListScheduler { return {BytesFreedIfScheduled(entry), entry.instruction->user_count()}; } - std::vector CreateSchedule() { - std::vector schedule; + HloInstructionSequence CreateSchedule() { + HloInstructionSequence schedule; // Populate the ready list with instructions which have no operands or // control predecessors. @@ -374,7 +374,7 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> ScheduleComputationHelper( +StatusOr ScheduleComputationHelper( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -392,7 +392,7 @@ StatusOr> ScheduleComputationHelper( } // namespace -StatusOr> DFSMemoryScheduler( +StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -443,7 +443,7 @@ StatusOr> DFSMemoryScheduler( // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a // tiebreaker by name for determinism. - std::vector sequence; + HloInstructionSequence sequence; FunctionVisitor visitor([&sequence](HloInstruction* hlo) { sequence.push_back(hlo); return Status::OK(); @@ -463,7 +463,7 @@ StatusOr> DFSMemoryScheduler( return sequence; } // namespace xla -StatusOr> ListMemoryScheduler( +StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -473,18 +473,16 @@ StatusOr> ListMemoryScheduler( memory_by_computation); } -StatusOr> PostOrderMemoryScheduler( +StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation) { - const auto& post_order = computation.MakeInstructionPostOrder(); - return std::vector{post_order.begin(), - post_order.end()}; + return HloInstructionSequence(computation.MakeInstructionPostOrder()); } -StatusOr> DefaultMemoryScheduler( +StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -499,7 +497,7 @@ StatusOr> DefaultMemoryScheduler( // List wins for most of our benchmarks; postorder-based schedulers win for // some RNNs. TF_ASSIGN_OR_RETURN( - std::vector list_sequence, + HloInstructionSequence list_sequence, ListMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 list_memory, @@ -508,7 +506,7 @@ StatusOr> DefaultMemoryScheduler( size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, + TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 dfs_memory, @@ -518,7 +516,7 @@ StatusOr> DefaultMemoryScheduler( VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); TF_ASSIGN_OR_RETURN( - std::vector post_order_sequence, + HloInstructionSequence post_order_sequence, PostOrderMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 post_order_memory, @@ -545,32 +543,35 @@ StatusOr> DefaultMemoryScheduler( } } -StatusOr ScheduleComputationsInModule( +StatusOr ScheduleModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm) { - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(&module); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); tensorflow::gtl::FlatMap memory_by_computation; for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { - TF_ASSIGN_OR_RETURN(auto one_computation_sequence, + TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, ScheduleComputationHelper( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = HeapSimulator::MinimumMemoryForComputation( - *computation, one_computation_sequence, *points_to_analysis, + *computation, computation_sequence, *points_to_analysis, size_function, &memory_by_computation) .ValueOrDie(); - sequence[computation] = std::move(one_computation_sequence); + schedule.set_sequence(computation, std::move(computation_sequence)); } } - VLOG(1) << "Module schedule:\n" << sequence; - return sequence; + VLOG(1) << "Module schedule:\n" << schedule; + + TF_RETURN_IF_ERROR(schedule.Verify()); + + return std::move(schedule); } -StatusOr> ScheduleOneComputation( +StatusOr ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); @@ -581,187 +582,4 @@ StatusOr> ScheduleOneComputation( size_function, nullptr, empty_map); } -tensorflow::gtl::FlatMap> -ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) { - tensorflow::gtl::FlatMap> id_sequence; - for (const auto& computation_sequence : sequence) { - for (const HloInstruction* instruction : computation_sequence.second) { - id_sequence[computation_sequence.first].push_back( - instruction->unique_id()); - } - } - return id_sequence; -} - -Status UpdateSchedule( - const HloModule& module, - const tensorflow::gtl::FlatMap>& - id_sequence, - SequentialHloOrdering::HloModuleSequence* sequence) { - // Map from unique ID to HloInstruction pointer for instructions in the - // module. - tensorflow::gtl::FlatMap id_to_instruction; - // Set of all HloInstructions in the schedule. - tensorflow::gtl::FlatSet ids_in_schedule; - std::vector nonfusion_computations = - module.MakeNonfusionComputations(); - for (const HloComputation* computation : nonfusion_computations) { - for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK( - id_to_instruction.insert({instruction->unique_id(), instruction}) - .second); - } - for (int id : id_sequence.at(computation)) { - ids_in_schedule.insert(id); - } - } - - // Map from HloInstruction X to newly added instructions (instruction is in - // module, but not in schedule) which use X. If an instruction is not in the - // map, then it has no users which are newly added instructions. - tensorflow::gtl::FlatMap> - new_instruction_uses; - - // For each newly added instruction, this is the count of the instruction's - // operands that have not yet been scheduled. When this value reaches zero, - // then the instruction may be placed in the schedule. - tensorflow::gtl::FlatMap - unscheduled_operand_count; - // For each computation, this is the set of newly added instructions which - // have no operands. These must be handled specially and are added to the - // beginning of the schedule. - tensorflow::gtl::FlatMap> - new_zero_operand_instructions; - for (const HloComputation* computation : nonfusion_computations) { - new_zero_operand_instructions[computation] = {}; - for (const HloInstruction* instruction : computation->instructions()) { - if (ids_in_schedule.count(instruction->unique_id()) == 0) { - // This is a newly added instruction which is not in the schedule. - for (const HloInstruction* operand : instruction->operands()) { - new_instruction_uses[operand].push_back(instruction); - } - if (instruction->operands().empty()) { - new_zero_operand_instructions[computation].push_back(instruction); - } - unscheduled_operand_count[instruction] = instruction->operand_count(); - } - } - } - - // Update the schedule with the newly added instructions, and remove any - // instructions no longer in the graph. - for (const HloComputation* computation : nonfusion_computations) { - std::vector old_computation_sequence = - std::move(sequence->at(computation)); - sequence->at(computation).clear(); - - // Create a worklist of newly added instructions which are ready to be added - // to the schedule. Initialize worklist with those that have zero operands. - std::queue worklist; - for (const HloInstruction* instruction : - new_zero_operand_instructions.at(computation)) { - worklist.push(instruction); - } - - // Lambda which schedules all instructions on the worklist. - auto schedule_worklist = [&]() { - while (!worklist.empty()) { - const HloInstruction* instruction = worklist.front(); - worklist.pop(); - sequence->at(computation).push_back(instruction); - std::vector* new_users = - tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); - if (new_users != nullptr) { - // This just-scheduled instruction has users which are newly added to - // the module. Update the number of unscheduled operands and push the - // newly added instruction to the worklist if it is ready to - // schedule. - for (const HloInstruction* new_user : *new_users) { - unscheduled_operand_count.at(new_user)--; - CHECK_GE(unscheduled_operand_count.at(new_user), 0); - if (unscheduled_operand_count.at(new_user) == 0) { - worklist.push(new_user); - } - } - } - } - }; - - schedule_worklist(); - for (int id : id_sequence.at(computation)) { - auto it = id_to_instruction.find(id); - if (it == id_to_instruction.end()) { - // This instruction in the schedule is no longer in the module. - continue; - } - const HloInstruction* instruction = it->second; - worklist.push(instruction); - schedule_worklist(); - } - } - - TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence)); - return Status::OK(); -} - -Status VerifySchedule( - const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& sequence) { - VLOG(2) << "VerifySchedule()"; - XLA_VLOG_LINES(2, module.ToString()); - VLOG(2) << sequence; - - // Verify the set of computations in the sequence is exactly the set of - // computations in the module. - std::vector nonfusion_computations = - module.MakeNonfusionComputations(); - TF_RET_CHECK(nonfusion_computations.size() == sequence.size()); - tensorflow::gtl::FlatSet computations_in_module( - module.computations().begin(), module.computations().end()); - for (const auto& computation_sequence : sequence) { - TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1); - } - - // For each computation verify the set of instructions is the same and that - // each dependency and control edge is honored. - for (const HloComputation* computation : nonfusion_computations) { - tensorflow::gtl::FlatMap instruction_position; - int pos = 0; - for (const HloInstruction* instruction : sequence.at(computation)) { - TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) - << "Instruction " << instruction->name() - << " appears more than once in the schedule"; - pos++; - } - - TF_RET_CHECK(instruction_position.size() == - computation->instruction_count()); - for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(instruction_position.count(instruction) == 1) - << "Instruction " << instruction->name() << " is not in schedule"; - } - - for (const HloInstruction* instruction : computation->instructions()) { - for (const HloInstruction* operand : instruction->operands()) { - TF_RET_CHECK(instruction_position.at(operand) < - instruction_position.at(instruction)) - << "Instruction " << instruction->name() - << " is not scheduled after its operand " << operand->name(); - } - - for (const HloInstruction* pred : instruction->control_predecessors()) { - TF_RET_CHECK(instruction_position.at(pred) < - instruction_position.at(instruction)) - << "Instruction " << instruction->name() - << " is not scheduled after its control predecessor " - << pred->name(); - } - } - } - - return Status::OK(); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index d06b8d9a5c..54e32340ba 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -32,14 +33,14 @@ namespace xla { // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. -typedef std::function>( +typedef std::function( const HloComputation&, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, const tensorflow::gtl::FlatMap&)> MemorySchedulerAlgorithm; // List scheduler -StatusOr> ListMemoryScheduler( +StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -47,7 +48,7 @@ StatusOr> ListMemoryScheduler( memory_by_computation); // DFS-order scheduler -StatusOr> DFSMemoryScheduler( +StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -55,7 +56,7 @@ StatusOr> DFSMemoryScheduler( memory_by_computation); // Naive Post Order scheduler -StatusOr> PostOrderMemoryScheduler( +StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -65,63 +66,26 @@ StatusOr> PostOrderMemoryScheduler( // The default scheduling algorithm. Runs both the list scheduler // and the DFS scheduler, and chooses whichever returns a lower min-memory, // not accounting for fragmentation. -StatusOr> DefaultMemoryScheduler( +StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation); -// Returns an HloModuleSequence which seeks to minimize the memory required for +// Returns an HloSchedule which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. -StatusOr ScheduleComputationsInModule( +StatusOr ScheduleModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm = {}); // Computes the schedule for a single computation. // Currently only used by the GPU backend. -StatusOr> ScheduleOneComputation( +StatusOr ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); -// Transforms the given schedule such that it is (again) a valid schedule for -// the module. This is used to update a schedule after the HLO module has been -// transformed in some way. In general, the only transformations to the module -// for which a schedule can be updated is the addition or removal of -// instructions to/from the module. Updating the schedule after new dependencies -// between existing instructions in the module is not supported and may result -// in an error status returned. -// -// Instructions in the module which also exist in the given schedule will remain -// in the same order in the updated schedule. Instructions which exist in the -// module but not in the given schedule will be placed as early as possible in -// the updated schedule. -// -// 'id_sequence' is a mirror of the given schedule 'sequence' but with -// HloInstruction ids rather than HloInstruction pointers. This should be -// constructed using ComputeIdSchedule below after the schedule is constructed -// but before the HLO module is transformed. -Status UpdateSchedule( - const HloModule& module, - const tensorflow::gtl::FlatMap>& - id_sequence, - SequentialHloOrdering::HloModuleSequence* sequence); - -// Constructs a copy of the given schedule but with HloInstruction unique ids -// rather than HloInstruction pointers. This is necessary for updating a -// schedule as HloInstruction points in the schedule may become invalid if -// instructions are removed from the module. Used by UpdateSchedule above.. -// TODO(b/113175018): Remove this function when HLO schedule is its own class. -tensorflow::gtl::FlatMap> -ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence); - -// Verifies that the given schedule is valid for the given module. Specifically, -// the schedule contains exactly the instructions in the module and every -// dependency in the module is satisfied in the schedule. -Status VerifySchedule(const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& sequence); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index d49d09d459..6afe51997e 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -67,19 +68,20 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + const std::vector& sequence = + schedule.sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); - EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); + EXPECT_EQ(param, sequence.front()); + EXPECT_EQ(sub, sequence.back()); - SequentialHloOrdering ordering(module.get(), sequence); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); } @@ -108,28 +110,26 @@ ENTRY root { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + const std::vector& sequence = + schedule.sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); std::unordered_map instructions_by_name; - for (const HloInstruction* instruction : - sequence.at(module->entry_computation())) { + for (const HloInstruction* instruction : sequence) { instructions_by_name[instruction->name()] = instruction; } // The first instruction should be the parameter and the last the root. - EXPECT_EQ(instructions_by_name.at("param"), - sequence.at(module->entry_computation()).front()); - EXPECT_EQ(instructions_by_name.at("result"), - sequence.at(module->entry_computation()).back()); + EXPECT_EQ(instructions_by_name.at("param"), sequence.front()); + EXPECT_EQ(instructions_by_name.at("result"), sequence.back()); // Instructions "d" and "e" will both be schedulable at the same time, but // instruction "d" allows us to free the buffer of "p1", so the list scheduler // should prefer it. - SequentialHloOrdering ordering(module.get(), sequence); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), instructions_by_name.at("e"))); } @@ -220,13 +220,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); EXPECT_EQ(entry_computation->instruction_count(), - sequence.at(entry_computation).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(entry_computation).size()); + SequentialHloOrdering ordering(schedule); // This schedule is an example of List's greedy heuristics being suboptimal. // The while_loop is more expensive than transpose, so it would have been // better to schedule it first, instead of during the busy time. @@ -243,13 +243,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { // HeapSimulator doesn't account for subcomputations EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); // HeapSimulator accounts for subcomputations. The output buffer is aliased, // so we don't double count. EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } @@ -281,19 +281,18 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); // tuple allocates the tuple buffer and doesn't free anything. // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. // abs_abs2 should be scheduled before tuple by List. @@ -332,18 +331,18 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { auto fusion = computation->CreateFusionInstruction( {tuple, mul, add}, HloInstruction::FusionKind::kLoop); - TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule( - *module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), 2); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), 2); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); // fusion allocates memory for the tuple elements and doesn't free anything, // so it's more expensive than exp. EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); @@ -391,12 +390,12 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); - EXPECT_EQ(entry_computation->instruction_count(), - sequence.at(entry_computation).size()); + EXPECT_EQ(module->entry_computation()->instruction_count(), + schedule.sequence(module->entry_computation()).size()); tensorflow::gtl::FlatMap memory_by_computation; memory_by_computation[cond_computation] = 17; @@ -406,262 +405,16 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { // HeapSimulator doesn't account for subcomputations EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); // HeapSimulator accounts for subcomputations. Cond is the largest one. // The output buffer of the while is aliased. EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } -TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) { - // Updating the schedule of an unchanged HLO module should not affect the - // schedule at all. - const string module_str = R"( -HloModule UpdateScheduleUnchanged - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - std::vector entry_schedule = sequence.begin()->second; - - EXPECT_EQ(entry_schedule.size(), 6); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(entry_schedule, sequence.begin()->second); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) { - // Add some additional instructions to a module and verify the schedule can be - // updated. - const string module_str = R"( -HloModule UpdateScheduleWithNewInstructions - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - HloComputation* entry = module->entry_computation(); - const Shape shape = entry->root_instruction()->shape(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kSubtract, constant, entry->root_instruction())); - entry->set_root_instruction(sub); - - auto in_schedule = [&](const HloInstruction* hlo) { - return std::find(sequence.at(entry).begin(), sequence.at(entry).end(), - hlo) != sequence.at(entry).end(); - }; - - EXPECT_EQ(sequence.at(entry).size(), 6); - EXPECT_FALSE(in_schedule(constant)); - EXPECT_FALSE(in_schedule(sub)); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 8); - EXPECT_TRUE(in_schedule(constant)); - EXPECT_TRUE(in_schedule(sub)); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) { - // Add and delete some instructions from a module and verify that the schedule - // can be updated successfully. - const string module_str = R"( -HloModule UpdateScheduleWithAddedAndDeletedInstruction - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - // Set the entry root to some expression containing just a parameter and a - // constant. - HloComputation* entry = module->entry_computation(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - HloInstruction* new_root = entry->AddInstruction( - HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, - constant, entry->parameter_instruction(0))); - entry->set_root_instruction(new_root); - - // DCE should remove everything but the parameters and the newly added code. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(entry).size(), 6); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 4); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) { - // Completely replace a module with an entirely new set of instructions and - // verify that the schedule can be updated successfully. - const string module_str = R"( -HloModule UpdateScheduleWithCompletelyReplacedModule - -ENTRY main { - a = f32[] constant(42.0) - b = f32[] constant(123.0) - ROOT sum = f32[] add(a, b) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - // Replace the entry computation with the negation of a constant. - HloComputation* entry = module->entry_computation(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kNegate, constant)); - entry->set_root_instruction(new_root); - - // DCE the old instructions. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(entry).size(), 3); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 2); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) { - // Create changes to more than one computation in an HLO module and verify - // that the schedule can be updated. - const string module_str = R"( -HloModule UpdateScheduleWithMultipleComputations - -%Body (param.1: (s32[], token[])) -> (s32[], token[]) { - %param.1 = (s32[], token[]) parameter(0) - %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 - %constant.1 = s32[] constant(1) - %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) - %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 - %after-all = token[] after-all(token[] %get-tuple-element.2) - ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) -} - -%Cond (param: (s32[], token[])) -> pred[] { - %param = (s32[], token[]) parameter(0) - %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 - %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) -} - -ENTRY %WhileLoop () -> s32[] { - %zero = s32[] constant(0) - %init_token = token[] after-all() - %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) - %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body - ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), - /*pointer_size=*/sizeof(void*)); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - const HloInstruction* xla_while = - module->entry_computation()->root_instruction()->operand(0); - HloComputation* body = xla_while->while_body(); - HloComputation* cond = xla_while->while_condition(); - - // Negate the root of the cond. - cond->set_root_instruction(cond->AddInstruction( - HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kNot, cond->root_instruction()))); - - // Replace the body with a computation which just passes through its - // parameter. - body->set_root_instruction(body->parameter_instruction(0)); - - // DCE the dead code in the body. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(body).size(), 7); - EXPECT_EQ(sequence.at(cond).size(), 4); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(body).size(), 1); - EXPECT_EQ(sequence.at(cond).size(), 5); -} - } // namespace } // namespace xla -- GitLab From 5032036e1f2a7060848aed64bce94a1f882142d5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 10:48:07 -0700 Subject: [PATCH 125/540] Introduce auxiliary input and allow "cross-linking" in the bidirectional LSTM Op. This introduces a connection between forward and backward cells across subsequent layers when stacking bidirectional LSTM Ops on top of each other. In more detail: Previously, the Op had only one input that was fed into the layer in the following way: INPUT (INPUT_REVERSED) | | ----------------------- | FW_LSTM BW_LSTM | <----- bidi-LSTM cell (with one input / two outputs) ----------------------- | | FW_OUT BW_OUT Now, the Op can have an (optional) auxiliary input in the following way: AUX_INPUT (AUX_INPUT_REVERSED) | | INPUT | (INPUT_R'D.)| | | | | ------------------------- | \ / \ / | | FW_LSTM BW_LSTM | <----- bidi-LSMT cell (with 2 inputs / 2 outputs) ------------------------- | | FW_OUT BW_OUT When stacking these Ops, previously, only the following flow was allowed: Input / \ FW_LSTM1 BW_LSTM1 | | | | FW_LSTM2 BW_LSTM2 | | | | FW_LSTM3 BW_LSTM3 \ / Output With the introduction of an auxiliary input to the bidi-LSTM layer, the forward (FW_LSTMi) output of the ith layer is fed into as the input to the next layer (hence, inputs to both FW_LSTM{i+1} and BW_LSTM{i+1}) and the backward output is fed as the auxiliary inputs to both FW_LSTM{i+1} and BW_LSTM{i+1}). This way, the stacking can be changed to allow for the "cross-linking" between subsequent layer in the following way: Input / \ FW_LSTM1 BW_LSTM1 | \ / | | / \ | FW_LSTM2 BW_LSTM2 | \ / | | / \ | FW_LSTM3 BW_LSTM3 \ / Output PiperOrigin-RevId: 211659472 --- .../kernels/bidirectional_sequence_lstm.cc | 348 ++++++++++++++---- .../bidirectional_sequence_lstm_test.cc | 70 ++++ .../lite/kernels/internal/kernel_utils.cc | 39 +- .../lite/kernels/internal/kernel_utils.h | 17 +- 4 files changed, 368 insertions(+), 106 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index cde4f55a16..6b8ecdd5c3 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -104,6 +104,19 @@ constexpr int kBwInputActivationStateTensor = 37; // Cell state tensors of size {n_batch, n_cell} constexpr int kBwInputCellStateTensor = 38; +// Auxiliary input and weights when stacking. +constexpr int kAuxInputTensor = 39; // Optional +// Forward weights. +constexpr int kFwAuxInputToInputWeightsTensor = 40; // Optional +constexpr int kFwAuxInputToForgetWeightsTensor = 41; // Optional +constexpr int kFwAuxInputToCellWeightsTensor = 42; // Optional +constexpr int kFwAuxInputToOutputWeightsTensor = 43; // Optional +// Backward weights. +constexpr int kBwAuxInputToInputWeightsTensor = 44; // Optional +constexpr int kBwAuxInputToForgetWeightsTensor = 45; // Optional +constexpr int kBwAuxInputToCellWeightsTensor = 46; // Optional +constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional + // Output tensors. constexpr int kFwOutputTensor = 0; constexpr int kBwOutputTensor = 1; @@ -115,14 +128,15 @@ enum TemporaryTensor { kBwScratchBuffer = 1, // Quantized tensors needed for the hybrid kernel. kInputQuantized = 2, - kFwActivationStateQuantized = 3, - kBwActivationStateQuantized = 4, - kFwCellStateQuantized = 5, - kBwCellStateQuantized = 6, - kScalingFactors = 7, - kProductScalingFactors = 8, - kRecoveredCellWeights = 9, - kNumTemporaryTensors = 10 + kAuxInputQuantized = 3, // Quantized tensor needed for auxiliary input. + kFwActivationStateQuantized = 4, + kBwActivationStateQuantized = 5, + kFwCellStateQuantized = 6, + kBwCellStateQuantized = 7, + kScalingFactors = 8, + kProductScalingFactors = 9, + kRecoveredCellWeights = 10, + kNumTemporaryTensors = 11 }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -335,7 +349,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { int* scratch_tensor_index = reinterpret_cast(node->user_data); // Check we have all the inputs and outputs we need. - TF_LITE_ENSURE_EQ(context, node->inputs->size, 39); + TF_LITE_ENSURE_EQ(context, node->inputs->size, 48); TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); // Inferring batch size, number of outputs and sequence length and @@ -366,6 +380,48 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context, CheckInputTensorDimensions(context, node, n_input, n_fw_output, n_fw_cell)); + // Get (optional) auxiliary inputs and weights. + const TfLiteTensor* aux_input = + GetOptionalInputTensor(context, node, kAuxInputTensor); + const TfLiteTensor* fw_aux_input_to_input_weights = + GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor); + const TfLiteTensor* fw_aux_input_to_forget_weights = + GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor); + const TfLiteTensor* fw_aux_input_to_cell_weights = + GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor); + const TfLiteTensor* fw_aux_input_to_output_weights = + GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor); + const TfLiteTensor* bw_aux_input_to_input_weights = + GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor); + const TfLiteTensor* bw_aux_input_to_forget_weights = + GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor); + const TfLiteTensor* bw_aux_input_to_cell_weights = + GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor); + const TfLiteTensor* bw_aux_input_to_output_weights = + GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor); + + const bool aux_inputs_all_or_none = + ((aux_input != nullptr) && (fw_aux_input_to_cell_weights != nullptr) && + (fw_aux_input_to_forget_weights != nullptr) && + (fw_aux_input_to_output_weights != nullptr) && + (bw_aux_input_to_cell_weights != nullptr) && + (bw_aux_input_to_forget_weights != nullptr) && + (bw_aux_input_to_output_weights != nullptr)) || + ((fw_aux_input_to_cell_weights == nullptr) && + (fw_aux_input_to_forget_weights == nullptr) && + (fw_aux_input_to_output_weights == nullptr) && + (bw_aux_input_to_cell_weights == nullptr) && + (bw_aux_input_to_forget_weights == nullptr) && + (bw_aux_input_to_output_weights == nullptr)); + TF_LITE_ENSURE(context, aux_inputs_all_or_none); + const bool has_aux_input = (aux_input != nullptr); + + if (has_aux_input) { + // Check that aux_input has the same dimensions (except last) as the input. + TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]); + TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]); + } + // Get the pointer to output, activation_state and cell_state buffer tensors. TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); TfLiteTensor* fw_activation_state = @@ -406,6 +462,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* fw_input_to_input_weights = GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor); + if (has_aux_input) { + TF_LITE_ENSURE_EQ(context, fw_aux_input_to_input_weights->dims->data[0], + fw_input_to_input_weights->dims->data[0]); + } const bool fw_use_cifg = (fw_input_to_input_weights == nullptr); TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2); fw_scratch_buffer_size->data[0] = n_batch; @@ -470,6 +530,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* bw_input_to_input_weights = GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor); + if (has_aux_input) { + TF_LITE_ENSURE_EQ(context, bw_aux_input_to_input_weights->dims->data[0], + bw_input_to_input_weights->dims->data[0]); + } const bool bw_use_cifg = (bw_input_to_input_weights == nullptr); TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2); bw_scratch_buffer_size->data[0] = n_batch; @@ -483,8 +547,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer, bw_scratch_buffer_size)); if (is_hybrid_op) { - // Allocate temporary tensors to store quantized values of input, - // output_state and cell_state tensors. + // Allocate temporary tensors to store quantized values of input, aux_input + // (if present), activation_state and cell_state tensors. node->temporaries->data[kInputQuantized] = *scratch_tensor_index + kInputQuantized; TfLiteTensor* input_quantized = @@ -497,6 +561,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { input_quantized_size)); } + if (has_aux_input) { + node->temporaries->data[kAuxInputQuantized] = + *scratch_tensor_index + kAuxInputQuantized; + TfLiteTensor* aux_input_quantized = + GetTemporary(context, node, kAuxInputQuantized); + aux_input_quantized->type = kTfLiteUInt8; + aux_input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) { + TfLiteIntArray* aux_input_quantized_size = + TfLiteIntArrayCopy(aux_input->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, aux_input_quantized, + aux_input_quantized_size)); + } + } + node->temporaries->data[kFwActivationStateQuantized] = *scratch_tensor_index + kFwActivationStateQuantized; TfLiteTensor* fw_activation_state_quantized = @@ -617,7 +697,11 @@ TfLiteStatus EvalFloat( const TfLiteTensor* recurrent_to_output_weights, const TfLiteTensor* cell_to_input_weights, const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input, + const TfLiteTensor* aux_input_to_input_weights, + const TfLiteTensor* aux_input_to_forget_weights, + const TfLiteTensor* aux_input_to_cell_weights, + const TfLiteTensor* aux_input_to_output_weights, const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, @@ -627,6 +711,7 @@ TfLiteStatus EvalFloat( const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; + const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0; // n_cell and n_output will be the same size when there is no projection. const int n_cell = input_to_output_weights->dims->data[0]; @@ -671,25 +756,41 @@ TfLiteStatus EvalFloat( const float* projection_bias_ptr = (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + float* aux_input_ptr = nullptr; + float* aux_input_to_input_weights_ptr = nullptr; + float* aux_input_to_forget_weights_ptr = nullptr; + float* aux_input_to_cell_weights_ptr = nullptr; + float* aux_input_to_output_weights_ptr = nullptr; + if (aux_input_size > 0) { + aux_input_ptr = aux_input->data.f; + aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f; + aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f; + aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f; + aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f; + } + // Loop through the sequence. if (forward_sequence) { for (int t = 0; t < max_time; t++) { const float* input_ptr = input->data.f + t * n_batch * n_input; float* output_ptr_time = output->data.f + t * n_batch * n_output; - kernel_utils::LstmStep( + kernel_utils::LstmStepWithAuxInput( input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f, input_to_cell_weights->data.f, - input_to_output_weights->data.f, recurrent_to_input_weights_ptr, - recurrent_to_forget_weights->data.f, + input_to_output_weights->data.f, aux_input_ptr, + aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr, + aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr, + recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f, recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, cell_to_forget_weights_ptr, cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, - params, n_batch, n_cell, n_input, n_output, activation_state->data.f, - cell_state->data.f, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, output_ptr_time); + params, n_batch, n_cell, n_input, aux_input_size, n_output, + activation_state->data.f, cell_state->data.f, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, + output_ptr_time); } } else { // Loop through the sequence backwards. @@ -697,19 +798,22 @@ TfLiteStatus EvalFloat( const float* input_ptr = input->data.f + t * n_batch * n_input; float* output_ptr_time = output->data.f + t * n_batch * n_output; - kernel_utils::LstmStep( + kernel_utils::LstmStepWithAuxInput( input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f, input_to_cell_weights->data.f, - input_to_output_weights->data.f, recurrent_to_input_weights_ptr, - recurrent_to_forget_weights->data.f, + input_to_output_weights->data.f, aux_input_ptr, + aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr, + aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr, + recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f, recurrent_to_output_weights->data.f, cell_to_input_weights_ptr, cell_to_forget_weights_ptr, cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f, output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr, - params, n_batch, n_cell, n_input, n_output, activation_state->data.f, - cell_state->data.f, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, output_ptr_time); + params, n_batch, n_cell, n_input, aux_input_size, n_output, + activation_state->data.f, cell_state->data.f, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, + output_ptr_time); } } return kTfLiteOk; @@ -726,19 +830,25 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* recurrent_to_output_weights, const TfLiteTensor* cell_to_input_weights, const TfLiteTensor* cell_to_forget_weights, - const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input, + const TfLiteTensor* aux_input_to_input_weights, + const TfLiteTensor* aux_input_to_forget_weights, + const TfLiteTensor* aux_input_to_cell_weights, + const TfLiteTensor* aux_input_to_output_weights, const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, const TfLiteLSTMParams* params, bool forward_sequence, TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, - TfLiteTensor* input_quantized, TfLiteTensor* output_state_quantized, - TfLiteTensor* cell_state_quantized, TfLiteTensor* output_state, - TfLiteTensor* cell_state, TfLiteTensor* output) { + TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { const int max_time = input->dims->data[0]; const int n_batch = input->dims->data[1]; const int n_input = input->dims->data[2]; + const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0; // n_cell and n_output will be the same size when there is no projection. const int n_cell = input_to_output_weights->dims->data[0]; const int n_output = recurrent_to_output_weights->dims->data[1]; @@ -842,6 +952,10 @@ TfLiteStatus EvalHybrid( // Temporary storage for quantized values and scaling factors. int8_t* quantized_input_ptr = reinterpret_cast(input_quantized->data.uint8); + int8_t* quantized_aux_input_ptr = + (aux_input_quantized == nullptr) + ? nullptr + : reinterpret_cast(aux_input_quantized->data.uint8); int8_t* quantized_output_state_ptr = reinterpret_cast(output_state_quantized->data.uint8); int8_t* quantized_cell_state_ptr = @@ -850,31 +964,63 @@ TfLiteStatus EvalHybrid( float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + // Auxiliary input and weights. + float* aux_input_ptr = nullptr; + int8_t* aux_input_to_input_weights_ptr = nullptr; + int8_t* aux_input_to_forget_weights_ptr = nullptr; + int8_t* aux_input_to_cell_weights_ptr = nullptr; + int8_t* aux_input_to_output_weights_ptr = nullptr; + float aux_input_to_input_weights_scale = 0.0f; + float aux_input_to_forget_weights_scale = 0.0f; + float aux_input_to_cell_weights_scale = 0.0f; + float aux_input_to_output_weights_scale = 0.0f; + if (aux_input_size > 0) { + aux_input_ptr = aux_input->data.f; + aux_input_to_input_weights_ptr = + reinterpret_cast(aux_input_to_input_weights->data.uint8); + aux_input_to_forget_weights_ptr = + reinterpret_cast(aux_input_to_forget_weights->data.uint8); + aux_input_to_cell_weights_ptr = + reinterpret_cast(aux_input_to_cell_weights->data.uint8); + aux_input_to_output_weights_ptr = + reinterpret_cast(aux_input_to_output_weights->data.uint8); + aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale; + aux_input_to_forget_weights_scale = + aux_input_to_forget_weights->params.scale; + aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale; + aux_input_to_output_weights_scale = + aux_input_to_output_weights->params.scale; + } if (forward_sequence) { // Feed the sequence into the LSTM step-by-step. for (int t = 0; t < max_time; t++) { const float* input_ptr = input->data.f + t * n_batch * n_input; float* output_ptr = output->data.f + t * n_batch * n_output; - kernel_utils::LstmStep( + kernel_utils::LstmStepWithAuxInput( input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, input_to_forget_weights_ptr, input_to_forget_weights_scale, input_to_cell_weights_ptr, input_to_cell_weights_scale, input_to_output_weights_ptr, input_to_output_weights_scale, - recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, - recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, - recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, - recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, - cell_to_input_weights_ptr, cell_to_input_weights_scale, - cell_to_forget_weights_ptr, cell_to_forget_weights_scale, - cell_to_output_weights_ptr, cell_to_output_weights_scale, - input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, - output_gate_bias_ptr, projection_weights_ptr, - projection_weights_scale, projection_bias_ptr, params, n_batch, - n_cell, n_input, n_output, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, scaling_factors_ptr, - prod_scaling_factors_ptr, recovered_cell_weights_ptr, - quantized_input_ptr, quantized_output_state_ptr, + aux_input_ptr, aux_input_to_input_weights_ptr, + aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, + aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, + aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, + aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, + recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, + recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, + recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, + recurrent_to_output_weights_scale, cell_to_input_weights_ptr, + cell_to_input_weights_scale, cell_to_forget_weights_ptr, + cell_to_forget_weights_scale, cell_to_output_weights_ptr, + cell_to_output_weights_scale, input_gate_bias_ptr, + forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, + projection_weights_ptr, projection_weights_scale, projection_bias_ptr, + params, n_batch, n_cell, n_input, aux_input_size, n_output, + input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, + recovered_cell_weights_ptr, quantized_input_ptr, + quantized_aux_input_ptr, quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr); } @@ -884,25 +1030,30 @@ TfLiteStatus EvalHybrid( const float* input_ptr = input->data.f + t * n_batch * n_input; float* output_ptr = output->data.f + t * n_batch * n_output; - kernel_utils::LstmStep( + kernel_utils::LstmStepWithAuxInput( input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, input_to_forget_weights_ptr, input_to_forget_weights_scale, input_to_cell_weights_ptr, input_to_cell_weights_scale, input_to_output_weights_ptr, input_to_output_weights_scale, - recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, - recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, - recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, - recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, - cell_to_input_weights_ptr, cell_to_input_weights_scale, - cell_to_forget_weights_ptr, cell_to_forget_weights_scale, - cell_to_output_weights_ptr, cell_to_output_weights_scale, - input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, - output_gate_bias_ptr, projection_weights_ptr, - projection_weights_scale, projection_bias_ptr, params, n_batch, - n_cell, n_input, n_output, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, scaling_factors_ptr, - prod_scaling_factors_ptr, recovered_cell_weights_ptr, - quantized_input_ptr, quantized_output_state_ptr, + aux_input_ptr, aux_input_to_input_weights_ptr, + aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr, + aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr, + aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr, + aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr, + recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr, + recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr, + recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr, + recurrent_to_output_weights_scale, cell_to_input_weights_ptr, + cell_to_input_weights_scale, cell_to_forget_weights_ptr, + cell_to_forget_weights_scale, cell_to_output_weights_ptr, + cell_to_output_weights_scale, input_gate_bias_ptr, + forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, + projection_weights_ptr, projection_weights_scale, projection_bias_ptr, + params, n_batch, n_cell, n_input, aux_input_size, n_output, + input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, + recovered_cell_weights_ptr, quantized_input_ptr, + quantized_aux_input_ptr, quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr); } @@ -1004,17 +1155,39 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* bw_projection_bias = GetOptionalInputTensor(context, node, kBwProjectionBiasTensor); + // State tensors. TfLiteTensor* bw_activation_state = GetVariableInput(context, node, kBwInputActivationStateTensor); TfLiteTensor* bw_cell_state = GetVariableInput(context, node, kBwInputCellStateTensor); TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); + // Temporary tensors. TfLiteTensor* fw_scratch_buffer = GetTemporary(context, node, kFwScratchBuffer); TfLiteTensor* bw_scratch_buffer = GetTemporary(context, node, kBwScratchBuffer); + // (Optional) auxiliary inputs. + const TfLiteTensor* aux_input = + GetOptionalInputTensor(context, node, kAuxInputTensor); + const TfLiteTensor* fw_aux_input_to_input_weights = + GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor); + const TfLiteTensor* fw_aux_input_to_forget_weights = + GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor); + const TfLiteTensor* fw_aux_input_to_cell_weights = + GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor); + const TfLiteTensor* fw_aux_input_to_output_weights = + GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor); + const TfLiteTensor* bw_aux_input_to_input_weights = + GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor); + const TfLiteTensor* bw_aux_input_to_forget_weights = + GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor); + const TfLiteTensor* bw_aux_input_to_cell_weights = + GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor); + const TfLiteTensor* bw_aux_input_to_output_weights = + GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor); + switch (fw_input_to_output_weights->type) { case kTfLiteFloat32: { TfLiteStatus fw_pass_status = EvalFloat( @@ -1023,10 +1196,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, fw_cell_to_input_weights, fw_cell_to_forget_weights, - fw_cell_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, - fw_cell_bias, fw_output_gate_bias, fw_projection_weights, - fw_projection_bias, params, /*forward_sequence=*/true, - fw_scratch_buffer, fw_activation_state, fw_cell_state, fw_output); + fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights, + fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights, + fw_aux_input_to_output_weights, fw_input_gate_bias, + fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias, + fw_projection_weights, fw_projection_bias, params, + /*forward_sequence=*/true, fw_scratch_buffer, fw_activation_state, + fw_cell_state, fw_output); TF_LITE_ENSURE_OK(context, fw_pass_status); TfLiteStatus bw_pass_status = EvalFloat( @@ -1035,16 +1211,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, bw_cell_to_input_weights, bw_cell_to_forget_weights, - bw_cell_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, - bw_cell_bias, bw_output_gate_bias, bw_projection_weights, - bw_projection_bias, params, /*forward_sequence=*/false, - bw_scratch_buffer, bw_activation_state, bw_cell_state, bw_output); + bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights, + bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights, + bw_aux_input_to_output_weights, bw_input_gate_bias, + bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias, + bw_projection_weights, bw_projection_bias, params, + /*forward_sequence=*/false, bw_scratch_buffer, bw_activation_state, + bw_cell_state, bw_output); TF_LITE_ENSURE_OK(context, bw_pass_status); return kTfLiteOk; } case kTfLiteUInt8: { TfLiteTensor* input_quantized = GetTemporary(context, node, kInputQuantized); + TfLiteTensor* aux_input_quantized = + GetTemporary(context, node, kAuxInputQuantized); TfLiteTensor* fw_activation_state_quantized = GetTemporary(context, node, kFwActivationStateQuantized); TfLiteTensor* bw_activation_state_quantized = @@ -1059,19 +1240,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetTemporary(context, node, kProductScalingFactors); TfLiteTensor* recovered_cell_weights = GetTemporary(context, node, kRecoveredCellWeights); + TfLiteStatus fw_pass_status = EvalHybrid( input, fw_input_to_input_weights, fw_input_to_forget_weights, fw_input_to_cell_weights, fw_input_to_output_weights, fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights, fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights, fw_cell_to_input_weights, fw_cell_to_forget_weights, - fw_cell_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, - fw_cell_bias, fw_output_gate_bias, fw_projection_weights, - fw_projection_bias, params, /*forward_sequence=*/true, - fw_scratch_buffer, scaling_factors, prod_scaling_factors, - recovered_cell_weights, input_quantized, - fw_activation_state_quantized, fw_cell_state_quantized, - fw_activation_state, fw_cell_state, fw_output); + fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights, + fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights, + fw_aux_input_to_output_weights, fw_input_gate_bias, + fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias, + fw_projection_weights, fw_projection_bias, params, + /*forward_sequence=*/true, fw_scratch_buffer, scaling_factors, + prod_scaling_factors, recovered_cell_weights, input_quantized, + aux_input_quantized, fw_activation_state_quantized, + fw_cell_state_quantized, fw_activation_state, fw_cell_state, + fw_output); TF_LITE_ENSURE_OK(context, fw_pass_status); TfLiteStatus bw_pass_status = EvalHybrid( @@ -1080,13 +1265,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights, bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights, bw_cell_to_input_weights, bw_cell_to_forget_weights, - bw_cell_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, - bw_cell_bias, bw_output_gate_bias, bw_projection_weights, - bw_projection_bias, params, /*forward_sequence=*/false, - bw_scratch_buffer, scaling_factors, prod_scaling_factors, - recovered_cell_weights, input_quantized, - bw_activation_state_quantized, bw_cell_state_quantized, - bw_activation_state, bw_cell_state, bw_output); + bw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights, + fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights, + fw_aux_input_to_output_weights, bw_input_gate_bias, + bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias, + bw_projection_weights, bw_projection_bias, params, + /*forward_sequence=*/false, bw_scratch_buffer, scaling_factors, + prod_scaling_factors, recovered_cell_weights, input_quantized, + aux_input_quantized, bw_activation_state_quantized, + bw_cell_state_quantized, bw_activation_state, bw_cell_state, + bw_output); TF_LITE_ENSURE_OK(context, bw_pass_status); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc index d058fab529..74ba8021c2 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc @@ -177,6 +177,16 @@ class BidirectionalLSTMOpModel : public SingleOpModel { bw_output_ = AddOutput(TensorType_FLOAT32); + aux_input_ = AddNullInput(); + fw_aux_input_to_input_weights_ = AddNullInput(); + fw_aux_input_to_forget_weights_ = AddNullInput(); + fw_aux_input_to_cell_weights_ = AddNullInput(); + fw_aux_input_to_output_weights_ = AddNullInput(); + bw_aux_input_to_input_weights_ = AddNullInput(); + bw_aux_input_to_forget_weights_ = AddNullInput(); + bw_aux_input_to_cell_weights_ = AddNullInput(); + bw_aux_input_to_output_weights_ = AddNullInput(); + SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, BuiltinOptions_LSTMOptions, CreateLSTMOptions(builder_, ActivationFunctionType_TANH, @@ -340,6 +350,16 @@ class BidirectionalLSTMOpModel : public SingleOpModel { int fw_output_; int bw_output_; + int aux_input_; + int fw_aux_input_to_input_weights_; + int fw_aux_input_to_forget_weights_; + int fw_aux_input_to_cell_weights_; + int fw_aux_input_to_output_weights_; + int bw_aux_input_to_input_weights_; + int bw_aux_input_to_forget_weights_; + int bw_aux_input_to_cell_weights_; + int bw_aux_input_to_output_weights_; + int n_batch_; int n_input_; int n_fw_cell_; @@ -415,6 +435,16 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor + + {n_batch, sequence_length, 0}, // aux_input tensor + {n_cell, 0}, // aux_fw_input_to_input tensor + {n_cell, 0}, // aux_fw_input_to_forget tensor + {n_cell, 0}, // aux_fw_input_to_cell tensor + {n_cell, 0}, // aux_fw_input_to_output tensor + {n_cell, 0}, // aux_bw_input_to_input tensor + {n_cell, 0}, // aux_bw_input_to_forget tensor + {n_cell, 0}, // aux_bw_input_to_cell tensor + {n_cell, 0}, // aux_bw_input_to_output tensor }); lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, @@ -562,6 +592,16 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) { {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor + + {n_batch, sequence_length, 0}, // aux_input tensor + {n_cell, 0}, // aux_fw_input_to_input tensor + {n_cell, 0}, // aux_fw_input_to_forget tensor + {n_cell, 0}, // aux_fw_input_to_cell tensor + {n_cell, 0}, // aux_fw_input_to_output tensor + {n_cell, 0}, // aux_bw_input_to_input tensor + {n_cell, 0}, // aux_bw_input_to_forget tensor + {n_cell, 0}, // aux_bw_input_to_cell tensor + {n_cell, 0}, // aux_bw_input_to_output tensor }); lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, @@ -709,6 +749,16 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor + + {n_batch, sequence_length, 0}, // aux_input tensor + {n_cell, 0}, // aux_fw_input_to_input tensor + {n_cell, 0}, // aux_fw_input_to_forget tensor + {n_cell, 0}, // aux_fw_input_to_cell tensor + {n_cell, 0}, // aux_fw_input_to_output tensor + {n_cell, 0}, // aux_bw_input_to_input tensor + {n_cell, 0}, // aux_bw_input_to_forget tensor + {n_cell, 0}, // aux_bw_input_to_cell tensor + {n_cell, 0}, // aux_bw_input_to_output tensor }); lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, @@ -848,6 +898,16 @@ TEST(LSTMOpTest, {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor + + {n_batch, sequence_length, 0}, // aux_input tensor + {n_cell, 0}, // aux_fw_input_to_input tensor + {n_cell, 0}, // aux_fw_input_to_forget tensor + {n_cell, 0}, // aux_fw_input_to_cell tensor + {n_cell, 0}, // aux_fw_input_to_output tensor + {n_cell, 0}, // aux_bw_input_to_input tensor + {n_cell, 0}, // aux_bw_input_to_forget tensor + {n_cell, 0}, // aux_bw_input_to_cell tensor + {n_cell, 0}, // aux_bw_input_to_output tensor }); lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, @@ -987,6 +1047,16 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { {n_batch, n_output}, // activation_state tensor {n_batch, n_cell}, // cell_state tensor + + {n_batch, sequence_length, 0}, // aux_input tensor + {n_cell, 0}, // aux_fw_input_to_input tensor + {n_cell, 0}, // aux_fw_input_to_forget tensor + {n_cell, 0}, // aux_fw_input_to_cell tensor + {n_cell, 0}, // aux_fw_input_to_output tensor + {n_cell, 0}, // aux_bw_input_to_input tensor + {n_cell, 0}, // aux_bw_input_to_forget tensor + {n_cell, 0}, // aux_bw_input_to_cell tensor + {n_cell, 0}, // aux_bw_input_to_output tensor }); lstm.SetInputToInputWeights( diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 360b472c45..b9dd40ddf9 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -203,9 +203,9 @@ void LstmStep( cell_to_input_weights_ptr, cell_to_forget_weights_ptr, cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, - projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - output_state_ptr, cell_state_ptr, input_gate_scratch, forget_gate_scratch, - cell_scratch, output_gate_scratch, output_ptr_batch); + projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0, + n_output, output_state_ptr, cell_state_ptr, input_gate_scratch, + forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch); } void LstmStepWithAuxInput( @@ -227,8 +227,8 @@ void LstmStepWithAuxInput( const float* forget_gate_bias_ptr, const float* cell_bias_ptr, const float* output_gate_bias_ptr, const float* projection_weights_ptr, const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr, - float* cell_state_ptr, float* input_gate_scratch, + int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, + float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch) { // Since we have already checked that weights are all there or none, we can @@ -268,19 +268,20 @@ void LstmStepWithAuxInput( if (aux_input_ptr_batch != nullptr) { if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_input_weights_ptr, n_cell, n_input, aux_input_ptr_batch, - n_batch, input_gate_scratch, /*result_stride=*/1); + aux_input_to_input_weights_ptr, n_cell, n_aux_input, + aux_input_ptr_batch, n_batch, input_gate_scratch, + /*result_stride=*/1); } tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_forget_weights_ptr, n_cell, n_input, aux_input_ptr_batch, - n_batch, forget_gate_scratch, /*result_stride=*/1); + aux_input_to_forget_weights_ptr, n_cell, n_aux_input, + aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1); tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_cell_weights_ptr, n_cell, n_input, aux_input_ptr_batch, + aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch, n_batch, cell_scratch, /*result_stride=*/1); tensor_utils::MatrixBatchVectorMultiplyAccumulate( - aux_input_to_output_weights_ptr, n_cell, n_input, aux_input_ptr_batch, - n_batch, output_gate_scratch, /*result_stride=*/1); + aux_input_to_output_weights_ptr, n_cell, n_aux_input, + aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1); } // For each batch and cell: compute recurrent_weight * output_state. @@ -432,10 +433,11 @@ void LstmStep( cell_to_output_weights_ptr, cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, - projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, - input_gate_scratch, forget_gate_scratch, cell_scratch, - output_gate_scratch, scaling_factors, product_scaling_factors, - recovered_cell_weights, quantized_input_ptr_batch, + projection_bias_ptr, params, n_batch, n_cell, n_input, + /*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch, + cell_scratch, output_gate_scratch, scaling_factors, + product_scaling_factors, recovered_cell_weights, + quantized_input_ptr_batch, /*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr_batch); @@ -476,8 +478,9 @@ void LstmStep( const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, float projection_weights_scale, const float* projection_bias_ptr, const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, - int n_output, float* input_gate_scratch, float* forget_gate_scratch, - float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + int n_aux_input, int n_output, float* input_gate_scratch, + float* forget_gate_scratch, float* cell_scratch, + float* output_gate_scratch, float* scaling_factors, float* product_scaling_factors, float* recovered_cell_weights, int8_t* quantized_input_ptr_batch, int8_t* quantized_aux_input_ptr_batch, diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index 38436c1382..215ad04add 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -131,8 +131,8 @@ void LstmStepWithAuxInput( const float* forget_gate_bias_ptr, const float* cell_bias_ptr, const float* output_gate_bias_ptr, const float* projection_weights_ptr, const float* projection_bias_ptr, const TfLiteLSTMParams* params, - int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr, - float* cell_state_ptr, float* input_gate_scratch, + int n_batch, int n_cell, int n_input, int n_aux_input, int n_output, + float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch); @@ -252,12 +252,13 @@ void LstmStepWithAuxInput( const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, float projection_weights_scale, const float* projection_bias_ptr, const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, - int n_output, float* input_gate_scratch, float* forget_gate_scratch, - float* cell_scratch, float* output_gate_scratch, float* scaling_factors, - float* product_scaling_factors, float* recovered_cell_weights, - int8_t* quantized_input_ptr_batch, int8_t* quantized_aux_input_ptr_batch, - int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr, - float* output_state_ptr, float* cell_state_ptr, float* output_ptr_batch); + int n_aux_input, int n_output, float* input_gate_scratch, + float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, + float* scaling_factors, float* product_scaling_factors, + float* recovered_cell_weights, int8_t* quantized_input_ptr_batch, + int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch); } // namespace kernel_utils } // namespace tflite -- GitLab From 587808a8ad12fdb20270bb4fefbf85a48702383b Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Wed, 5 Sep 2018 10:50:18 -0700 Subject: [PATCH 126/540] test_util.py: Allow use_gpu to change between calls to self.cached_session() use_gpu does not affect the creation of the session, it only affects the context manager in which nodes are added to the graph, so it should not be included in the consistency check. PiperOrigin-RevId: 211659833 --- tensorflow/python/framework/test_util.py | 156 ++++++++---------- tensorflow/python/framework/test_util_test.py | 3 - 2 files changed, 66 insertions(+), 93 deletions(-) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 3b63e49a84..0925598e33 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1073,13 +1073,9 @@ class TensorFlowTestCase(googletest.TestCase): if context.executing_eagerly(): yield None else: - sess = self._create_session(graph, config, use_gpu, force_gpu) - with self._constrain_devices_and_set_default( - sess, use_gpu, force_gpu) as constrained_sess: - # We need to do this to make sure the session closes, otherwise, even - # if the user does with self.session():, it will not close the session. - with constrained_sess: - yield constrained_sess + with self._create_session(graph, config, force_gpu) as sess: + with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu): + yield sess @contextlib.contextmanager def cached_session(self, @@ -1127,10 +1123,11 @@ class TensorFlowTestCase(googletest.TestCase): if context.executing_eagerly(): yield None else: - with self._get_cached_session( - graph, config, use_gpu, force_gpu, - crash_if_inconsistent_args=True) as sess: - yield sess + sess = self._get_cached_session( + graph, config, force_gpu, crash_if_inconsistent_args=True) + with self._constrain_devices_and_set_default(sess, use_gpu, + force_gpu) as cached: + yield cached @contextlib.contextmanager def test_session(self, @@ -1146,10 +1143,11 @@ class TensorFlowTestCase(googletest.TestCase): yield None else: if graph is None: - with self._get_cached_session( - graph, config, use_gpu, force_gpu, - crash_if_inconsistent_args=False) as sess: - yield sess + sess = self._get_cached_session( + graph, config, force_gpu, crash_if_inconsistent_args=False) + with self._constrain_devices_and_set_default(sess, use_gpu, + force_gpu) as cached: + yield cached else: with self.session(graph, config, use_gpu, force_gpu) as sess: yield sess @@ -1835,91 +1833,69 @@ class TensorFlowTestCase(googletest.TestCase): with sess.graph.device("/cpu:0"): yield sess - def _create_session(self, graph, config, use_gpu, force_gpu): + def _create_session(self, graph, config, force_gpu): """See session() for details.""" - if context.executing_eagerly(): - return None - else: + def prepare_config(config): + """Returns a config for sessions. - def prepare_config(config): - """Returns a config for sessions. - - Args: - config: An optional config_pb2.ConfigProto to use to configure the - session. - Returns: - A config_pb2.ConfigProto object. - """ - if config is None: - config = config_pb2.ConfigProto() - config.allow_soft_placement = not force_gpu - config.gpu_options.per_process_gpu_memory_fraction = 0.3 - elif force_gpu and config.allow_soft_placement: - config = config_pb2.ConfigProto().CopyFrom(config) - config.allow_soft_placement = False - # Don't perform optimizations for tests so we don't inadvertently run - # gpu ops on cpu - config.graph_options.optimizer_options.opt_level = -1 - config.graph_options.rewrite_options.constant_folding = ( - rewriter_config_pb2.RewriterConfig.OFF) - config.graph_options.rewrite_options.arithmetic_optimization = ( - rewriter_config_pb2.RewriterConfig.OFF) - return config - - return ErrorLoggingSession(graph=graph, config=prepare_config(config)) + Args: + config: An optional config_pb2.ConfigProto to use to configure the + session. + + Returns: + A config_pb2.ConfigProto object. + """ + if config is None: + config = config_pb2.ConfigProto() + config.allow_soft_placement = not force_gpu + config.gpu_options.per_process_gpu_memory_fraction = 0.3 + elif force_gpu and config.allow_soft_placement: + config = config_pb2.ConfigProto().CopyFrom(config) + config.allow_soft_placement = False + # Don't perform optimizations for tests so we don't inadvertently run + # gpu ops on cpu + config.graph_options.optimizer_options.opt_level = -1 + config.graph_options.rewrite_options.constant_folding = ( + rewriter_config_pb2.RewriterConfig.OFF) + config.graph_options.rewrite_options.arithmetic_optimization = ( + rewriter_config_pb2.RewriterConfig.OFF) + return config + + return ErrorLoggingSession(graph=graph, config=prepare_config(config)) - @contextlib.contextmanager def _get_cached_session(self, graph=None, config=None, - use_gpu=False, force_gpu=False, crash_if_inconsistent_args=True): """See cached_session() for documentation.""" - if context.executing_eagerly(): - yield None + if self._cached_session is None: + sess = self._create_session( + graph=graph, config=config, force_gpu=force_gpu) + self._cached_session = sess + self._cached_graph = graph + self._cached_config = config + self._cached_force_gpu = force_gpu + return sess else: - if self._cached_session is None: - sess = self._create_session( - graph=graph, config=config, use_gpu=use_gpu, force_gpu=force_gpu) - self._cached_session = sess - self._cached_graph = graph - self._cached_config = config - self._cached_use_gpu = use_gpu - self._cached_force_gpu = force_gpu - with self._constrain_devices_and_set_default( - sess, use_gpu, force_gpu) as constrained_sess: - yield constrained_sess - else: - if crash_if_inconsistent_args and self._cached_graph is not graph: - raise ValueError("The graph used to get the cached session is " - "different than the one that was used to create the " - "session. Maybe create a new session with " - "self.session()") - if crash_if_inconsistent_args and self._cached_config is not config: - raise ValueError("The config used to get the cached session is " - "different than the one that was used to create the " - "session. Maybe create a new session with " - "self.session()") - if crash_if_inconsistent_args and self._cached_use_gpu is not use_gpu: - raise ValueError( - "The use_gpu value used to get the cached session is " - "different than the one that was used to create the " - "session. Maybe create a new session with " - "self.session()") - if crash_if_inconsistent_args and (self._cached_force_gpu is - not force_gpu): - raise ValueError( - "The force_gpu value used to get the cached session is " - "different than the one that was used to create the " - "session. Maybe create a new session with " - "self.session()") - # If you modify this logic, make sure to modify it in _create_session - # as well. - sess = self._cached_session - with self._constrain_devices_and_set_default( - sess, use_gpu, force_gpu) as constrained_sess: - yield constrained_sess + if crash_if_inconsistent_args and self._cached_graph is not graph: + raise ValueError("The graph used to get the cached session is " + "different than the one that was used to create the " + "session. Maybe create a new session with " + "self.session()") + if crash_if_inconsistent_args and self._cached_config is not config: + raise ValueError("The config used to get the cached session is " + "different than the one that was used to create the " + "session. Maybe create a new session with " + "self.session()") + if crash_if_inconsistent_args and (self._cached_force_gpu is + not force_gpu): + raise ValueError( + "The force_gpu value used to get the cached session is " + "different than the one that was used to create the " + "session. Maybe create a new session with " + "self.session()") + return self._cached_session @tf_export("test.create_local_cluster") diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index a0939f98b2..c4f8fa9108 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -70,9 +70,6 @@ class TestUtilTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): with self.cached_session(graph=ops.Graph()) as sess2: pass - with self.assertRaises(ValueError): - with self.cached_session(use_gpu=True) as sess2: - pass with self.assertRaises(ValueError): with self.cached_session(force_gpu=True) as sess2: pass -- GitLab From d27c60b1a09dab2a0b35a76d46305c713c0735a6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 10:59:15 -0700 Subject: [PATCH 127/540] libc++ fix: make comparison functors const PiperOrigin-RevId: 211661670 --- tensorflow/core/grappler/graph_analyzer/graph_analyzer.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h index 26d38a4931..97626346c7 100644 --- a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h +++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h @@ -138,7 +138,7 @@ class GraphAnalyzer { // The entries are owned by collation_map_, so must be removed from // ordered_collation_ before removing them from collation_map_. struct ReverseLessByCount { - bool operator()(CollationEntry* left, CollationEntry* right) { + bool operator()(CollationEntry* left, CollationEntry* right) const { return left->count > right->count; // Reverse order. } }; -- GitLab From 6aa8abbb17c06fbaeb9cc4396e58b6cfc33d177f Mon Sep 17 00:00:00 2001 From: Dimitris Vardoulakis Date: Wed, 5 Sep 2018 11:03:06 -0700 Subject: [PATCH 128/540] [TF:XLA] Define DefaultPrecisionConfig in HloTestBase and delete multiple duplicate definitions. PiperOrigin-RevId: 211662523 --- .../compiler/xla/service/algebraic_simplifier_test.cc | 7 ------- .../xla/service/cpu/conv_canonicalization_test.cc | 7 ------- .../xla/service/gpu/cudnn_convolution_rewriter_test.cc | 7 ------- tensorflow/compiler/xla/service/heap_simulator_test.cc | 7 ------- tensorflow/compiler/xla/service/hlo_evaluator_test.cc | 7 ------- tensorflow/compiler/xla/service/hlo_instruction_test.cc | 7 ------- tensorflow/compiler/xla/service/transpose_folding_test.cc | 7 ------- tensorflow/compiler/xla/tests/hlo_test_base.cc | 8 ++++++++ tensorflow/compiler/xla/tests/hlo_test_base.h | 2 ++ tensorflow/compiler/xla/tests/multioutput_fusion_test.cc | 7 ------- 10 files changed, 10 insertions(+), 56 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 019840b476..0db74bd038 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -1013,13 +1013,6 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { 1); } -PrecisionConfigProto DefaultPrecisionConfig(int operands) { - PrecisionConfigProto precision_config; - precision_config.mutable_operand_precision()->Resize( - operands, PrecisionConfigProto::DEFAULT); - return precision_config; -} - TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { auto builder = HloComputation::Builder(TestName()); HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 616c453750..05792795a1 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -56,13 +56,6 @@ class ConvCanonicalizationTest : public HloTestBase { static constexpr int kOutputFeatureCount = 64; }; -PrecisionConfigProto DefaultPrecisionConfig(int operands) { - PrecisionConfigProto precision_config; - precision_config.mutable_operand_precision()->Resize( - operands, PrecisionConfigProto::DEFAULT); - return precision_config; -} - TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in CNHW order. diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index 9b46bfc098..bda8ebe579 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -95,13 +95,6 @@ class CudnnConvolutionRewriterTest : public HloVerifiedTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -PrecisionConfigProto DefaultPrecisionConfig(int operands) { - PrecisionConfigProto precision_config; - precision_config.mutable_operand_precision()->Resize( - operands, PrecisionConfigProto::DEFAULT); - return precision_config; -} - TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 1d98c45567..00a25db467 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -359,13 +359,6 @@ TEST_F(HeapSimulatorTest, BufferReusedOnce) { (neg_buffer == output_buffer_1)); } -PrecisionConfigProto DefaultPrecisionConfig(int operands) { - PrecisionConfigProto precision_config; - precision_config.mutable_operand_precision()->Resize( - operands, PrecisionConfigProto::DEFAULT); - return precision_config; -} - TEST_F(HeapSimulatorTest, MultiplyDot) { auto builder = HloComputation::Builder(TestName()); auto paramA = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index f586f253da..abd4bb1f73 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -622,13 +622,6 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result)); } -PrecisionConfigProto DefaultPrecisionConfig(int operands) { - PrecisionConfigProto precision_config; - precision_config.mutable_operand_precision()->Resize( - operands, PrecisionConfigProto::DEFAULT); - return precision_config; -} - TEST_P(HloEvaluatorTest, DotRank2AndRank1) { HloComputation::Builder b(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index b4e302e832..9eab6eea80 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1122,13 +1122,6 @@ TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) { } } -PrecisionConfigProto DefaultPrecisionConfig(int operands) { - PrecisionConfigProto precision_config; - precision_config.mutable_operand_precision()->Resize( - operands, PrecisionConfigProto::DEFAULT); - return precision_config; -} - TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { // Fused expression: // diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index e486a00e53..79b5c09abb 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -215,13 +215,6 @@ ENTRY entry_computation { /*lhs_contracting_dim=*/1, /*rhs_contracting_dim=*/1)); } -PrecisionConfigProto DefaultPrecisionConfig(int operands) { - PrecisionConfigProto precision_config; - precision_config.mutable_operand_precision()->Resize( - operands, PrecisionConfigProto::DEFAULT); - return precision_config; -} - // Test that a two dimension swap of the kernel gets folded into convolution. TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { auto builder = HloComputation::Builder("entry_computation"); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index fc4c68246e..edab480091 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -120,6 +120,14 @@ StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, return status_or; } +/* static */ +PrecisionConfigProto HloTestBase::DefaultPrecisionConfig(int operands) { + PrecisionConfigProto precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfigProto::DEFAULT); + return precision_config; +} + DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 4c88257bb2..89e72a045e 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -80,6 +80,8 @@ class HloTestBase : public ::testing::Test { static StatusOr RunHloPass(HloPassInterface* hlo_pass, HloModule* module); + static PrecisionConfigProto DefaultPrecisionConfig(int operands); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 53b5e933b6..c5e0b9b097 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -47,13 +47,6 @@ limitations under the License. namespace xla { namespace { -PrecisionConfigProto DefaultPrecisionConfig(int operands) { - PrecisionConfigProto precision_config; - precision_config.mutable_operand_precision()->Resize( - operands, PrecisionConfigProto::DEFAULT); - return precision_config; -} - class MultiOutputFusionTest : public HloTestBase { protected: MultiOutputFusionTest() { error_spec_ = ErrorSpec{0.0001, 1e-2}; } -- GitLab From a6c4916764392819f3692dc0f763472d22b8076f Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 5 Sep 2018 11:08:52 -0700 Subject: [PATCH 129/540] Allow gradients() calls from inside a tfe.defun wrt captured tensors. This modifies https://github.com/tensorflow/tensorflow/commit/834da2c3fddab1bbbce742db572cfe65dd320fcd to work with tfe.defun in addition to the legacy Defun implementation. PiperOrigin-RevId: 211663702 --- tensorflow/python/BUILD | 12 +++++++ tensorflow/python/client/session_test.py | 2 ++ tensorflow/python/eager/BUILD | 2 +- tensorflow/python/eager/function.py | 3 ++ tensorflow/python/ops/gradients.py | 2 +- tensorflow/python/ops/gradients_impl.py | 45 +++++++++++++++++------- tensorflow/python/ops/gradients_test.py | 31 ++++++++-------- 7 files changed, 69 insertions(+), 28 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 5af6437c56..e6169e9e80 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2090,6 +2090,18 @@ py_library( srcs = [ "ops/custom_gradient.py", "ops/gradients.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":gradients_impl", + "//tensorflow/python/eager:function", + "//tensorflow/python/eager:tape", + ], +) + +py_library( + name = "gradients_impl", + srcs = [ "ops/gradients_impl.py", ], srcs_version = "PY2AND3", diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 052be68385..f87a96e547 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -49,6 +49,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_control_flow_ops +# Import gradients to resolve circular imports +from tensorflow.python.ops import gradients # pylint: disable=unused-import from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops # Import resource_variable_ops for the variables-to-tensor implicit conversion. diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 6f48d38b58..85da1baaf0 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -241,7 +241,7 @@ py_library( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", + "//tensorflow/python:gradients_impl", "//tensorflow/python:graph_to_function_def", "//tensorflow/python:util", "//tensorflow/python/eager:context", diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 6c87dccaf1..b57979b484 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -55,6 +55,9 @@ from tensorflow.python.util import tf_inspect # (function -> gradients_impl -> control_flow_ops -> cond_v2_impl). cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access +# This is to avoid a circular dependency with gradients_impl +gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access + def create_substitute_placeholder(value, name, dtype=None): """Creates a placeholder for `value` and propagates shape info to it.""" diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index 9fa8e27d5c..1dc666e78b 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -19,10 +19,10 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import +from tensorflow.python.eager import function from tensorflow.python.eager.backprop import GradientTape from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.gradients_impl import AggregationMethod from tensorflow.python.ops.gradients_impl import gradients from tensorflow.python.ops.gradients_impl import hessians # pylint: enable=unused-import - diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index a68f680224..3268b38b86 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -31,7 +31,7 @@ from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function +from tensorflow.python.framework import function as framework_function from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -58,6 +58,10 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export +# This is to avoid a circular dependency (eager.function depends on +# gradients_impl). This is set in eager/function.py. +_function = None + # This is to avoid a circular dependency with cond_v2_impl. cond_v2_impl._gradients_impl = sys.modules[__name__] # pylint: disable=protected-access @@ -121,7 +125,7 @@ def _MarkReachedOps(from_ops, reached_ops, func_graphs): Args: from_ops: list of Operations. reached_ops: set of Operations. - func_graphs: list of function._FuncGraphs. This method will traverse through + func_graphs: list of _function.FuncGraphs. This method will traverse through these functions if they capture from_ops or any reachable ops. """ queue = collections.deque() @@ -146,7 +150,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs, to_ops: list of Operations. from_ops: list of Operations. colocate_gradients_with_ops: Python bool. See docstring of gradients(). - func_graphs: list of function._FuncGraphs. This method will traverse through + func_graphs: list of _function.FuncGraphs. This method will traverse through these functions if they capture from_ops or any reachable ops. This is useful if to_ops occur in a function and from_ops are in an outer function or graph. @@ -441,6 +445,19 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs): % target_op.name) +def _IsFunction(graph): + return (isinstance(graph, _function.FuncGraph) or + isinstance(graph, framework_function._FuncGraph)) # pylint: disable=protected-access + + +def _Captures(func_graph): + if isinstance(func_graph, _function.FuncGraph): + return func_graph.captures + else: + assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access + return func_graph._captured # pylint: disable=protected-access + + def _MaybeCaptured(t): """If t is a captured value placeholder, returns the original captured value. @@ -448,11 +465,11 @@ def _MaybeCaptured(t): t: Tensor Returns: - A tensor, potentially from a different Graph/function._FuncGraph. + A tensor, potentially from a different Graph/_function.FuncGraph. """ # pylint: disable=protected-access - if isinstance(t.op.graph, function._FuncGraph) and t.op.type == "Placeholder": - for input_t, placeholder_t in t.op.graph._captured.items(): + if _IsFunction(t.op.graph) and t.op.type == "Placeholder": + for input_t, placeholder_t in _Captures(t.op.graph).items(): if t == placeholder_t: return _MaybeCaptured(input_t) # pylint: enable=protected-access @@ -470,10 +487,10 @@ def _Inputs(op, xs): Returns: A list of tensors. The tensors may be from multiple - Graph/function._FuncGraphs if op is in a function._FuncGraph and has + Graph/_function.FuncGraphs if op is in a _function.FuncGraph and has captured inputs. """ - if isinstance(op.graph, function._FuncGraph): # pylint: disable=protected-access + if _IsFunction(op.graph): # pylint: disable=protected-access # If we're differentiating w.r.t. `t`, do not attempt to traverse through it # to a captured value. The algorithm needs to "see" `t` in this case, even # if it's a function input for a captured value, whereas usually we'd like @@ -489,7 +506,7 @@ def _Consumers(t, func_graphs): Args: t: Tensor - func_graphs: a list of function._FuncGraphs that may have captured t. + func_graphs: a list of _function.FuncGraphs that may have captured t. Returns: A list of tensors. The tensors will be from the current graph and/or @@ -497,7 +514,7 @@ def _Consumers(t, func_graphs): """ consumers = t.consumers() for func in func_graphs: - for input_t, placeholder in func._captured.items(): # pylint: disable=protected-access + for input_t, placeholder in _Captures(func).items(): if input_t == t: consumers.extend(_Consumers(placeholder, func_graphs)) return consumers @@ -616,9 +633,13 @@ def _GradientsHelper(ys, # ancestor graphs. This is necessary for correctly handling captured values. func_graphs = [] curr_graph = src_graph - while isinstance(curr_graph, function._FuncGraph): # pylint: disable=protected-access + while _IsFunction(curr_graph): func_graphs.append(curr_graph) - curr_graph = curr_graph._outer_graph # pylint: disable=protected-access + if isinstance(curr_graph, _function.FuncGraph): + curr_graph = curr_graph.outer_graph + else: + assert isinstance(curr_graph, framework_function._FuncGraph) # pylint: disable=protected-access + curr_graph = curr_graph._outer_graph # pylint: disable=protected-access ys = _AsList(ys) xs = _AsList(xs) diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index fa9910b351..3759d8a543 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -26,9 +26,10 @@ import numpy as np from tensorflow.python.client import session from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function +from tensorflow.python.framework import function as framework_function from tensorflow.python.framework import ops from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util @@ -369,8 +370,8 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): @classmethod def _GetFunc(cls, **kwargs): - return function.Defun(dtypes.float32, dtypes.float32, ** - kwargs)(cls.XSquarePlusB) + return framework_function.Defun(dtypes.float32, dtypes.float32, ** + kwargs)(cls.XSquarePlusB) def _GetFuncGradients(self, f, x_value, b_value): x = constant_op.constant(x_value, name="x") @@ -408,8 +409,9 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): def testFunctionGradientsWithGradFunc(self): g = ops.Graph() with g.as_default(): - grad_func = function.Defun(dtypes.float32, dtypes.float32, - dtypes.float32)(self.XSquarePlusBGradient) + grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, + dtypes.float32)( + self.XSquarePlusBGradient) f = self._GetFunc(grad_func=grad_func) # Get gradients (should add SymbolicGradient node for function, which # uses the grad_func above, which multiplies all gradients by 2). @@ -430,8 +432,9 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): def testFunctionGradientWithGradFuncAndRegistration(self): g = ops.Graph() with g.as_default(): - grad_func = function.Defun(dtypes.float32, dtypes.float32, - dtypes.float32)(self.XSquarePlusBGradient) + grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, + dtypes.float32)( + self.XSquarePlusBGradient) with self.assertRaisesRegexp(ValueError, "Gradient defined twice"): f = self._GetFunc( grad_func=grad_func, python_grad_func=self._PythonGradient) @@ -441,7 +444,7 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): with ops.Graph().as_default(): x = constant_op.constant(1.0, name="x") - @function.Defun() + @function.defun() def Foo(): y = math_ops.multiply(x, 2.0, name="y") g = gradients_impl.gradients(y, x) @@ -456,7 +459,7 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): x = constant_op.constant(1.0, name="x") y = math_ops.multiply(x, 2.0, name="y") - @function.Defun() + @framework_function.Defun() def Foo(): g = gradients_impl.gradients(y, x) return g[0] @@ -469,7 +472,7 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): with ops.Graph().as_default(): var = resource_variable_ops.ResourceVariable(1.0, name="var") - @function.Defun() + @function.defun() def Foo(): y = math_ops.multiply(var, 2.0, name="y") g = gradients_impl.gradients(y, var) @@ -486,11 +489,11 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): x2 = constant_op.constant(2.0, name="x2") x3 = math_ops.multiply(x1, x2, name="x3") - @function.Defun() + @function.defun() def Outer(): outer1 = array_ops.identity(x1, name="outer1") - @function.Defun() + @function.defun() def Inner(): inner1 = array_ops.identity(outer1, name="inner1") inner2 = array_ops.identity(x2, name="inner2") @@ -511,11 +514,11 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): with ops.Graph().as_default(): x = constant_op.constant(1.0, name="x") - @function.Defun() + @function.defun() def Outer(): y = math_ops.multiply(x, 2.0, name="y") - @function.Defun() + @function.defun() def Inner(): z = math_ops.multiply(y, 3.0, name="z") g = gradients_impl.gradients(z, y) -- GitLab From 5d60dd9eab07bd02553cf7542641a08b0e3667cb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 11:17:09 -0700 Subject: [PATCH 130/540] Internal change. PiperOrigin-RevId: 211665268 --- tensorflow/core/kernels/gather_nd_op_cpu_impl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h index 66ae7f0894..277ee2be02 100644 --- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h @@ -123,10 +123,10 @@ struct GatherNdSlice { // is considerably more efficient. #pragma omp parallel for for (Eigen::DenseIndex i = 0; i < batch_size; i++) { - const Eigen::array loc = i; + const Eigen::array loc{i}; gather_nd_generator(loc); } -#else +#else // INTEL_MKL Tscratch.device(d) = Tscratch.reshape(reshape_dims) .broadcast(broadcast_dims) .generate(gather_nd_generator) -- GitLab From d3a63ee12b1c8910cf71e87a81e59f998144ce36 Mon Sep 17 00:00:00 2001 From: Michael Case Date: Wed, 5 Sep 2018 11:24:13 -0700 Subject: [PATCH 131/540] Internal Change. PiperOrigin-RevId: 211666438 --- tensorflow/contrib/__init__.py | 8 ++++++++ tensorflow/python/__init__.py | 7 +++++++ tensorflow/python/tools/component_api_helper.py | 5 +++-- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 5f477a79a3..9478e42b46 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -21,6 +21,14 @@ from __future__ import print_function import os +from tensorflow.python.tools import component_api_helper +component_api_helper.package_hook( + parent_package_str=( + "tensorflow.contrib"), + child_package_str=( + "tensorflow_estimator.contrib.estimator")) +del component_api_helper + # Add projects here, they will show up under tf.contrib. from tensorflow.contrib import autograph from tensorflow.contrib import batching diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index a2ab63bb48..4921ecc43c 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -48,6 +48,13 @@ import numpy as np from tensorflow.python import pywrap_tensorflow +from tensorflow.python.tools import component_api_helper +component_api_helper.package_hook( + parent_package_str='tensorflow.python', + child_package_str=( + 'tensorflow_estimator.python.estimator')) +del component_api_helper + # Protocol buffers from tensorflow.core.framework.graph_pb2 import * from tensorflow.core.framework.node_def_pb2 import * diff --git a/tensorflow/python/tools/component_api_helper.py b/tensorflow/python/tools/component_api_helper.py index 988ecc61f0..97f46719e5 100644 --- a/tensorflow/python/tools/component_api_helper.py +++ b/tensorflow/python/tools/component_api_helper.py @@ -65,9 +65,10 @@ def package_hook(parent_package_str, child_package_str, error_msg=None): Will allow the following import statement to work. >>> import parent.child """ - child_pkg_path = [os.path.join(os.path.dirname(child_pkg.__file__), "..")] + child_pkg_path = [os.path.abspath( + os.path.join(os.path.dirname(child_pkg.__file__), ".."))] try: - parent_pkg.__path__ += child_pkg_path + parent_pkg.__path__ = child_pkg_path + parent_pkg.__path__ except AttributeError: parent_pkg.__path__ = child_pkg_path -- GitLab From d6e95e5de2041110530ea7b1fe36b77c9469b1ff Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Wed, 5 Sep 2018 12:20:07 -0700 Subject: [PATCH 132/540] Make logging less verbose I want --vmodule=xla_compilation_cache=1 to print only the most essential things. PiperOrigin-RevId: 211676846 --- tensorflow/compiler/jit/xla_compilation_cache.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index ef6b0e67d3..dcb0b3240a 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -259,7 +259,7 @@ Status XlaCompilationCache::CompileImpl( const XlaCompiler::CompileOptions& compile_options, bool compile_single_op) { CHECK_NE(executable, nullptr); - VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); + VLOG(2) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { VLOG(2) << "num_inputs=" << ctx->num_inputs() @@ -310,7 +310,7 @@ Status XlaCompilationCache::CompileImpl( // cache eviction. mutex_lock entry_lock(entry->mu); if (!entry->compiled) { - VLOG(1) << "Compilation cache miss for signature: " + VLOG(2) << "Compilation cache miss for signature: " << SignatureDebugString(signature); tensorflow::Env* env = tensorflow::Env::Default(); const uint64 compile_start_us = env->NowMicros(); -- GitLab From 1486421be066d740ccf55426c013e4d32e78ad91 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 12:52:22 -0700 Subject: [PATCH 133/540] Make TFLite NNAPI delegate friendlier to application code. Esp. allows running benchmark on O-MR1 without an exit() of the process. Also fixes bug in interpretation of error values (NNAPI vs. TFLite error codes). PiperOrigin-RevId: 211681942 --- tensorflow/contrib/lite/nnapi_delegate.cc | 65 +++++++++++++++-------- 1 file changed, 42 insertions(+), 23 deletions(-) diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 602f3ee5d2..484842713d 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -64,6 +64,14 @@ void logError(const char* format, ...) { __LINE__); \ } +#define RETURN_ERROR_IF_TFLITE_FAILED(x) \ + if (x != kTfLiteOk) { \ + logError( \ + "Returning error since TFLite returned failure nnapi_delegate.cc:%d.", \ + __LINE__); \ + return kTfLiteError; \ + } + #define RETURN_ERROR_IF_NN_FAILED(x) \ if (x != ANEURALNETWORKS_NO_ERROR) { \ logError( \ @@ -299,17 +307,21 @@ TfLiteStatus AddOpsAndParams( }; auto check_and_add_activation = [&add_scalar_int32](int activation) { if (activation > kTfLiteActRelu6) { - FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations"); + logError("NNAPI only supports RELU, RELU1 and RELU6 activations"); + return kTfLiteError; } add_scalar_int32(activation); + return kTfLiteOk; }; auto add_add_params = [&add_scalar_int32](void* data) { auto* builtin = reinterpret_cast(data); if (builtin->activation > kTfLiteActRelu6) { - FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations"); + logError("NNAPI only supports RELU, RELU1 and RELU6 activations"); + return kTfLiteError; } add_scalar_int32(builtin->activation); + return kTfLiteOk; }; auto add_pooling_params = [&add_scalar_int32, @@ -320,7 +332,7 @@ TfLiteStatus AddOpsAndParams( add_scalar_int32(builtin->stride_height); add_scalar_int32(builtin->filter_width); add_scalar_int32(builtin->filter_height); - check_and_add_activation(builtin->activation); + return check_and_add_activation(builtin->activation); }; auto add_convolution_params = [&add_scalar_int32, @@ -329,7 +341,7 @@ TfLiteStatus AddOpsAndParams( add_scalar_int32(builtin->padding); add_scalar_int32(builtin->stride_width); add_scalar_int32(builtin->stride_height); - check_and_add_activation(builtin->activation); + return check_and_add_activation(builtin->activation); }; auto add_depthwise_conv_params = [&add_scalar_int32, @@ -339,20 +351,22 @@ TfLiteStatus AddOpsAndParams( add_scalar_int32(builtin->stride_width); add_scalar_int32(builtin->stride_height); add_scalar_int32(builtin->depth_multiplier); - check_and_add_activation(builtin->activation); + return check_and_add_activation(builtin->activation); }; auto add_fully_connected_params = [&check_and_add_activation](void* data) { auto builtin = reinterpret_cast(data); - check_and_add_activation(builtin->activation); + return check_and_add_activation(builtin->activation); }; auto add_concatenation_params = [&add_scalar_int32](void* data) { auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->axis); if (builtin->activation != kTfLiteActNone) { - FATAL("Concatenation does not support fused activation in NNAPI"); + logError("Concatenation does not support fused activation in NNAPI"); + return kTfLiteError; } + return kTfLiteOk; }; auto add_softmax_params = [&add_scalar_float32](void* data) { @@ -433,22 +447,22 @@ TfLiteStatus AddOpsAndParams( switch (builtin) { case tflite::BuiltinOperator_ADD: nn_op_type = ANEURALNETWORKS_ADD; - add_add_params(node.builtin_data); + RETURN_ERROR_IF_TFLITE_FAILED(add_add_params(node.builtin_data)); break; case tflite::BuiltinOperator_MUL: nn_op_type = ANEURALNETWORKS_MUL; - add_add_params(node.builtin_data); + RETURN_ERROR_IF_TFLITE_FAILED(add_add_params(node.builtin_data)); break; case tflite::BuiltinOperator_AVERAGE_POOL_2D: - add_pooling_params(node.builtin_data); + RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data)); nn_op_type = ANEURALNETWORKS_AVERAGE_POOL_2D; break; case tflite::BuiltinOperator_MAX_POOL_2D: - add_pooling_params(node.builtin_data); + RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data)); nn_op_type = ANEURALNETWORKS_MAX_POOL_2D; break; case tflite::BuiltinOperator_L2_POOL_2D: - add_pooling_params(node.builtin_data); + RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data)); nn_op_type = ANEURALNETWORKS_L2_POOL_2D; break; case tflite::BuiltinOperator_CONV_2D: { @@ -459,7 +473,8 @@ TfLiteStatus AddOpsAndParams( return kTfLiteError; } } - add_convolution_params(node.builtin_data); + RETURN_ERROR_IF_TFLITE_FAILED( + add_convolution_params(node.builtin_data)); nn_op_type = ANEURALNETWORKS_CONV_2D; break; case tflite::BuiltinOperator_RELU: @@ -478,11 +493,13 @@ TfLiteStatus AddOpsAndParams( nn_op_type = ANEURALNETWORKS_LOGISTIC; break; case tflite::BuiltinOperator_DEPTHWISE_CONV_2D: - add_depthwise_conv_params(node.builtin_data); + RETURN_ERROR_IF_TFLITE_FAILED( + add_depthwise_conv_params(node.builtin_data)); nn_op_type = ANEURALNETWORKS_DEPTHWISE_CONV_2D; break; case tflite::BuiltinOperator_CONCATENATION: - add_concatenation_params(node.builtin_data); + RETURN_ERROR_IF_TFLITE_FAILED( + add_concatenation_params(node.builtin_data)); nn_op_type = ANEURALNETWORKS_CONCATENATION; break; case tflite::BuiltinOperator_SOFTMAX: @@ -490,7 +507,8 @@ TfLiteStatus AddOpsAndParams( nn_op_type = ANEURALNETWORKS_SOFTMAX; break; case tflite::BuiltinOperator_FULLY_CONNECTED: - add_fully_connected_params(node.builtin_data); + RETURN_ERROR_IF_TFLITE_FAILED( + add_fully_connected_params(node.builtin_data)); nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED; break; case tflite::BuiltinOperator_RESHAPE: @@ -544,14 +562,14 @@ TfLiteStatus AddOpsAndParams( case tflite::BuiltinOperator_DIV: nnapi_version = 11; // require NNAPI 1.1 nn_op_type = ANEURALNETWORKS_DIV; - check_and_add_activation( - reinterpret_cast(node.builtin_data)->activation); + RETURN_ERROR_IF_TFLITE_FAILED(check_and_add_activation( + reinterpret_cast(node.builtin_data)->activation)); break; case tflite::BuiltinOperator_SUB: nnapi_version = 11; // require NNAPI 1.1 nn_op_type = ANEURALNETWORKS_SUB; - check_and_add_activation( - reinterpret_cast(node.builtin_data)->activation); + RETURN_ERROR_IF_TFLITE_FAILED(check_and_add_activation( + reinterpret_cast(node.builtin_data)->activation)); break; case tflite::BuiltinOperator_SQUEEZE: nnapi_version = 11; // requires NNAPI 1.1 @@ -664,7 +682,8 @@ TfLiteStatus AddOpsAndParams( } if (nnapi_version == 11 && GetAndroidSdkVersionCached() < 28) { - FATAL("Op %d needs NNAPI1.1", builtin); + logError("Op %d needs NNAPI1.1", builtin); + return kTfLiteError; } // Add the operation. @@ -712,9 +731,9 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { interpreter->outputs().size()); uint32_t next_id = 0; - RETURN_ERROR_IF_NN_FAILED(addTensorOperands( + RETURN_ERROR_IF_TFLITE_FAILED(addTensorOperands( interpreter, nn_model_, &next_id, &tensor_id_to_nnapi_id)); - RETURN_ERROR_IF_NN_FAILED( + RETURN_ERROR_IF_TFLITE_FAILED( AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_, &model_states_outputs_, tensor_id_to_nnapi_id)); -- GitLab From 5d3f444034e6b9af914a59efe9f8de2710079e13 Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Wed, 5 Sep 2018 12:52:29 -0700 Subject: [PATCH 134/540] BEGIN_PUBLIC Automated rollback of commit 7fa693209fe238478739b3982f652a7e35be91f3 PiperOrigin-RevId: 211681957 --- tensorflow/compiler/xla/service/BUILD | 48 --- .../compiler/xla/service/buffer_assignment.cc | 28 +- .../xla/service/buffer_assignment_test.cc | 98 +++-- .../xla/service/buffer_liveness_test.cc | 42 +-- .../compiler/xla/service/cpu/cpu_compiler.cc | 56 +-- .../compiler/xla/service/cpu/ir_emitter.cc | 2 +- .../compiler/xla/service/cpu/ir_emitter.h | 2 +- tensorflow/compiler/xla/service/gpu/BUILD | 1 - .../xla/service/gpu/gpu_hlo_schedule.cc | 6 +- .../xla/service/gpu/gpu_hlo_schedule.h | 4 +- .../compiler/xla/service/heap_simulator.cc | 43 ++- .../compiler/xla/service/heap_simulator.h | 48 +-- .../xla/service/heap_simulator_test.cc | 36 +- .../xla/service/hlo_alias_analysis_test.cc | 16 +- .../xla/service/hlo_dataflow_analysis_test.cc | 29 +- .../compiler/xla/service/hlo_ordering.cc | 86 +++-- .../compiler/xla/service/hlo_ordering.h | 22 +- .../compiler/xla/service/hlo_ordering_test.cc | 101 ------ .../xla/service/hlo_rematerialization.cc | 87 +++-- .../xla/service/hlo_rematerialization.h | 19 +- .../xla/service/hlo_rematerialization_test.cc | 46 ++- .../compiler/xla/service/hlo_schedule.cc | 291 --------------- .../compiler/xla/service/hlo_schedule.h | 151 -------- .../compiler/xla/service/hlo_schedule_test.cc | 341 ----------------- .../compiler/xla/service/hlo_scheduling.cc | 230 ++++++++++-- .../compiler/xla/service/hlo_scheduling.h | 54 ++- .../xla/service/hlo_scheduling_test.cc | 343 +++++++++++++++--- 27 files changed, 905 insertions(+), 1325 deletions(-) delete mode 100644 tensorflow/compiler/xla/service/hlo_schedule.cc delete mode 100644 tensorflow/compiler/xla/service/hlo_schedule.h delete mode 100644 tensorflow/compiler/xla/service/hlo_schedule_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 612302781c..f6cfac6537 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -989,7 +989,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1037,7 +1036,6 @@ tf_cc_test( ":flatten_call_graph", ":hlo", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1051,7 +1049,6 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1065,7 +1062,6 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", - ":hlo_schedule", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1086,7 +1082,6 @@ tf_cc_test( ":hlo", ":hlo_dataflow_analysis", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1094,7 +1089,6 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", ], ) @@ -1108,7 +1102,6 @@ cc_library( ":hlo", ":hlo_ordering", ":hlo_proto", - ":hlo_schedule", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1132,7 +1125,6 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", - "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1177,43 +1169,6 @@ cc_library( ], ) -cc_library( - name = "hlo_schedule", - srcs = ["hlo_schedule.cc"], - hdrs = ["hlo_schedule.h"], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib_internal", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - ], -) - -tf_cc_test( - name = "hlo_schedule_test", - srcs = ["hlo_schedule_test.cc"], - deps = [ - ":heap_simulator", - ":hlo", - ":hlo_dce", - ":hlo_ordering", - ":hlo_parser", - ":hlo_schedule", - ":hlo_scheduling", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - "@com_google_absl//absl/algorithm:container", - ], -) - cc_library( name = "hlo_scheduling", srcs = ["hlo_scheduling.cc"], @@ -1222,7 +1177,6 @@ cc_library( ":heap_simulator", ":hlo", ":hlo_ordering", - ":hlo_schedule", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1251,7 +1205,6 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", - "@com_google_absl//absl/algorithm:container", ], ) @@ -2413,7 +2366,6 @@ cc_library( ":hlo", ":hlo_dce", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 0f0af57626..8b8c6bfd26 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -617,24 +617,18 @@ Status BufferAssignment::ComputeSummaryStats() { } // Only compute total fragmentation if all computations have schedules. - HloSchedule schedule(module_); - bool schedule_complete = true; + SequentialHloOrdering::HloModuleSequence module_sequence; for (const auto& computation : module_->computations()) { - if (!computation->IsFusionComputation()) { - const std::vector* sequence = - liveness_->hlo_ordering().SequentialOrder(*computation); - if (sequence == nullptr) { - schedule_complete = false; - } else { - schedule.set_sequence(computation, *sequence); - } + const std::vector* sequence = + liveness_->hlo_ordering().SequentialOrder(*computation); + if (sequence != nullptr) { + module_sequence.emplace(computation, *sequence); } } - if (schedule_complete) { - TF_RETURN_IF_ERROR(schedule.Verify()); + if (module_sequence.size() == module_->computation_count()) { TF_ASSIGN_OR_RETURN( const int64 min_size, - HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_)); + HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } @@ -1070,7 +1064,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( // since buffers for kCall, kWhile, and kConditional sub-computations are // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; - HloSchedule schedule(&assignment->module()); + SequentialHloOrdering::HloModuleSequence module_sequence; FlatSet all_buffers_to_assign; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; @@ -1078,7 +1072,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const std::vector* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); - schedule.set_sequence(computation, *instruction_sequence); + module_sequence[computation] = *instruction_sequence; all_buffers_to_assign.insert(buffers_to_assign.begin(), buffers_to_assign.end()); } @@ -1096,7 +1090,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HeapSimulator::Result result, HeapSimulator::Run(absl::make_unique( absl::make_unique(alignment)), - assignment->module(), schedule, + assignment->module(), module_sequence, assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, @@ -1127,7 +1121,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( HeapSimulator::Run( absl::make_unique( absl::make_unique(alignment)), - *computation, HloInstructionSequence(*instruction_sequence), + *computation, *instruction_sequence, assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 03e155fc11..7398f105a0 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -33,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -41,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -122,10 +120,14 @@ class BufferAssignmentTest : public HloVerifiedTestBase { HloModule* module, absl::Span instruction_sequence, int64 alignment = 1) { - HloSchedule schedule(module); - schedule.set_sequence(module->entry_computation(), instruction_sequence); + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[module->entry_computation()] = + std::vector(instruction_sequence.begin(), + instruction_sequence.end()); return BufferAssigner::Run( - module, absl::make_unique(schedule), + module, + absl::make_unique(module, + module_sequence), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -1783,10 +1785,11 @@ class WhileBufferAssignmentTest : public HloVerifiedTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { - HloSchedule schedule = - ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); + auto sequence = + ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, absl::make_unique(schedule), + module, + absl::make_unique(module, sequence), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -2093,25 +2096,17 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // Create a sequential order among all the instructions in the entry // computation, since the issue this test stresses depends on the order the // nodes are traversed during BufferAssignment. - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), - /*pointer_size=*/sizeof(void*)); - })); - schedule.set_sequence( - module->entry_computation(), - {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}); - TF_ASSERT_OK(schedule.Verify()); - + SequentialHloOrdering::HloModuleSequence sequence; + sequence[module->entry_computation()] = { + token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}; TF_ASSERT_OK_AND_ASSIGN( auto assignment, - BufferAssigner::Run(module, - absl::make_unique(schedule), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run( + module, absl::make_unique(module, sequence), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -2268,6 +2263,29 @@ ENTRY Main { GetAllocation(*buffers, param0, {1, 1})); } +static bool IsPostOrderTraversal( + const std::vector& sequence) { + tensorflow::gtl::FlatSet seen_so_far; + auto has_not_been_seen_yet = [&](const HloInstruction* instruction) { + return seen_so_far.count(instruction) == 0; + }; + + for (auto instruction : sequence) { + if (std::any_of(instruction->operands().begin(), + instruction->operands().end(), has_not_been_seen_yet) || + std::any_of(instruction->control_predecessors().begin(), + instruction->control_predecessors().end(), + has_not_been_seen_yet)) { + return false; // Not a post order. + } + if (!seen_so_far.insert(instruction).second) { + return false; // Not a "traversal". + } + } + + return true; +} + TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); @@ -2322,27 +2340,27 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { RunCopyInsertion(module); - HloSchedule schedule = - ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); + auto sequence = + ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); - // To trigger b/38494731, we want a specific Hlo schedule for the + // To trigger b/38494731, we want a specific Hlo sequence for the // root computation, so we overwrite that entry with a manually // crafted sequence. - schedule.set_sequence(module->entry_computation(), - {input1, weights1, one, output1, while1->operand(0), - while1, input0, weights0, zero, output0, - while0->operand(0), while0, gte0, gte1, root_add}); + sequence[module->entry_computation()] = { + input1, weights1, one, output1, while1->operand(0), while1, + input0, weights0, zero, output0, while0->operand(0), while0, + gte0, gte1, root_add}; - // If this ASSERT fails, we constructed a bogus sequence above and this test - // itself is buggy. - TF_ASSERT_OK(schedule.Verify()); + // If this ASSERT_TRUE fails, we constructed a bogus sequence above + // and this test itself is buggy. + ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); auto assignment = - BufferAssigner::Run(module, - absl::make_unique(schedule), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true) + BufferAssigner::Run( + module, absl::make_unique(module, sequence), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 414bfe7999..26e26e316d 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -167,12 +166,12 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto module = CreateNewModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); - HloSchedule schedule(module.get()); - schedule.set_sequence(entry, {param0, negate, param1, exp, add}); - auto liveness = - BufferLiveness::Run(module.get(), - absl::make_unique(schedule)) - .ConsumeValueOrDie(); + SequentialHloOrdering::HloModuleSequence sequence; + sequence.insert({entry, {param0, negate, param1, exp, add}}); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique( + module.get(), sequence)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -292,12 +291,13 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - HloSchedule schedule(module.get()); - schedule.set_sequence(computation, {param, negate, exp, add}); - auto liveness = - BufferLiveness::Run(module.get(), - absl::make_unique(schedule)) - .ConsumeValueOrDie(); + SequentialHloOrdering::HloModuleSequence module_sequence; + std::vector order = {param, negate, exp, add}; + module_sequence.emplace(computation, order); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique( + module.get(), module_sequence)) + .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -339,14 +339,14 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build(add)); - HloSchedule schedule(module.get()); - schedule.set_sequence(computation, - {param, add, token, recv, recv_done, send, send_done}); - TF_ASSERT_OK(schedule.Verify()); - auto liveness = - BufferLiveness::Run(module.get(), - absl::make_unique(schedule)) - .ConsumeValueOrDie(); + SequentialHloOrdering::HloModuleSequence module_sequence; + std::vector order = {param, add, recv, + recv_done, send, send_done}; + module_sequence.emplace(computation, order); + auto liveness = BufferLiveness::Run(module.get(), + absl::make_unique( + module.get(), module_sequence)) + .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); // Check the root instruction (add) buffer interferes with the recv buffer. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index e7b6075994..796f36510e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -584,14 +584,16 @@ StatusOr> CpuCompiler::RunBackend( // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( - HloSchedule schedule, - ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); + SequentialHloOrdering::HloModuleSequence module_sequence, + ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), + DFSMemoryScheduler)); // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run(module.get(), - absl::make_unique(schedule), + absl::make_unique( + module.get(), module_sequence), BufferSizeBytesFunction(), memory_alignment, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); @@ -625,10 +627,9 @@ StatusOr> CpuCompiler::RunBackend( } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation( - embedded_computation, embedded_computation->name(), - /*is_top_level_computation=*/false, - &schedule.sequence(embedded_computation).instructions()) + .EmitComputation(embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &module_sequence.at(embedded_computation)) .status()); } string function_name_prefix = entry_computation->name().empty() @@ -636,10 +637,9 @@ StatusOr> CpuCompiler::RunBackend( : entry_computation->name(); TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, - ir_emitter.EmitComputation( - entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &schedule.sequence(entry_computation).instructions())); + ir_emitter.EmitComputation(entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + &module_sequence.at(entry_computation))); string function_name = [&]() { llvm::SmallVector function_name_vector; @@ -771,18 +771,20 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_ASSIGN_OR_RETURN(HloSchedule schedule, - ScheduleModule(*module, BufferSizeBytesFunction())); + TF_ASSIGN_OR_RETURN( + SequentialHloOrdering::HloModuleSequence module_sequence, + ScheduleComputationsInModule(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run(module, - absl::make_unique(schedule), - BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run( + module, + absl::make_unique(module, module_sequence), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -822,18 +824,18 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation( - embedded_computation, embedded_computation->name(), - /*is_top_level_computation=*/false, - &schedule.sequence(embedded_computation).instructions()) + .EmitComputation(embedded_computation, + embedded_computation->name(), + /*is_top_level_computation=*/false, + &module_sequence.at(embedded_computation)) .status()); } const string& entry_point_name = options.entry_point_name(); - TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, - ir_emitter.EmitComputation( - computation, entry_point_name, - /*is_top_level_computation=*/true, - &schedule.sequence(computation).instructions())); + TF_ASSIGN_OR_RETURN( + llvm::Function * entry_function, + ir_emitter.EmitComputation(computation, entry_point_name, + /*is_top_level_computation=*/true, + &module_sequence.at(computation))); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index df8c2a636b..e5cf15c686 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -110,7 +110,7 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order) { + std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]; ordered? " << (instruction_order != nullptr); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 3df99464ba..58a333b8fb 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -98,7 +98,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - const std::vector* instruction_order); + std::vector* instruction_order); llvm::IRBuilder<>* b() { return &b_; } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 13ccff35f8..a68b7a1bef 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -813,7 +813,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:hlo_schedule", "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index ea9376e101..743035a84e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" -#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" @@ -199,12 +198,11 @@ StatusOr> GpuHloSchedule::Build( // All kernels are launched on a single stream, so there's no loss of // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( - HloInstructionSequence sequence, - ScheduleComputation( + schedule->thunk_launch_order_, + ScheduleOneComputation( *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); - schedule->thunk_launch_order_ = sequence.instructions(); } else { // BFS tends to increase concurrency, but also increases memory usage. BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h index 07a7fc67aa..30a0e7cecd 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h @@ -33,9 +33,7 @@ namespace gpu { // launches, because thunks may be scheduled onto concurrent streams. This // schedule is used by BufferAssigner to determine buffer liveness (i.e. to // minimize allocations), and also by ThunkSchedule to determine the thunk -// launch order. This class differs from xla::HloSchedule in that HloSchedule -// represents a total order of all instructions in the module for backends which -// execute HLO instructions strictly sequentially. +// launch order. class GpuHloSchedule { public: // Constructs an GpuHloSchedule for the given module, based on the given diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index e0f3a7e0e2..38c3982ebf 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -29,13 +29,13 @@ using tensorflow::gtl::FlatSet; /*static*/ StatusOr HeapSimulator::MinimumMemoryForModule( - const HloSchedule& schedule, + const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function) { - if (schedule.empty()) { + if (module_sequence.empty()) { return 0; } - const HloModule* module = schedule.module(); + const HloModule* module = module_sequence.begin()->first->parent(); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -47,13 +47,14 @@ StatusOr HeapSimulator::MinimumMemoryForModule( TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(absl::make_unique(), *module, - schedule, *points_to_analysis, size_function)); + module_sequence, *points_to_analysis, size_function)); return result.heap_size; } /*static*/ StatusOr HeapSimulator::MinimumMemoryForComputation( - const HloComputation& computation, const HloInstructionSequence& sequence, + const HloComputation& computation, + const std::vector& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap* @@ -70,13 +71,13 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, - const HloSchedule& schedule, + const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options) { - HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule); + HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); const HloComputation* entry_computation = module.entry_computation(); - const HloInstructionSequence& instruction_sequence = - schedule.sequence(entry_computation); + const std::vector& instruction_sequence = + FindOrDie(module_sequence, entry_computation); TF_RETURN_IF_ERROR(heap.RunComputation( *entry_computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -85,13 +86,13 @@ StatusOr HeapSimulator::Run( /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, - const HloInstructionSequence& instruction_sequence, + const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, const tensorflow::gtl::FlatMap* memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*schedule=*/nullptr, memory_by_computation); + /*module_sequence=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -101,7 +102,7 @@ StatusOr HeapSimulator::Run( // 'instruction_sequence'. Status HeapSimulator::RunComputation( const HloComputation& computation, - const HloInstructionSequence& instruction_sequence, + const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis) { VLOG(3) << "Computation:\n" << computation.ToString(); // The goal here is to minimize memory usage, assuming the given sequential @@ -132,8 +133,7 @@ Status HeapSimulator::RunComputation( // set of instructions that need to be visited contains all users of all // aliases, that is, all users of all instructions that have the buffer // contained in their points-to set. - for (const HloInstruction* instruction : - instruction_sequence.instructions()) { + for (const HloInstruction* instruction : instruction_sequence) { const PointsToSet& points_to = points_to_analysis.GetPointsToSet(instruction); const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); @@ -166,8 +166,7 @@ Status HeapSimulator::RunComputation( std::vector dead_buffers_to_free; std::vector operand_buffers_to_free; - for (const HloInstruction* instruction : - instruction_sequence.instructions()) { + for (const HloInstruction* instruction : instruction_sequence) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); @@ -286,14 +285,14 @@ Status HeapSimulator::RunComputation( // The order that the sub-computations are simulated does not affect // correctness; since the whole module has been scheduled, we know that the // sub-computations will never be run concurrently. - if (schedule_ != nullptr) { + if (module_sequence_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || instruction->opcode() == HloOpcode::kConditional || instruction->opcode() == HloOpcode::kWhile) { for (const HloComputation* called_computation : instruction->called_computations()) { - const HloInstructionSequence& called_sequence = - schedule_->sequence(called_computation); + const std::vector& called_sequence = + FindOrDie(*module_sequence_, called_computation); TF_RETURN_IF_ERROR(RunComputation( *called_computation, called_sequence, points_to_analysis)); } @@ -344,16 +343,16 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const HloSchedule* schedule, + const SequentialHloOrdering::HloModuleSequence* module_sequence, const tensorflow::gtl::FlatMap* memory_by_computation) : no_fragmentation_stats_(absl::make_unique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - schedule_(schedule), + module_sequence_(module_sequence), memory_by_computation_(memory_by_computation) { - debug_trace_.set_whole_module_simulation(schedule_ != nullptr); + debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); } HeapSimulator::~HeapSimulator() {} diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index ffbf947d5a..af05bedee7 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -89,22 +88,23 @@ class HeapSimulator { // Returns the minimum memory required to compute an HLO module where all // computations have been scheduled (represented by the given - // schedule), assuming no fragmentation. + // module_sequence), assuming no fragmentation. static StatusOr MinimumMemoryForModule( - const HloSchedule& schedule, + const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function); // Returns the minimum memory required to compute the given computation, // assuming no fragmentation. static StatusOr MinimumMemoryForComputation( - const HloComputation& computation, const HloInstructionSequence& sequence, + const HloComputation& computation, + const std::vector& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap* memory_by_computation = nullptr); // Run the heap simulation with the given algorithm, assuming the given - // schedule, which must contain a topologically-consistent total + // module_sequence, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid // if instructions are not run in exactly this sequence. // @@ -112,12 +112,12 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - static StatusOr Run(std::unique_ptr algorithm, - const HloModule& module, - const HloSchedule& schedule, - const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + static StatusOr Run( + std::unique_ptr algorithm, const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const TuplePointsToAnalysis& points_to_analysis, + const BufferValue::SizeFunction& size_fn, + const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' // must contain a topologically-consistent total ordering of all instructions @@ -126,7 +126,7 @@ class HeapSimulator { static StatusOr Run( std::unique_ptr algorithm, const HloComputation& computation, - const HloInstructionSequence& instruction_sequence, + const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options = Options(), @@ -134,19 +134,21 @@ class HeapSimulator { memory_by_computation = nullptr); private: - // If 'schedule' is non-null, it is used to find kCall and kWhile + // If 'module_sequence' is non-null, it is used to find kCall and kWhile // sub-computations, and the heap simulation for those sub-computations will // be run recursively. I.e. the simulation is run over the whole module. - HeapSimulator(std::unique_ptr algorithm, - const BufferValue::SizeFunction& size_fn, - const Options& options, const HloSchedule* schedule = nullptr, - const tensorflow::gtl::FlatMap* - memory_by_computation = nullptr); + HeapSimulator( + std::unique_ptr algorithm, + const BufferValue::SizeFunction& size_fn, const Options& options, + const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); ~HeapSimulator(); - Status RunComputation(const HloComputation& computation, - const HloInstructionSequence& instruction_sequence, - const TuplePointsToAnalysis& points_to_analysis); + Status RunComputation( + const HloComputation& computation, + const std::vector& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis); bool IgnoreBuffer(const BufferValue* buffer) const; void Alloc(const BufferValue* buffer, const HloInstruction* instruction); @@ -167,11 +169,11 @@ class HeapSimulator { const std::unique_ptr algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; - // schedule_ is set by buffer assignment, and memory_by_computation_ is + // module_sequence_ is set by buffer assignment, and memory_by_computation_ is // set by hlo scheduling. Then, in RunComputation, we check both in order to // handle subcomputations. It would be good to unify the handling of // subcomputations, but it's not clear how. - const HloSchedule* schedule_; + const SequentialHloOrdering::HloModuleSequence* module_sequence_; const tensorflow::gtl::FlatMap* memory_by_computation_; diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 00a25db467..7ad8a107e1 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -86,16 +85,13 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - HloSchedule schedule(module.get()); - schedule.set_sequence(cond_computation, - {cond_param, cond_iter, cond_data, cond_lt}); - schedule.set_sequence(body_computation, {body_param}); - schedule.set_sequence(entry_computation, {iter, data, tuple, while_op}); - TF_ASSERT_OK(schedule.Verify()); - - EXPECT_EQ( - 56, - HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie()); + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) + .ValueOrDie()); } const char kAlloc[] = "Alloc"; @@ -153,11 +149,10 @@ class HeapSimulatorTracker { auto zero_size = [](const BufferValue& buffer) { return 0; }; auto algorithm = absl::make_unique( absl::make_unique(&actual_calls_)); - result_ = - HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(), - HloInstructionSequence(instruction_sequence), - *points_to_analysis_, zero_size) - .ConsumeValueOrDie(); + result_ = HeapSimulator::Run( + std::move(algorithm), *module_->entry_computation(), + instruction_sequence, *points_to_analysis_, zero_size) + .ConsumeValueOrDie(); } explicit HeapSimulatorTracker(const string& name) { @@ -173,12 +168,11 @@ class HeapSimulatorTracker { TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); // Construct the module sequence grouped by computation. - HloSchedule schedule(module_.get()); + SequentialHloOrdering::HloModuleSequence module_sequence; tensorflow::gtl::FlatMap reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { const HloInstruction* instruction = full_module_sequence[i]; - schedule.GetOrCreateSequence(instruction->parent()) - .push_back(instruction); + module_sequence[instruction->parent()].push_back(instruction); reverse_position[instruction] = full_module_sequence.size() - i; } @@ -191,8 +185,8 @@ class HeapSimulatorTracker { }; auto algorithm = absl::make_unique( absl::make_unique(&actual_calls_)); - result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule, - *points_to_analysis_, size_fn) + result_ = HeapSimulator::Run(std::move(algorithm), *module_, + module_sequence, *points_to_analysis_, size_fn) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 0cd0ab36fc..54abe3345d 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -885,20 +885,18 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { // For a sequential order, if there is interference iff the negate is after // the while. - HloSchedule schedule(module_); - schedule.set_sequence(body, {body_param, body_root}); - schedule.set_sequence(condition, {cond_param, cond_root}); + SequentialHloOrdering::HloModuleSequence sequence; + sequence[body] = {body_param, body_root}; + sequence[condition] = {cond_param, cond_root}; { - schedule.set_sequence(entry, {init, xla_while, negate, entry_root}); - TF_ASSERT_OK(schedule.Verify()); - SequentialHloOrdering ordering(schedule); + sequence[entry] = {init, xla_while, negate, entry_root}; + SequentialHloOrdering ordering(module_, sequence); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } { - schedule.set_sequence(entry, {init, negate, xla_while, entry_root}); - TF_ASSERT_OK(schedule.Verify()); - SequentialHloOrdering ordering(schedule); + sequence[entry] = {init, negate, xla_while, entry_root}; + SequentialHloOrdering ordering(module_, sequence); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 0a86f83ed9..62eea2b06c 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -1262,10 +1261,9 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - HloSchedule schedule(module_.get()); - schedule.set_sequence(entry, {param0, negate, param1, exp, add}); - TF_ASSERT_OK(schedule.Verify()); - SequentialHloOrdering ordering(schedule); + SequentialHloOrdering::HloModuleSequence sequence; + sequence.insert({entry, {param0, negate, param1, exp, add}}); + SequentialHloOrdering ordering(module_.get(), sequence); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -1341,16 +1339,14 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { bool ssa_form = GetParam(); RunAnalysis(ssa_form); - HloSchedule schedule(module_.get()); - schedule.set_sequence(entry, {param, xla_while}); - schedule.set_sequence(condition, {cond_param, cond_constant}); + SequentialHloOrdering::HloModuleSequence sequence; + sequence.insert({entry, {param, xla_while}}); + sequence.insert({condition, {cond_param, cond_constant}}); // Construct the order such that 'constant' and its use 'exp' are before // body_param. - schedule.set_sequence( - body, {constant, exp, body_param, add, dead_constant, dead_negate}); - TF_ASSERT_OK(schedule.Verify()); + sequence.insert({body, {constant, exp, body_param, add}}); - SequentialHloOrdering ordering(schedule); + SequentialHloOrdering ordering(module_.get(), sequence); // 'add' is live out of the body and will interfere with an later instructions // such as 'dead_constant' and 'dead_negate'. @@ -1480,10 +1476,11 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - HloSchedule schedule(module_.get()); - schedule.set_sequence(entry, {param, negate, exp, add}); - TF_ASSERT_OK(schedule.Verify()); - SequentialHloOrdering ordering(schedule); + SequentialHloOrdering::HloModuleSequence sequence; + std::vector order = {param, negate, exp, add}; + sequence.emplace(entry, order); + + SequentialHloOrdering ordering(module_.get(), sequence); EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 2105f7a349..0581d5c404 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -253,12 +252,6 @@ bool HloOrdering::LiveRangeStrictlyBefore( VLOG(4) << a << " not defined before " << b; return false; } - - if (a.live_out_of_module()) { - VLOG(4) << a << " is live out of module and defined before " << b; - return false; - } - // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), @@ -271,18 +264,6 @@ bool HloOrdering::LiveRangeStrictlyBefore( return false; } } - - if (a.instruction()->parent() == b.instruction()->parent()) { - for (const HloPosition& position : a.positions()) { - if (position.instruction == - a.instruction()->parent()->root_instruction()) { - VLOG(4) << a << " is live out of computation and defined before " << b - << " which is in same computation"; - return false; - } - } - } - return true; } @@ -355,24 +336,15 @@ string DependencyHloOrdering::ToString() const { return ToStringHelper("DependencyHloOrdering"); } -SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule) - : HloOrdering(schedule.module()), schedule_(schedule) { - Initialize(); -} - -SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule) - : HloOrdering(schedule.module()), schedule_(std::move(schedule)) { - Initialize(); -} - -void SequentialHloOrdering::Initialize() { +SequentialHloOrdering::SequentialHloOrdering( + const HloModule* module, const HloModuleSequence& module_sequence) + : HloOrdering(module), module_sequence_(module_sequence) { // Create a map from instruction to its order position. - TF_DCHECK_OK(schedule_.Verify()); - for (const auto& computation_sequence : schedule_.sequences()) { - const std::vector& order = - computation_sequence.second.instructions(); + for (auto computation_order : module_sequence_) { + const std::vector& order = computation_order.second; for (int i = 0; i < order.size(); ++i) { - InsertOrDie(&order_position_, order[i], i); + DCHECK_EQ(0, order_position_.count(order[i])); + order_position_.emplace(order[i], i); } } } @@ -390,13 +362,49 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( const std::vector* SequentialHloOrdering::SequentialOrder( const HloComputation& computation) const { - return schedule_.is_computation_scheduled(&computation) - ? &schedule_.sequence(&computation).instructions() - : nullptr; + auto find_it = module_sequence_.find(&computation); + return find_it == module_sequence_.end() ? nullptr : &find_it->second; } string SequentialHloOrdering::ToString() const { - return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString()); + std::vector pieces; + pieces.push_back("SequentialHloOrdering"); + for (auto* computation : module_->computations()) { + pieces.push_back( + absl::StrFormat("computation %s order:", computation->name())); + // Gather all instructions in the module sequence for this computation and + // sort them by their position. + std::vector instructions; + for (auto& instruction_position : order_position_) { + const HloInstruction* instruction = instruction_position.first; + if (instruction->parent() == computation) { + instructions.push_back(instruction); + } + } + std::sort(instructions.begin(), instructions.end(), + [this](const HloInstruction* a, const HloInstruction* b) { + return order_position_.at(a) < order_position_.at(b); + }); + for (auto instruction : instructions) { + pieces.push_back(absl::StrFormat(" %s", instruction->name())); + } + } + return absl::StrJoin(pieces, "\n"); +} + +std::ostream& operator<<( + std::ostream& out, + const SequentialHloOrdering::HloModuleSequence& module_sequence) { + for (auto computation_pair : module_sequence) { + const HloComputation* computation = computation_pair.first; + const std::vector& computation_sequence = + computation_pair.second; + out << "Computation " << computation->name() << ":\n"; + for (auto* instruction : computation_sequence) { + out << " " << instruction->name() << "\n"; + } + } + return out; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index b21071c4b2..985f3fa64d 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -184,8 +183,17 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: - SequentialHloOrdering(const HloSchedule& schedule); - SequentialHloOrdering(HloSchedule&& schedule); + // TODO(dimvar): HloModuleSequence is not a good name because it sounds like + // a sequence of modules, instead of a map of schedules for all computations + // in a module. We should change it at some point. + // + // A sequence of instructions for each computation in the module. + using HloModuleSequence = + tensorflow::gtl::FlatMap>; + + SequentialHloOrdering(const HloModule* module, + const HloModuleSequence& module_sequence); ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. @@ -195,12 +203,10 @@ class SequentialHloOrdering : public HloOrdering { string ToString() const override; protected: - void Initialize(); - bool ExecutesBeforeInSameComputation(const HloInstruction* a, const HloInstruction* b) const override; - const HloSchedule schedule_; + const HloModuleSequence module_sequence_; // The position of every instruction in the HLO module in its respective // computation sequence (a value of zero indicates the instruction is first in @@ -211,6 +217,10 @@ class SequentialHloOrdering : public HloOrdering { tensorflow::gtl::FlatMap order_position_; }; +std::ostream& operator<<( + std::ostream& out, + const SequentialHloOrdering::HloModuleSequence& module_sequence); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 6b6005e7a5..126d3a2d9c 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -23,13 +23,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -378,104 +376,5 @@ ENTRY root { dataflow->GetValueDefinedAt(add_3))); } -TEST_F(HloOrderingTest, - ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) { - // Tests that values live out of the module should interfere with values - // defined after the root instruction. That is: - // - // %param = param(0) - // ROOT %root = negate(%param) - // %dead = Constant(123.0) - // - // %root should interfere with %dead. - auto module = CreateNewModule(); - const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); - - auto builder = HloComputation::Builder(TestName()); - HloInstruction* param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param")); - HloInstruction* root = builder.AddInstruction( - HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - HloInstruction* dead = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); - HloComputation* entry = - module->AddEntryComputation(builder.Build(/*root_instruction=*/root)); - - HloSchedule schedule(module.get()); - schedule.set_sequence(entry, {param, root, dead}); - TF_ASSERT_OK(schedule.Verify()); - SequentialHloOrdering ordering(schedule); - - TF_ASSERT_OK_AND_ASSIGN(auto dataflow, - HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); - - EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); - EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); - - EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( - dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), - *dataflow)); - - EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), - dataflow->GetValueDefinedAt(dead), - *dataflow)); -} - -TEST_F(HloOrderingTest, - ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) { - // Tests that values live out of a computation should interfere with values - // defined after the root instruction of the computation. That is: - // - // subcomputation: - // %param = param(0) - // ROOT %root = negate(%param) - // %dead = Constant(123.0) - // - // entry computation: - // %c = constant(42.0) - // ROOT %call = call({%c}), subcomputation - // - // %root should interfere with %dead. - auto module = CreateNewModule(); - const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); - - auto subbuilder = HloComputation::Builder(TestName() + ".sub"); - HloInstruction* param = subbuilder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param")); - HloInstruction* root = subbuilder.AddInstruction( - HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); - HloInstruction* dead = subbuilder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); - HloComputation* subcomputation = module->AddEmbeddedComputation( - subbuilder.Build(/*root_instruction=*/root)); - - auto builder = HloComputation::Builder(TestName()); - HloInstruction* c = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); - HloInstruction* call = builder.AddInstruction( - HloInstruction::CreateCall(scalar_shape, {c}, subcomputation)); - HloComputation* entry = module->AddEntryComputation(builder.Build()); - - HloSchedule schedule(module.get()); - schedule.set_sequence(subcomputation, {param, root, dead}); - schedule.set_sequence(entry, {c, call}); - TF_ASSERT_OK(schedule.Verify()); - SequentialHloOrdering ordering(schedule); - - TF_ASSERT_OK_AND_ASSIGN(auto dataflow, - HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); - - EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); - EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); - - EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( - dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), - *dataflow)); - - EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), - dataflow->GetValueDefinedAt(dead), - *dataflow)); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 0a0a6a323e..c9629926ea 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -962,7 +962,8 @@ StatusOr HloRematerialization::CalledComputationsMemoryUsage( } StatusOr HloRematerialization::RematerializeComputation( - HloComputation* computation, HloSchedule* schedule, + HloComputation* computation, + SequentialHloOrdering::HloModuleSequence* sequence, int64 memory_limit_bytes) { VLOG(1) << "Rematerializing computation " << computation->name() << " with limit " << HumanReadableNumBytes(memory_limit_bytes); @@ -970,8 +971,7 @@ StatusOr HloRematerialization::RematerializeComputation( << HumanReadableNumBytes(computation_peak_memory_.at(computation)); CHECK(!ContainsKey(rematerialized_computations_, computation)); - InstructionList instruction_list( - schedule->sequence(computation).instructions()); + InstructionList instruction_list(sequence->at(computation)); MemoryUsageTracker memory_tracker(computation, size_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -1145,7 +1145,7 @@ StatusOr HloRematerialization::RematerializeComputation( 0, memory_limit_bytes - memory_tracker.memory_usage()); TF_ASSIGN_OR_RETURN( bool subcomputation_changed, - RematerializeComputation(called_computation, schedule, + RematerializeComputation(called_computation, sequence, subcomputation_memory_limit_bytes)); changed |= subcomputation_changed; } @@ -1179,12 +1179,12 @@ StatusOr HloRematerialization::RematerializeComputation( computation_peak_memory_.at(computation) = peak_memory; // Update order to include rematerialized instructions. - HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation); - sequence.clear(); + auto& dst = sequence->at(computation); + dst.clear(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { const HloInstruction* instruction = item->instruction; - sequence.push_back(instruction); + dst.push_back(instruction); } rematerialized_computations_.insert(computation); @@ -1194,21 +1194,20 @@ StatusOr HloRematerialization::RematerializeComputation( return changed; } -StatusOr HloRematerialization::Run(HloModule* module, - HloSchedule* schedule, - int64 memory_limit_bytes, - RematerializationSizes* sizes, - CopyInsertion* copy_insertion) { - // The schedule is constructed entirely by this method. - TF_RET_CHECK(schedule->empty()); +StatusOr HloRematerialization::Run( + HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, + int64 memory_limit_bytes, RematerializationSizes* sizes, + CopyInsertion* copy_insertion) { + // The sequence is constructed entirely by this method. + TF_RET_CHECK(sequence->empty()); VLOG(1) << "HloRematerialization() with memory limit of " << HumanReadableNumBytes(memory_limit_bytes); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial schedule of HLO instructions. - TF_ASSIGN_OR_RETURN(*schedule, - ScheduleModule(*module, + // Create initial sequence of HLO instructions. + TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( + *module, [this](const BufferValue& buffer) { return size_function_(buffer.shape()); }, @@ -1218,7 +1217,16 @@ StatusOr HloRematerialization::Run(HloModule* module, // ordering from the HLO schedule allows for more copies to be eliminated. // TODO(b/80249101): Instead of a separate copy elision pass, use the // ordering from the HLO schedule directly for copy insertion. - SequentialHloOrdering ordering(*schedule); + + // First create a copy of the schedule which contains HloInstruction unique + // ids instead of HloInstruction*. This is necessary for updating the + // schedule below. + // TODO(b/113175018): Remove this when the HLO schedule is self-contained + // and can update itself. + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(*sequence); + + SequentialHloOrdering ordering(module, *sequence); TF_RETURN_IF_ERROR( copy_insertion->RemoveUnnecessaryCopies(ordering, module)); @@ -1233,10 +1241,10 @@ StatusOr HloRematerialization::Run(HloModule* module, // The passes above can add and remove copies, update the schedule to // account for these transformations. Newly added instructions will be // placed ASAP in the schedule. - TF_RETURN_IF_ERROR(schedule->Update()); + TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence)); TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference( - SequentialHloOrdering(*schedule), module)); + SequentialHloOrdering(module, *sequence), module)); } TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); @@ -1263,13 +1271,12 @@ StatusOr HloRematerialization::Run(HloModule* module, // sequential context. call_graph_ = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph_->VisitNodes( - [this, schedule](const CallGraphNode& node) -> Status { + [this, sequence](const CallGraphNode& node) -> Status { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], - ComputePeakMemory( - node.computation(), - schedule->sequence(node.computation()).instructions())); + ComputePeakMemory(node.computation(), + sequence->at(node.computation()))); } return Status::OK(); }, @@ -1288,7 +1295,7 @@ StatusOr HloRematerialization::Run(HloModule* module, // Subcomputations called by the entry computation will also be // rematerialized. TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( - module->entry_computation(), schedule, + module->entry_computation(), sequence, adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an @@ -1298,7 +1305,30 @@ StatusOr HloRematerialization::Run(HloModule* module, // After DCE, the module sequence may include instructions which no longer // exist. - TF_RETURN_IF_ERROR(schedule->Update()); + for (const auto* computation : module->MakeNonfusionComputations()) { + if (sequence->at(computation).size() != computation->instruction_count()) { + // A size mismatch between the computation instruction count and the size + // of the ordering of instructions can only be caused by DCE. Rebuild the + // order by removing the deleted instructions from the order. + tensorflow::gtl::FlatSet instruction_set; + for (const auto& instruction : computation->instructions()) { + instruction_set.insert(instruction); + } + // Move the old order into a temporary vector, then build new order + // inplace. + std::vector& order = sequence->at(computation); + std::vector old_order; + using std::swap; + swap(order, old_order); + std::copy_if(old_order.begin(), old_order.end(), + std::back_inserter(order), + [&instruction_set](const HloInstruction* instruction) { + return ContainsKey(instruction_set, instruction); + }); + TF_RET_CHECK(sequence->at(computation).size() == + computation->instruction_count()); + } + } VLOG(1) << "Rematerialized " << instructions_rematerialized_ << " instructions in module " << module->name() << "; " << net_instructions_added_ << " net instructions added"; @@ -1336,10 +1366,11 @@ StatusOr HloRematerialization::Run(HloModule* module, /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, - MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule, + MemorySchedulerAlgorithm scheduler_algorithm, + SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes, CopyInsertion* copy_insertion) { HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes, + return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, copy_insertion); } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index fa0414b472..2ec004350a 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -21,7 +21,6 @@ #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -51,7 +50,7 @@ class HloRematerialization { // // hlo_module: HLO module to rematerialize instructions in. // - // schedule: Should point to an empty HloSchedule. Upon return + // sequence: Should point to an empty HloModuleSequence. Upon return // contains the HLO instruction order which was used for // rematerialization. This is the order in which HLO instructions should // be emitted to minimize memory use. @@ -76,8 +75,8 @@ class HloRematerialization { static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, - HloSchedule* schedule, RematerializationSizes* sizes, - CopyInsertion* copy_insertion = nullptr); + SequentialHloOrdering::HloModuleSequence* sequence, + RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr); protected: HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, @@ -88,9 +87,10 @@ class HloRematerialization { // Runs rematerialization on the given module. Returns whether the module was // changed. memory_limit is the target maximum peak memory usage by the - // module. schedule should be an empty HloSchedule. Upon return sequence + // module. sequence should be an empty HloModuleSequence. Upon return sequence // contains the memory-minimizing order in which to emit the HLO instructions. - StatusOr Run(HloModule* module, HloSchedule* schedule, + StatusOr Run(HloModule* module, + SequentialHloOrdering::HloModuleSequence* sequence, int64 memory_limit, RematerializationSizes* sizes, CopyInsertion* copy_insertion); @@ -98,9 +98,10 @@ class HloRematerialization { // order in which the computation's instructions will be emitted in the // backend. Rematerialized instructions will be added to the HLO computation // and inserted into 'order'. - StatusOr RematerializeComputation(HloComputation* computation, - HloSchedule* schedule, - int64 memory_limit_bytes); + StatusOr RematerializeComputation( + HloComputation* computation, + SequentialHloOrdering::HloModuleSequence* sequence, + int64 computation_memory_limit); // Computes and returns the peak memory used by the given computation. The // peak memory is the maximum total size of all live HLO instruction values at diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 83cb113bfb..ac8c97d380 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -141,13 +141,13 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } - StatusOr RunHloRematerialization(int64 memory_limit_bytes, - HloModule* module, - HloSchedule* schedule) { + StatusOr RunHloRematerialization( + int64 memory_limit_bytes, HloModule* module, + SequentialHloOrdering::HloModuleSequence* sequence) { TF_EXPECT_OK(verifier().Run(module).status()); return HloRematerialization::RematerializeAndSchedule( ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - schedule, /*sizes=*/nullptr); + sequence, /*sizes=*/nullptr); } // Various shapes used in the canned computations. @@ -170,12 +170,12 @@ TEST_F(HloRematerializationTest, SingleComputation) { const HloInstruction* concat = slice->operand(0); const HloInstruction* bcast = concat->operand(0); - HloSchedule schedule(module.get()); + SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/14 * 1024, - module.get(), &schedule)); + module.get(), &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -187,11 +187,9 @@ TEST_F(HloRematerializationTest, SingleComputation) { // The rematerialized broadcast should be immediate before the concat in the // sequence. - EXPECT_EQ(schedule.sequence(computation) - .instructions()[computation->instruction_count() - 2], + EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2], concat); - EXPECT_EQ(schedule.sequence(computation) - .instructions()[computation->instruction_count() - 3], + EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3], remat_bcast); } @@ -205,10 +203,10 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 8); - HloSchedule schedule(module.get()); + SequentialHloOrdering::HloModuleSequence sequence; TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/20 * 1024, - module.get(), &schedule)); + module.get(), &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -244,10 +242,10 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. - HloSchedule schedule(module.get()); + SequentialHloOrdering::HloModuleSequence sequence; TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/17 * 1024, - module.get(), &schedule)); + module.get(), &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -278,10 +276,10 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(entry_computation->instruction_count(), 7); EXPECT_EQ(body_computation->instruction_count(), 8); - HloSchedule schedule(module.get()); + SequentialHloOrdering::HloModuleSequence sequence; TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/15 * 1024, - module.get(), &schedule)); + module.get(), &sequence)); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -318,10 +316,10 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. - HloSchedule schedule(module.get()); + SequentialHloOrdering::HloModuleSequence sequence; TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/13 * 1024, - module.get(), &schedule)); + module.get(), &sequence)); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -384,14 +382,14 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { ASSERT_EQ(count_rngs(entry_computation), 1); const int64 original_instruction_count = entry_computation->instruction_count(); - HloSchedule schedule(module.get()); + SequentialHloOrdering::HloModuleSequence sequence; // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( bool changed, RunHloRematerialization( /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), &schedule)); + module.get(), &sequence)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -478,13 +476,13 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_EQ(add_3->operand(0), bcast); EXPECT_EQ(add_4->operand(0), bcast); - HloSchedule schedule(module.get()); + SequentialHloOrdering::HloModuleSequence sequence; // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/22 * 1024, - module.get(), &schedule)); + module.get(), &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -573,13 +571,13 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { EXPECT_EQ(entry_computation->instruction_count(), 8); - HloSchedule schedule(module.get()); + SequentialHloOrdering::HloModuleSequence sequence; // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/22 * 1024, - module.get(), &schedule)); + module.get(), &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc deleted file mode 100644 index a65b33bf40..0000000000 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ /dev/null @@ -1,291 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_schedule.h" - -#include -#include - -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/map_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/gtl/map_util.h" - -namespace xla { - -void HloSchedule::set_sequence( - const HloComputation* computation, - absl::Span sequence) { - set_sequence(computation, HloInstructionSequence(sequence)); -} - -void HloSchedule::set_sequence(const HloComputation* computation, - HloInstructionSequence sequence) { - CHECK(computation->parent() == module_); - sequences_[computation->unique_id()] = std::move(sequence); -} - -HloInstructionSequence& HloSchedule::GetOrCreateSequence( - const HloComputation* computation) { - auto it = sequences_.find(computation->unique_id()); - if (it == sequences_.end()) { - // No sequence found for computation. Create and return an empty one. - CHECK(computation->parent() == module_); - return sequences_[computation->unique_id()]; - } else { - return it->second; - } -} - -const HloInstructionSequence& HloSchedule::sequence( - const HloComputation* computation) const { - return sequences_.at(computation->unique_id()); -} - -Status HloSchedule::UpdateComputationSchedule( - const HloComputation* computation) { - // Map from unique ID to HloInstruction pointer for instructions in the - // computation. - tensorflow::gtl::FlatMap id_to_instruction; - for (const HloInstruction* instruction : computation->instructions()) { - InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); - } - - // Set of all HloInstructions in the schedule. - tensorflow::gtl::FlatSet ids_in_schedule; - for (int id : sequences_.at(computation->unique_id()).ids()) { - InsertOrDie(&ids_in_schedule, id); - } - - // Map from HloInstruction X to newly added instructions (instruction is in - // computation, but not in schedule) which use X. If an instruction is not in - // the map, then it has no users which are newly added instructions. - tensorflow::gtl::FlatMap> - new_instruction_uses; - - // For each newly added instruction, this is the count of the instruction's - // operands that have not yet been scheduled. When this value reaches zero, - // then the instruction may be placed in the schedule. - tensorflow::gtl::FlatMap - unscheduled_operand_count; - - // Create a worklist of newly added instructions which are ready to be added - // to the schedule. Initialize worklist with those that have zero operands. - std::queue worklist; - - for (const HloInstruction* instruction : computation->instructions()) { - if (ids_in_schedule.count(instruction->unique_id()) == 0) { - // This is a newly added instruction which is not in the schedule. - if (instruction->operands().empty()) { - worklist.push(instruction); - } else { - for (const HloInstruction* operand : instruction->operands()) { - new_instruction_uses[operand].push_back(instruction); - } - unscheduled_operand_count[instruction] = instruction->operand_count(); - } - } - } - - // Update the schedule with the newly added instructions, and remove any - // instructions no longer in the graph. - HloInstructionSequence new_sequence; - - // Lambda which schedules all instructions on the worklist. - auto schedule_worklist = [&]() { - while (!worklist.empty()) { - const HloInstruction* instruction = worklist.front(); - worklist.pop(); - new_sequence.push_back(instruction); - std::vector* new_users = - tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); - if (new_users != nullptr) { - // This just-scheduled instruction has users which are newly added to - // the module. Update the number of unscheduled operands and push the - // newly added instruction to the worklist if it is ready to - // schedule. - for (const HloInstruction* new_user : *new_users) { - unscheduled_operand_count.at(new_user)--; - CHECK_GE(unscheduled_operand_count.at(new_user), 0); - if (unscheduled_operand_count.at(new_user) == 0) { - worklist.push(new_user); - } - } - } - } - }; - - schedule_worklist(); - for (int id : sequences_.at(computation->unique_id()).ids()) { - auto it = id_to_instruction.find(id); - if (it == id_to_instruction.end()) { - // This instruction in the schedule is no longer in the module. Do not add - // it to the new schedule. - continue; - } - worklist.push(it->second); - schedule_worklist(); - } - - set_sequence(computation, std::move(new_sequence)); - return Status::OK(); -} - -Status HloSchedule::Update() { - // The schedule must contain a sequence for every non-fusion computation in - // the module, but can have sequences for computations which no longer exist - // (these are removed). - std::vector nonfusion_computations = - module_->MakeNonfusionComputations(); - for (const HloComputation* computation : nonfusion_computations) { - TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) - << "Computation " << computation->name() << " not in HloSchedule."; - } - if (sequences_.size() > nonfusion_computations.size()) { - // Schedule contains some computations which have been removed from the - // HloModule. Remove them from the schedule as well. - tensorflow::gtl::FlatSet nonfusion_computations_ids; - for (const HloComputation* computation : nonfusion_computations) { - nonfusion_computations_ids.insert(computation->unique_id()); - } - for (auto it = sequences_.begin(); it != sequences_.end();) { - if (nonfusion_computations_ids.count(it->first) == 0) { - it = sequences_.erase(it); - } else { - it++; - } - } - } - CHECK_EQ(sequences_.size(), nonfusion_computations.size()); - - for (const HloComputation* computation : nonfusion_computations) { - TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation)); - } - - TF_RETURN_IF_ERROR(Verify()); - return Status::OK(); -} - -Status HloSchedule::Verify() const { - VLOG(2) << "VerifySchedule()"; - XLA_VLOG_LINES(3, module_->ToString()); - XLA_VLOG_LINES(2, ToString()); - - // Verify schedule contains exactly the same set of non-fusion computations as - // module currently does. - std::vector nonfusion_computations = - module_->MakeNonfusionComputations(); - TF_RET_CHECK(nonfusion_computations.size() == sequences_.size()) - << "Schedule has " << sequences_.size() << " sequences, but module has " - << nonfusion_computations.size() << " non-fusion computations"; - for (const HloComputation* computation : nonfusion_computations) { - TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) - << "Computation " << computation->name() - << " missing from HLO schedule."; - } - - // For each computation verify the set of instructions is the same and that - // each dependency and control edge is honored. - for (const HloComputation* computation : nonfusion_computations) { - tensorflow::gtl::FlatMap instruction_position; - int pos = 0; - for (const HloInstruction* instruction : - sequence(computation).instructions()) { - TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) - << "Instruction " << instruction->name() - << " appears more than once in the schedule"; - pos++; - } - - TF_RET_CHECK(instruction_position.size() == - computation->instruction_count()); - for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(instruction_position.count(instruction) == 1) - << "Instruction " << instruction->name() << " is not in schedule"; - } - - for (const HloInstruction* instruction : computation->instructions()) { - for (const HloInstruction* operand : instruction->operands()) { - TF_RET_CHECK(instruction_position.at(operand) < - instruction_position.at(instruction)) - << "Instruction " << instruction->name() - << " is not scheduled after its operand " << operand->name(); - } - - for (const HloInstruction* pred : instruction->control_predecessors()) { - TF_RET_CHECK(instruction_position.at(pred) < - instruction_position.at(instruction)) - << "Instruction " << instruction->name() - << " is not scheduled after its control predecessor " - << pred->name(); - } - } - } - - return Status::OK(); -} - -namespace { - -// Returns the computation in the given module with the given unique ID. Returns -// nullptr if no such computation exists. -const HloComputation* IdToComputation(const HloModule* module, int64 id) { - for (const HloComputation* computation : module->computations()) { - if (computation->unique_id() == id) { - return computation; - } - } - return nullptr; -} - -} // namespace - -string HloSchedule::ToString() const { - std::vector pieces; - - pieces.push_back("HloSchedule"); - for (const auto& id_sequence : sequences_) { - const HloComputation* computation = - IdToComputation(module_, id_sequence.first); - if (computation == nullptr) { - // The computation is not in the module and may have been deleted so it is - // not safe to dereference any HLO pointers. Just use the HLO unique ids - // stored in this object. - pieces.push_back( - absl::StrFormat("computation with id %d (no longer in HLO module):", - id_sequence.first)); - for (int id : id_sequence.second.ids()) { - pieces.push_back(absl::StrCat(" ", id)); - } - } else { - pieces.push_back(absl::StrFormat("computation %s:", computation->name())); - for (const HloInstruction* instruction : - id_sequence.second.instructions()) { - pieces.push_back(absl::StrCat(" ", instruction->name())); - } - } - } - return absl::StrJoin(pieces, "\n"); -} - -std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) { - out << schedule.ToString(); - return out; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h deleted file mode 100644 index 21c6988638..0000000000 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ /dev/null @@ -1,151 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ - -#include - -#include "absl/types/span.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_schedule.h" -#include "tensorflow/compiler/xla/status.h" - -namespace xla { - -// Class representing a sequence of HLO instructions such as the sequential -// execution order of an HLO computation. -class HloInstructionSequence { - public: - HloInstructionSequence() = default; - HloInstructionSequence(absl::Span instructions) { - for (const HloInstruction* instruction : instructions) { - push_back(instruction); - } - } - - // Adds the instruction to the end of the sequence. - void push_back(const HloInstruction* instruction) { - instruction_sequence_.push_back(instruction); - id_sequence_.push_back(instruction->unique_id()); - } - - // Clears the sequence of all instructions. - void clear() { - instruction_sequence_.clear(); - id_sequence_.clear(); - } - - int64 size() const { return instruction_sequence_.size(); } - - // Returns the sequence of HLO instructions. - const std::vector& instructions() const { - return instruction_sequence_; - } - - // Returns the unique IDs of the instructions in the sequence (in order). - const std::vector& ids() const { return id_sequence_; } - - private: - // The sequence as HloInstructions. - std::vector instruction_sequence_; - - // The sequence of HLO instructions, represented by their unique IDs. The - // sequence is stored as both HloInstructions and unique IDs because the - // sequence may be referenced after transformations to the HLO graph and HLO - // pointers can be invalidated or recycled in this process (see - // HloSchedule::Update). - std::vector id_sequence_; -}; - -// A class representing a sequential schedule of instructions for an HLO -// module. A complete HLO schedule contains an instruction sequence for every -// non-fusion computation in the HLO module. -class HloSchedule { - public: - HloSchedule(const HloModule* module) : module_(module) {} - - // Returns a reference to the sequence for the given computation. - const HloInstructionSequence& sequence( - const HloComputation* computation) const; - - // Returns the sequence for the given computation. An empty sequence is - // created if none exists for the computation. - HloInstructionSequence& GetOrCreateSequence( - const HloComputation* computation); - - // Sets the sequence for the given computation to the given sequence. - void set_sequence(const HloComputation* computation, - absl::Span sequence); - void set_sequence(const HloComputation* computation, - HloInstructionSequence sequence); - - // Returns a map from HloComputation unique ID to instruction sequence. The - // map contains all sequences in the schedule. - const tensorflow::gtl::FlatMap& sequences() - const { - return sequences_; - } - - // Returns true if the schedule has a sequence for the given computation. - bool is_computation_scheduled(const HloComputation* computation) const { - return sequences_.count(computation->unique_id()) == 1; - } - - // Updates the schedule such that it is (again) a valid schedule for the - // module. This is used to update a schedule after the HLO module has been - // transformed in some way. In general, the only transformations to the module - // for which a schedule can be updated is the addition or removal of - // instructions and removal of computations. Updating the schedule after new - // dependencies between existing instructions in the module is not supported - // and may result in an error status returned. - // - // Instructions in the module which also exist in the given schedule will - // remain in the same order in the updated schedule. Instructions which exist - // in the module but not in the given schedule will be placed as early as - // possible in the updated schedule. - Status Update(); - - // Verifies that the given schedule is valid for the given module. - // Specifically, the schedule contains exactly the instructions in the - // non-fusion computations in the module and every dependency in the module is - // satisfied in the schedule. - Status Verify() const; - - string ToString() const; - - bool empty() const { return sequences_.empty(); } - - const HloModule* module() const { return module_; } - - private: - // Updates the instruction sequence for the given computation. - Status UpdateComputationSchedule(const HloComputation* computation); - - const HloModule* module_; - - // A map from computation unique ID to instruction sequence. Unique IDs are - // used rather than HloComputation pointers because HLO pointers are not - // unique across HLO transformations because pointers may be recycled. - tensorflow::gtl::FlatMap sequences_; -}; - -std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc deleted file mode 100644 index eb52582bb5..0000000000 --- a/tensorflow/compiler/xla/service/hlo_schedule_test.cc +++ /dev/null @@ -1,341 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_schedule.h" - -#include -#include - -#include "absl/algorithm/container.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/service/hlo_scheduling.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace xla { -namespace { - -class HloScheduleTest : public HloTestBase {}; - -TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) { - // Updating the schedule of an unchanged HLO module should not affect the - // schedule at all. - const string module_str = R"( -HloModule UpdateScheduleUnchanged - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - const std::vector& entry_schedule = - schedule.sequence(module->entry_computation()).instructions(); - - EXPECT_EQ(entry_schedule.size(), 6); - - TF_ASSERT_OK(schedule.Update()); - TF_ASSERT_OK(schedule.Verify()); - - EXPECT_EQ(entry_schedule, - schedule.sequence(module->entry_computation()).instructions()); -} - -TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) { - // Add some additional instructions to a module and verify the schedule can be - // updated. - const string module_str = R"( -HloModule UpdateScheduleWithNewInstructions - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - - HloComputation* entry = module->entry_computation(); - const Shape shape = entry->root_instruction()->shape(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kSubtract, constant, entry->root_instruction())); - entry->set_root_instruction(sub); - - auto in_schedule = [&](const HloInstruction* hlo) { - return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo); - }; - - EXPECT_EQ(schedule.sequence(entry).size(), 6); - EXPECT_FALSE(in_schedule(constant)); - EXPECT_FALSE(in_schedule(sub)); - - ASSERT_IS_NOT_OK(schedule.Verify()); - TF_ASSERT_OK(schedule.Update()); - TF_ASSERT_OK(schedule.Verify()); - - EXPECT_EQ(schedule.sequence(entry).size(), 8); - EXPECT_TRUE(in_schedule(constant)); - EXPECT_TRUE(in_schedule(sub)); -} - -TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) { - // Add and delete some instructions from a module and verify that the schedule - // can be updated successfully. - const string module_str = R"( -HloModule UpdateScheduleWithAddedAndDeletedInstruction - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - - // Set the entry root to some expression containing just a parameter and a - // constant. - HloComputation* entry = module->entry_computation(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - HloInstruction* new_root = entry->AddInstruction( - HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, - constant, entry->parameter_instruction(0))); - entry->set_root_instruction(new_root); - - // DCE should remove everything but the parameters and the newly added code. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(schedule.sequence(entry).size(), 6); - - ASSERT_IS_NOT_OK(schedule.Verify()); - TF_ASSERT_OK(schedule.Update()); - TF_ASSERT_OK(schedule.Verify()); - - EXPECT_EQ(schedule.sequence(entry).size(), 4); -} - -TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) { - // Completely replace a module with an entirely new set of instructions and - // verify that the schedule can be updated successfully. - const string module_str = R"( -HloModule UpdateScheduleWithCompletelyReplacedModule - -ENTRY main { - a = f32[] constant(42.0) - b = f32[] constant(123.0) - ROOT sum = f32[] add(a, b) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - - // Replace the entry computation with the negation of a constant. - HloComputation* entry = module->entry_computation(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kNegate, constant)); - entry->set_root_instruction(new_root); - - // DCE the old instructions. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(schedule.sequence(entry).size(), 3); - - ASSERT_IS_NOT_OK(schedule.Verify()); - TF_ASSERT_OK(schedule.Update()); - TF_ASSERT_OK(schedule.Verify()); - - EXPECT_EQ(schedule.sequence(entry).size(), 2); -} - -TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) { - // Create changes to more than one computation in an HLO module and verify - // that the schedule can be updated. - const string module_str = R"( -HloModule UpdateScheduleWithMultipleComputations - -%Body (param.1: (s32[], token[])) -> (s32[], token[]) { - %param.1 = (s32[], token[]) parameter(0) - %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 - %constant.1 = s32[] constant(1) - %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) - %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 - %after-all = token[] after-all(token[] %get-tuple-element.2) - ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) -} - -%Cond (param: (s32[], token[])) -> pred[] { - %param = (s32[], token[]) parameter(0) - %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 - %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) -} - -ENTRY %WhileLoop () -> s32[] { - %zero = s32[] constant(0) - %init_token = token[] after-all() - %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) - %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body - ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), - /*pointer_size=*/sizeof(void*)); - })); - - const HloInstruction* xla_while = - module->entry_computation()->root_instruction()->operand(0); - HloComputation* body = xla_while->while_body(); - HloComputation* cond = xla_while->while_condition(); - - // Negate the root of the cond. - cond->set_root_instruction(cond->AddInstruction( - HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kNot, cond->root_instruction()))); - - // Replace the body with a computation which just passes through its - // parameter. - body->set_root_instruction(body->parameter_instruction(0)); - - // DCE the dead code in the body. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(schedule.sequence(body).size(), 7); - EXPECT_EQ(schedule.sequence(cond).size(), 4); - - ASSERT_IS_NOT_OK(schedule.Verify()); - TF_ASSERT_OK(schedule.Update()); - TF_ASSERT_OK(schedule.Verify()); - - EXPECT_EQ(schedule.sequence(body).size(), 1); - EXPECT_EQ(schedule.sequence(cond).size(), 5); -} - -TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) { - // Remove computations from a module and verify the schedule can be updated. - const string module_str = R"( -HloModule UpdateScheduleWithMultipleComputations - -%Body (param.1: (s32[], token[])) -> (s32[], token[]) { - %param.1 = (s32[], token[]) parameter(0) - %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 - %constant.1 = s32[] constant(1) - %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) - %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 - %after-all = token[] after-all(token[] %get-tuple-element.2) - ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) -} - -%Cond (param: (s32[], token[])) -> pred[] { - %param = (s32[], token[]) parameter(0) - %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 - %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) -} - -ENTRY %WhileLoop () -> s32[] { - %zero = s32[] constant(0) - %init_token = token[] after-all() - %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) - %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body - ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), - /*pointer_size=*/sizeof(void*)); - })); - - HloInstruction* xla_while = - module->entry_computation()->root_instruction()->mutable_operand(0); - HloInstruction* init = xla_while->mutable_operand(0); - - // Replace the while with its init value. The conditional and body - // computations should then be dead. - TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init)); - - // DCE the dead code in the body. - HloDCE dce; - ASSERT_EQ(module->computation_count(), 3); - TF_ASSERT_OK(dce.Run(module.get()).status()); - ASSERT_EQ(module->computation_count(), 1); - - ASSERT_IS_NOT_OK(schedule.Verify()); - TF_ASSERT_OK(schedule.Update()); - TF_ASSERT_OK(schedule.Verify()); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 9bfb0af96c..0fc3b268c0 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -70,7 +70,7 @@ class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions // containing the given HLO computation. - static StatusOr Run( + static StatusOr> Run( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -229,8 +229,8 @@ class ListScheduler { return {BytesFreedIfScheduled(entry), entry.instruction->user_count()}; } - HloInstructionSequence CreateSchedule() { - HloInstructionSequence schedule; + std::vector CreateSchedule() { + std::vector schedule; // Populate the ready list with instructions which have no operands or // control predecessors. @@ -374,7 +374,7 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr ScheduleComputationHelper( +StatusOr> ScheduleComputationHelper( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -392,7 +392,7 @@ StatusOr ScheduleComputationHelper( } // namespace -StatusOr DFSMemoryScheduler( +StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -443,7 +443,7 @@ StatusOr DFSMemoryScheduler( // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a // tiebreaker by name for determinism. - HloInstructionSequence sequence; + std::vector sequence; FunctionVisitor visitor([&sequence](HloInstruction* hlo) { sequence.push_back(hlo); return Status::OK(); @@ -463,7 +463,7 @@ StatusOr DFSMemoryScheduler( return sequence; } // namespace xla -StatusOr ListMemoryScheduler( +StatusOr> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -473,16 +473,18 @@ StatusOr ListMemoryScheduler( memory_by_computation); } -StatusOr PostOrderMemoryScheduler( +StatusOr> PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation) { - return HloInstructionSequence(computation.MakeInstructionPostOrder()); + const auto& post_order = computation.MakeInstructionPostOrder(); + return std::vector{post_order.begin(), + post_order.end()}; } -StatusOr DefaultMemoryScheduler( +StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -497,7 +499,7 @@ StatusOr DefaultMemoryScheduler( // List wins for most of our benchmarks; postorder-based schedulers win for // some RNNs. TF_ASSIGN_OR_RETURN( - HloInstructionSequence list_sequence, + std::vector list_sequence, ListMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 list_memory, @@ -506,7 +508,7 @@ StatusOr DefaultMemoryScheduler( size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence, + TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 dfs_memory, @@ -516,7 +518,7 @@ StatusOr DefaultMemoryScheduler( VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); TF_ASSIGN_OR_RETURN( - HloInstructionSequence post_order_sequence, + std::vector post_order_sequence, PostOrderMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 post_order_memory, @@ -543,35 +545,32 @@ StatusOr DefaultMemoryScheduler( } } -StatusOr ScheduleModule( +StatusOr ScheduleComputationsInModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm) { - HloSchedule schedule(&module); + SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); tensorflow::gtl::FlatMap memory_by_computation; for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { - TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, + TF_ASSIGN_OR_RETURN(auto one_computation_sequence, ScheduleComputationHelper( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = HeapSimulator::MinimumMemoryForComputation( - *computation, computation_sequence, *points_to_analysis, + *computation, one_computation_sequence, *points_to_analysis, size_function, &memory_by_computation) .ValueOrDie(); - schedule.set_sequence(computation, std::move(computation_sequence)); + sequence[computation] = std::move(one_computation_sequence); } } - VLOG(1) << "Module schedule:\n" << schedule; - - TF_RETURN_IF_ERROR(schedule.Verify()); - - return std::move(schedule); + VLOG(1) << "Module schedule:\n" << sequence; + return sequence; } -StatusOr ScheduleComputation( +StatusOr> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); @@ -582,4 +581,187 @@ StatusOr ScheduleComputation( size_function, nullptr, empty_map); } +tensorflow::gtl::FlatMap> +ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) { + tensorflow::gtl::FlatMap> id_sequence; + for (const auto& computation_sequence : sequence) { + for (const HloInstruction* instruction : computation_sequence.second) { + id_sequence[computation_sequence.first].push_back( + instruction->unique_id()); + } + } + return id_sequence; +} + +Status UpdateSchedule( + const HloModule& module, + const tensorflow::gtl::FlatMap>& + id_sequence, + SequentialHloOrdering::HloModuleSequence* sequence) { + // Map from unique ID to HloInstruction pointer for instructions in the + // module. + tensorflow::gtl::FlatMap id_to_instruction; + // Set of all HloInstructions in the schedule. + tensorflow::gtl::FlatSet ids_in_schedule; + std::vector nonfusion_computations = + module.MakeNonfusionComputations(); + for (const HloComputation* computation : nonfusion_computations) { + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK( + id_to_instruction.insert({instruction->unique_id(), instruction}) + .second); + } + for (int id : id_sequence.at(computation)) { + ids_in_schedule.insert(id); + } + } + + // Map from HloInstruction X to newly added instructions (instruction is in + // module, but not in schedule) which use X. If an instruction is not in the + // map, then it has no users which are newly added instructions. + tensorflow::gtl::FlatMap> + new_instruction_uses; + + // For each newly added instruction, this is the count of the instruction's + // operands that have not yet been scheduled. When this value reaches zero, + // then the instruction may be placed in the schedule. + tensorflow::gtl::FlatMap + unscheduled_operand_count; + // For each computation, this is the set of newly added instructions which + // have no operands. These must be handled specially and are added to the + // beginning of the schedule. + tensorflow::gtl::FlatMap> + new_zero_operand_instructions; + for (const HloComputation* computation : nonfusion_computations) { + new_zero_operand_instructions[computation] = {}; + for (const HloInstruction* instruction : computation->instructions()) { + if (ids_in_schedule.count(instruction->unique_id()) == 0) { + // This is a newly added instruction which is not in the schedule. + for (const HloInstruction* operand : instruction->operands()) { + new_instruction_uses[operand].push_back(instruction); + } + if (instruction->operands().empty()) { + new_zero_operand_instructions[computation].push_back(instruction); + } + unscheduled_operand_count[instruction] = instruction->operand_count(); + } + } + } + + // Update the schedule with the newly added instructions, and remove any + // instructions no longer in the graph. + for (const HloComputation* computation : nonfusion_computations) { + std::vector old_computation_sequence = + std::move(sequence->at(computation)); + sequence->at(computation).clear(); + + // Create a worklist of newly added instructions which are ready to be added + // to the schedule. Initialize worklist with those that have zero operands. + std::queue worklist; + for (const HloInstruction* instruction : + new_zero_operand_instructions.at(computation)) { + worklist.push(instruction); + } + + // Lambda which schedules all instructions on the worklist. + auto schedule_worklist = [&]() { + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop(); + sequence->at(computation).push_back(instruction); + std::vector* new_users = + tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); + if (new_users != nullptr) { + // This just-scheduled instruction has users which are newly added to + // the module. Update the number of unscheduled operands and push the + // newly added instruction to the worklist if it is ready to + // schedule. + for (const HloInstruction* new_user : *new_users) { + unscheduled_operand_count.at(new_user)--; + CHECK_GE(unscheduled_operand_count.at(new_user), 0); + if (unscheduled_operand_count.at(new_user) == 0) { + worklist.push(new_user); + } + } + } + } + }; + + schedule_worklist(); + for (int id : id_sequence.at(computation)) { + auto it = id_to_instruction.find(id); + if (it == id_to_instruction.end()) { + // This instruction in the schedule is no longer in the module. + continue; + } + const HloInstruction* instruction = it->second; + worklist.push(instruction); + schedule_worklist(); + } + } + + TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence)); + return Status::OK(); +} + +Status VerifySchedule( + const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& sequence) { + VLOG(2) << "VerifySchedule()"; + XLA_VLOG_LINES(2, module.ToString()); + VLOG(2) << sequence; + + // Verify the set of computations in the sequence is exactly the set of + // computations in the module. + std::vector nonfusion_computations = + module.MakeNonfusionComputations(); + TF_RET_CHECK(nonfusion_computations.size() == sequence.size()); + tensorflow::gtl::FlatSet computations_in_module( + module.computations().begin(), module.computations().end()); + for (const auto& computation_sequence : sequence) { + TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1); + } + + // For each computation verify the set of instructions is the same and that + // each dependency and control edge is honored. + for (const HloComputation* computation : nonfusion_computations) { + tensorflow::gtl::FlatMap instruction_position; + int pos = 0; + for (const HloInstruction* instruction : sequence.at(computation)) { + TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) + << "Instruction " << instruction->name() + << " appears more than once in the schedule"; + pos++; + } + + TF_RET_CHECK(instruction_position.size() == + computation->instruction_count()); + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(instruction_position.count(instruction) == 1) + << "Instruction " << instruction->name() << " is not in schedule"; + } + + for (const HloInstruction* instruction : computation->instructions()) { + for (const HloInstruction* operand : instruction->operands()) { + TF_RET_CHECK(instruction_position.at(operand) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its operand " << operand->name(); + } + + for (const HloInstruction* pred : instruction->control_predecessors()) { + TF_RET_CHECK(instruction_position.at(pred) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its control predecessor " + << pred->name(); + } + } + } + + return Status::OK(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 54e32340ba..d06b8d9a5c 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -33,14 +32,14 @@ namespace xla { // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. -typedef std::function( +typedef std::function>( const HloComputation&, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, const tensorflow::gtl::FlatMap&)> MemorySchedulerAlgorithm; // List scheduler -StatusOr ListMemoryScheduler( +StatusOr> ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -48,7 +47,7 @@ StatusOr ListMemoryScheduler( memory_by_computation); // DFS-order scheduler -StatusOr DFSMemoryScheduler( +StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -56,7 +55,7 @@ StatusOr DFSMemoryScheduler( memory_by_computation); // Naive Post Order scheduler -StatusOr PostOrderMemoryScheduler( +StatusOr> PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -66,26 +65,63 @@ StatusOr PostOrderMemoryScheduler( // The default scheduling algorithm. Runs both the list scheduler // and the DFS scheduler, and chooses whichever returns a lower min-memory, // not accounting for fragmentation. -StatusOr DefaultMemoryScheduler( +StatusOr> DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation); -// Returns an HloSchedule which seeks to minimize the memory required for +// Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. -StatusOr ScheduleModule( +StatusOr ScheduleComputationsInModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm = {}); // Computes the schedule for a single computation. // Currently only used by the GPU backend. -StatusOr ScheduleComputation( +StatusOr> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); +// Transforms the given schedule such that it is (again) a valid schedule for +// the module. This is used to update a schedule after the HLO module has been +// transformed in some way. In general, the only transformations to the module +// for which a schedule can be updated is the addition or removal of +// instructions to/from the module. Updating the schedule after new dependencies +// between existing instructions in the module is not supported and may result +// in an error status returned. +// +// Instructions in the module which also exist in the given schedule will remain +// in the same order in the updated schedule. Instructions which exist in the +// module but not in the given schedule will be placed as early as possible in +// the updated schedule. +// +// 'id_sequence' is a mirror of the given schedule 'sequence' but with +// HloInstruction ids rather than HloInstruction pointers. This should be +// constructed using ComputeIdSchedule below after the schedule is constructed +// but before the HLO module is transformed. +Status UpdateSchedule( + const HloModule& module, + const tensorflow::gtl::FlatMap>& + id_sequence, + SequentialHloOrdering::HloModuleSequence* sequence); + +// Constructs a copy of the given schedule but with HloInstruction unique ids +// rather than HloInstruction pointers. This is necessary for updating a +// schedule as HloInstruction points in the schedule may become invalid if +// instructions are removed from the module. Used by UpdateSchedule above.. +// TODO(b/113175018): Remove this function when HLO schedule is its own class. +tensorflow::gtl::FlatMap> +ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence); + +// Verifies that the given schedule is valid for the given module. Specifically, +// the schedule contains exactly the instructions in the module and every +// dependency in the module is satisfied in the schedule. +Status VerifySchedule(const HloModule& module, + const SequentialHloOrdering::HloModuleSequence& sequence); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 6afe51997e..d49d09d459 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -68,20 +67,19 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, [](const BufferValue& buffer) { + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); // Verify that all instructions are in the sequence. - const std::vector& sequence = - schedule.sequence(module->entry_computation()).instructions(); - EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.front()); - EXPECT_EQ(sub, sequence.back()); + EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); + EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); - SequentialHloOrdering ordering(schedule); + SequentialHloOrdering ordering(module.get(), sequence); EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); } @@ -110,26 +108,28 @@ ENTRY root { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - const std::vector& sequence = - schedule.sequence(module->entry_computation()).instructions(); - EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); std::unordered_map instructions_by_name; - for (const HloInstruction* instruction : sequence) { + for (const HloInstruction* instruction : + sequence.at(module->entry_computation())) { instructions_by_name[instruction->name()] = instruction; } // The first instruction should be the parameter and the last the root. - EXPECT_EQ(instructions_by_name.at("param"), sequence.front()); - EXPECT_EQ(instructions_by_name.at("result"), sequence.back()); + EXPECT_EQ(instructions_by_name.at("param"), + sequence.at(module->entry_computation()).front()); + EXPECT_EQ(instructions_by_name.at("result"), + sequence.at(module->entry_computation()).back()); // Instructions "d" and "e" will both be schedulable at the same time, but // instruction "d" allows us to free the buffer of "p1", so the list scheduler // should prefer it. - SequentialHloOrdering ordering(schedule); + SequentialHloOrdering ordering(module.get(), sequence); EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), instructions_by_name.at("e"))); } @@ -220,13 +220,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); EXPECT_EQ(entry_computation->instruction_count(), - schedule.sequence(entry_computation).size()); - SequentialHloOrdering ordering(schedule); + sequence.at(entry_computation).size()); + SequentialHloOrdering ordering(module.get(), sequence); // This schedule is an example of List's greedy heuristics being suboptimal. // The while_loop is more expensive than transpose, so it would have been // better to schedule it first, instead of during the busy time. @@ -243,13 +243,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { // HeapSimulator doesn't account for subcomputations EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), + *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); // HeapSimulator accounts for subcomputations. The output buffer is aliased, // so we don't double count. EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), + *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } @@ -281,18 +281,19 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), - schedule.sequence(module->entry_computation()).size()); - SequentialHloOrdering ordering(schedule); + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); // tuple allocates the tuple buffer and doesn't free anything. // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. // abs_abs2 should be scheduled before tuple by List. @@ -331,18 +332,18 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { auto fusion = computation->CreateFusionInstruction( {tuple, mul, add}, HloInstruction::FusionKind::kLoop); - TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, - ScheduleModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), 2); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule( + *module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), 2); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), - schedule.sequence(module->entry_computation()).size()); - SequentialHloOrdering ordering(schedule); + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); // fusion allocates memory for the tuple elements and doesn't free anything, // so it's more expensive than exp. EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); @@ -390,12 +391,12 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; TF_ASSERT_OK_AND_ASSIGN( - HloSchedule schedule, - ScheduleModule(*module, size_fn, ListMemoryScheduler)); + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); - EXPECT_EQ(module->entry_computation()->instruction_count(), - schedule.sequence(module->entry_computation()).size()); + EXPECT_EQ(entry_computation->instruction_count(), + sequence.at(entry_computation).size()); tensorflow::gtl::FlatMap memory_by_computation; memory_by_computation[cond_computation] = 17; @@ -405,16 +406,262 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { // HeapSimulator doesn't account for subcomputations EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), + *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); // HeapSimulator accounts for subcomputations. Cond is the largest one. // The output buffer of the while is aliased. EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, schedule.sequence(entry_computation), + *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } +TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) { + // Updating the schedule of an unchanged HLO module should not affect the + // schedule at all. + const string module_str = R"( +HloModule UpdateScheduleUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + std::vector entry_schedule = sequence.begin()->second; + + EXPECT_EQ(entry_schedule.size(), 6); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(entry_schedule, sequence.begin()->second); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) { + // Add some additional instructions to a module and verify the schedule can be + // updated. + const string module_str = R"( +HloModule UpdateScheduleWithNewInstructions + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + + HloComputation* entry = module->entry_computation(); + const Shape shape = entry->root_instruction()->shape(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, constant, entry->root_instruction())); + entry->set_root_instruction(sub); + + auto in_schedule = [&](const HloInstruction* hlo) { + return std::find(sequence.at(entry).begin(), sequence.at(entry).end(), + hlo) != sequence.at(entry).end(); + }; + + EXPECT_EQ(sequence.at(entry).size(), 6); + EXPECT_FALSE(in_schedule(constant)); + EXPECT_FALSE(in_schedule(sub)); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(entry).size(), 8); + EXPECT_TRUE(in_schedule(constant)); + EXPECT_TRUE(in_schedule(sub)); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) { + // Add and delete some instructions from a module and verify that the schedule + // can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithAddedAndDeletedInstruction + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + + // Set the entry root to some expression containing just a parameter and a + // constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* new_root = entry->AddInstruction( + HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, + constant, entry->parameter_instruction(0))); + entry->set_root_instruction(new_root); + + // DCE should remove everything but the parameters and the newly added code. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(sequence.at(entry).size(), 6); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(entry).size(), 4); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) { + // Completely replace a module with an entirely new set of instructions and + // verify that the schedule can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithCompletelyReplacedModule + +ENTRY main { + a = f32[] constant(42.0) + b = f32[] constant(123.0) + ROOT sum = f32[] add(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + + // Replace the entry computation with the negation of a constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNegate, constant)); + entry->set_root_instruction(new_root); + + // DCE the old instructions. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(sequence.at(entry).size(), 3); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(entry).size(), 2); +} + +TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) { + // Create changes to more than one computation in an HLO module and verify + // that the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + tensorflow::gtl::FlatMap> + id_sequence = ComputeIdSchedule(sequence); + + const HloInstruction* xla_while = + module->entry_computation()->root_instruction()->operand(0); + HloComputation* body = xla_while->while_body(); + HloComputation* cond = xla_while->while_condition(); + + // Negate the root of the cond. + cond->set_root_instruction(cond->AddInstruction( + HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kNot, cond->root_instruction()))); + + // Replace the body with a computation which just passes through its + // parameter. + body->set_root_instruction(body->parameter_instruction(0)); + + // DCE the dead code in the body. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(sequence.at(body).size(), 7); + EXPECT_EQ(sequence.at(cond).size(), 4); + + TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); + TF_ASSERT_OK(VerifySchedule(*module, sequence)); + + EXPECT_EQ(sequence.at(body).size(), 1); + EXPECT_EQ(sequence.at(cond).size(), 5); +} + } // namespace } // namespace xla -- GitLab From 79dc7cd72654fdf9890d9ea5b7a9af15fa7d5d73 Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Wed, 5 Sep 2018 13:25:13 -0700 Subject: [PATCH 135/540] [tf.data]: Fix internal comment. PiperOrigin-RevId: 211687433 --- tensorflow/core/kernels/data/captured_function.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index ae6bdfc2a0..9526da22d1 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -50,8 +50,8 @@ class CapturedFunction { // Creates a new instance from a list of named attributes and captured inputs. // - // If `low_latency_hint` is true, the runtime may use an executor that is - // optimized for small functions. + // If `use_inter_op_parallelism` is false, the runtime may use an executor + // that is optimized for small functions. static Status Create(const NameAttrList& func, std::vector captured_inputs, bool use_inter_op_parallelism, -- GitLab From c9c8de440213355ea4a4d3577fd068d418678d38 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Wed, 5 Sep 2018 13:34:05 -0700 Subject: [PATCH 136/540] Change tags for estimator_test PiperOrigin-RevId: 211688974 --- tensorflow/python/estimator/BUILD | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 9fce172bee..f6ef6d8dcb 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -684,8 +684,10 @@ py_test( shard_count = 4, srcs_version = "PY2AND3", tags = [ + "manual", # b/112769036, b/113907597 + "no_oss", # b/112769036, b/113907597 "no_windows", - "notsan", + "notsan", # b/67510291 ], deps = [ ":keras", -- GitLab From 11caab3c138d06390344c88a4149f1897e3d780d Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 5 Sep 2018 13:50:20 -0700 Subject: [PATCH 137/540] [XLA] Make tensorflow/compiler use absl::{StrCat,string_view,InlinedVector} consistently StringPiece is an alias for absl::string_view, InlinedVector is aliased to absl::InlinedVector. StrCat is compatible, so swapping it out is safe. PiperOrigin-RevId: 211691840 --- tensorflow/compiler/aot/codegen.cc | 44 +++---- tensorflow/compiler/aot/codegen.h | 4 +- tensorflow/compiler/aot/codegen_test.cc | 2 +- .../compiler/aot/embedded_protocol_buffers.cc | 24 ++-- .../compiler/aot/embedded_protocol_buffers.h | 2 +- tensorflow/compiler/aot/tfcompile_main.cc | 7 +- tensorflow/compiler/jit/BUILD | 2 + tensorflow/compiler/jit/deadness_analysis.cc | 12 +- .../jit/encapsulate_subgraphs_pass.cc | 57 ++++----- .../jit/encapsulate_subgraphs_pass_test.cc | 120 +++++++++--------- tensorflow/compiler/jit/graphcycles/BUILD | 1 + .../compiler/jit/graphcycles/graphcycles.cc | 4 +- .../compiler/jit/mark_for_compilation_pass.cc | 40 +++--- .../jit/mark_for_compilation_pass_test.cc | 2 +- .../compiler/jit/partially_decluster_pass.cc | 14 +- .../jit/resource_operation_safety_analysis.cc | 6 +- tensorflow/compiler/jit/xla_cluster_util.cc | 9 +- tensorflow/compiler/jit/xla_cluster_util.h | 2 +- .../compiler/jit/xla_compilation_cache.cc | 6 +- tensorflow/compiler/jit/xla_device.cc | 5 +- tensorflow/compiler/jit/xla_device_context.cc | 4 +- tensorflow/compiler/jit/xla_device_context.h | 4 +- .../compiler/jit/xla_fusion_optimizer.cc | 3 +- tensorflow/compiler/jit/xla_tensor.h | 2 +- tensorflow/compiler/tests/BUILD | 1 + tensorflow/compiler/tests/randomized_tests.cc | 50 ++++---- tensorflow/compiler/tf2xla/BUILD | 5 +- tensorflow/compiler/tf2xla/dump_graph.cc | 8 +- .../compiler/tf2xla/functionalize_cond.cc | 36 +++--- .../tf2xla/functionalize_control_flow_util.cc | 2 +- .../tf2xla/functionalize_control_flow_util.h | 13 +- .../compiler/tf2xla/functionalize_while.cc | 6 +- tensorflow/compiler/tf2xla/graph_compiler.cc | 2 +- tensorflow/compiler/tf2xla/graph_compiler.h | 2 +- .../tf2xla/kernels/batchtospace_op.cc | 2 +- .../compiler/tf2xla/kernels/bcast_ops.cc | 4 +- .../tf2xla/kernels/depthtospace_op.cc | 2 +- .../compiler/tf2xla/kernels/pooling_ops.cc | 2 +- .../tf2xla/kernels/reduction_ops_common.cc | 4 +- .../compiler/tf2xla/kernels/reverse_op.cc | 2 +- .../compiler/tf2xla/kernels/shape_op.cc | 2 +- .../tf2xla/kernels/spacetobatch_op.cc | 2 +- .../tf2xla/kernels/spacetodepth_op.cc | 2 +- .../compiler/tf2xla/kernels/stack_ops.cc | 2 +- .../tf2xla/kernels/strided_slice_op.cc | 28 ++-- .../tf2xla/kernels/tensor_array_ops.cc | 2 +- .../compiler/tf2xla/kernels/transpose_op.cc | 2 +- tensorflow/compiler/tf2xla/lib/BUILD | 2 +- tensorflow/compiler/tf2xla/lib/while_loop.cc | 8 +- tensorflow/compiler/tf2xla/lib/while_loop.h | 6 +- .../tf2xla/resource_operation_table.cc | 22 ++-- .../tf2xla/resource_operation_table.h | 8 +- .../tf2xla/resource_operation_table_test.cc | 2 +- tensorflow/compiler/tf2xla/sharding_util.cc | 1 - tensorflow/compiler/tf2xla/tf2xla.cc | 10 +- tensorflow/compiler/tf2xla/tf2xla_util.cc | 10 +- tensorflow/compiler/tf2xla/tf2xla_util.h | 2 +- .../compiler/tf2xla/tf2xla_util_test.cc | 6 +- .../compiler/tf2xla/xla_compilation_device.cc | 11 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 19 ++- tensorflow/compiler/tf2xla/xla_context.cc | 1 - tensorflow/compiler/tf2xla/xla_op_kernel.cc | 26 ++-- tensorflow/compiler/tf2xla/xla_op_kernel.h | 26 ++-- tensorflow/compiler/tf2xla/xla_op_registry.cc | 20 +-- tensorflow/compiler/tf2xla/xla_op_registry.h | 16 +-- tensorflow/compiler/tf2xla/xla_resource.cc | 4 +- tensorflow/compiler/xla/service/BUILD | 2 + .../service/gpu/multi_output_fusion_test.cc | 10 +- tensorflow/compiler/xla/service/hlo_cse.cc | 2 +- .../service/while_loop_constant_sinking.cc | 2 +- tensorflow/compiler/xrt/BUILD | 1 + tensorflow/compiler/xrt/kernels/BUILD | 5 +- .../compiler/xrt/kernels/xrt_compile_ops.cc | 4 +- tensorflow/compiler/xrt/xrt_state.cc | 8 +- 74 files changed, 399 insertions(+), 392 deletions(-) diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 2b1ce34b37..b17bc658fa 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/types/span.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace tfcompile { @@ -135,12 +135,12 @@ Status AddRewritesForShape(int i, const xla::Shape& shape, indices = "[0]"; } else { for (int dim = 0; dim < shape.dimensions_size(); ++dim) { - dim_vars.push_back(strings::StrCat("size_t dim", dim)); - dim_sizes += strings::StrCat("[", shape.dimensions(dim), "]"); - indices += strings::StrCat("[dim", dim, "]"); + dim_vars.push_back(absl::StrCat("size_t dim", dim)); + dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]"); + indices += absl::StrCat("[dim", dim, "]"); } } - rewrites->push_back({"{{I}}", strings::StrCat(i)}); + rewrites->push_back({"{{I}}", absl::StrCat(i)}); rewrites->push_back({"{{TYPE}}", type}); rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")}); rewrites->push_back({"{{DIM_SIZES}}", dim_sizes}); @@ -194,7 +194,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, arg_data({{I}}))){{INDICES}}; } )"; - *methods += RewriteWithName(strings::StrCat(i), code, rewrites); + *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.feed(i).name().empty()) { *methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites); } @@ -235,7 +235,7 @@ Status GenResultMethods(const tf2xla::Config& config, result_data({{I}}))){{INDICES}}; } )"; - *methods += RewriteWithName(strings::StrCat(i), code, rewrites); + *methods += RewriteWithName(absl::StrCat(i), code, rewrites); if (!config.fetch(i).name().empty()) { *methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites); } @@ -304,8 +304,8 @@ std::vector BufferInfosToCppExpression( string encoded_second_as_str = encoded.second == ~0ULL ? "~0ULL" - : strings::StrCat(encoded.second, "ULL"); - return strings::StrCat( + : absl::StrCat(encoded.second, "ULL"); + return absl::StrCat( "::tensorflow::cpu_function_runtime::BufferInfo({", encoded.first, "ULL, ", encoded_second_as_str, "})"); }); @@ -352,13 +352,13 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, // Create rewrite strings for namespace start and end. string ns_start; for (const string& n : opts.namespaces) { - ns_start += strings::StrCat("namespace ", n, " {\n"); + ns_start += absl::StrCat("namespace ", n, " {\n"); } ns_start += "\n"; string ns_end("\n"); for (int i = opts.namespaces.size() - 1; i >= 0; --i) { const string& n = opts.namespaces[i]; - ns_end += strings::StrCat("} // end namespace ", n, "\n"); + ns_end += absl::StrCat("} // end namespace ", n, "\n"); } // Generate metadata. @@ -568,10 +568,10 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { )"; // The replacement strategy is naive, but good enough for our purposes. const std::vector> rewrites = { - {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)}, - {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, + {"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)}, + {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)}, {"{{ARG_NAMES_CODE}}", arg_names_code}, - {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())}, + {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())}, {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")}, {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size}, {"{{CLASS}}", opts.class_name}, @@ -590,11 +590,11 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", metadata_result.program_shape_access_shim}, - {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, + {"{{RESULT_INDEX}}", absl::StrCat(result_index)}, {"{{RESULT_NAMES_CODE}}", result_names_code}, - {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, - {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, - {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())}, + {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)}, + {"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)}, + {"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())}, {"{{BUFFER_INFOS_AS_STRING}}", absl::StrJoin(buffer_infos_as_strings, ",\n")}}; absl::StrReplaceAll(rewrites, header); @@ -602,13 +602,13 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { } static string CreateUniqueIdentifier(const CodegenOpts& opts, - StringPiece suffix) { + absl::string_view suffix) { string result = "__tfcompile"; for (const string& n : opts.namespaces) { - strings::StrAppend(&result, "_", n); + absl::StrAppend(&result, "_", n); } - strings::StrAppend(&result, "_", opts.class_name, "_", suffix); + absl::StrAppend(&result, "_", opts.class_name, "_", suffix); return result; } @@ -678,7 +678,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name, return Status::OK(); } -Status ValidateCppIdent(StringPiece ident, StringPiece msg) { +Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) { if (ident.empty()) { return errors::InvalidArgument("empty identifier: ", msg); } diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 83f2d3ee11..90410c46a8 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { namespace tfcompile { @@ -96,7 +96,7 @@ Status ParseCppClass(const string& cpp_class, string* class_name, // ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is // appended to error messages. -Status ValidateCppIdent(StringPiece ident, StringPiece msg); +Status ValidateCppIdent(absl::string_view ident, absl::string_view msg); } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index e3a53edb73..bb288d2300 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -19,11 +19,11 @@ limitations under the License. #include #include "absl/strings/match.h" +#include "absl/strings/string_view.h" #include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc index f1e8e5c084..3c32d533f6 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -38,11 +38,11 @@ using xla::llvm_ir::AsStringRef; static void AddEmbeddedProtocolBufferToLlvmModule( llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto, - StringPiece unique_identifier, string* protobuf_array_symbol_name, + absl::string_view unique_identifier, string* protobuf_array_symbol_name, int64* protobuf_array_size) { string protobuf_array_contents = proto.SerializeAsString(); *protobuf_array_symbol_name = - strings::StrCat(unique_identifier, "_protobuf_array_contents"); + absl::StrCat(unique_identifier, "_protobuf_array_contents"); *protobuf_array_size = protobuf_array_contents.size(); llvm::Constant* protobuf_array_initializer = @@ -55,9 +55,9 @@ static void AddEmbeddedProtocolBufferToLlvmModule( protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name)); } -static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, - StringPiece protobuf_array_symbol_name, - int64 protobuf_array_size) { +static string CreateCPPShimExpression( + absl::string_view qualified_cpp_protobuf_name, + absl::string_view protobuf_array_symbol_name, int64 protobuf_array_size) { string code = "[]() {\n" " {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n" @@ -68,9 +68,9 @@ static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, return absl::StrReplaceAll( code, { - {"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)}, - {"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)}, - {"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)}, + {"{{ARRAY_SYMBOL}}", absl::StrCat(protobuf_array_symbol_name)}, + {"{{ARRAY_SIZE}}", absl::StrCat(protobuf_array_size)}, + {"{{PROTOBUF_NAME}}", absl::StrCat(qualified_cpp_protobuf_name)}, }); } @@ -93,7 +93,7 @@ static StatusOr CodegenModule(llvm::TargetMachine* target_machine, } static StatusOr> -GetTargetMachineFromTriple(StringPiece target_triple) { +GetTargetMachineFromTriple(absl::string_view target_triple) { std::string error; std::string normalized_triple = llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple))); @@ -110,7 +110,7 @@ GetTargetMachineFromTriple(StringPiece target_triple) { } StatusOr CreateEmbeddedProtocolBuffers( - StringPiece target_triple, + absl::string_view target_triple, absl::Span protobufs_to_embed) { TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, GetTargetMachineFromTriple(target_triple)); @@ -135,8 +135,8 @@ StatusOr CreateEmbeddedProtocolBuffers( protobuf_to_embed.qualified_cpp_protobuf_name, protobuf_array_symbol_name, protobuf_array_size); - cpp_variable_decl = strings::StrCat("extern \"C\" char ", - protobuf_array_symbol_name, "[];"); + cpp_variable_decl = + absl::StrCat("extern \"C\" char ", protobuf_array_symbol_name, "[];"); } else { cpp_shim = "nullptr"; } diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h index 4f940c0197..cf5c04ac4b 100644 --- a/tensorflow/compiler/aot/embedded_protocol_buffers.h +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -83,7 +83,7 @@ struct ProtobufToEmbed { // is stored in the object_file_data field in the returned // EmbeddedProtocolBuffers instance. StatusOr CreateEmbeddedProtocolBuffers( - StringPiece target_triple, + absl::string_view target_triple, absl::Span protobufs_to_embed); } // namespace tfcompile diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index f3c44e9dda..b95b063348 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" @@ -92,8 +92,9 @@ Status Main(const MainFlags& flags) { // Write output files. Env* env = Env::Default(); const std::vector& obj = compile_result.aot->object_file_data(); - TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object, - StringPiece(obj.data(), obj.size()))); + TF_RETURN_IF_ERROR( + WriteStringToFile(env, flags.out_function_object, + absl::string_view(obj.data(), obj.size()))); CodegenOpts codegen_opts; codegen_opts.gen_name_to_index = flags.gen_name_to_index; codegen_opts.gen_program_shape = flags.gen_program_shape; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index df81f3c23e..de7cd26d1d 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -410,6 +410,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -566,6 +567,7 @@ cc_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 82aa03810b..9128b48da3 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -154,7 +154,7 @@ class AndPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); + return absl::StrCat("(", absl::StrJoin(operands_str, " & "), ")"); } Kind kind() const override { return Kind::kAnd; } @@ -185,7 +185,7 @@ class OrPredicate : public Predicate { std::back_inserter(operands_str), [](Predicate* pred) { return pred->ToString(); }); - return strings::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); + return absl::StrCat("(", absl::StrJoin(operands_str, " | "), ")"); } Kind kind() const override { return Kind::kOr; } @@ -206,7 +206,7 @@ class NotPredicate : public Predicate { operands_({operand}) {} string ToString() const override { - return strings::StrCat("~", operand()->ToString()); + return absl::StrCat("~", operand()->ToString()); } Kind kind() const override { return Kind::kNot; } @@ -240,8 +240,8 @@ class AndRecurrencePredicate : public Predicate { Predicate* step() const { return operands_[1]; } string ToString() const override { - return strings::StrCat("{", start()->ToString(), ",&,", step()->ToString(), - "}"); + return absl::StrCat("{", start()->ToString(), ",&,", step()->ToString(), + "}"); } Kind kind() const override { return Kind::kAndRecurrence; } @@ -267,7 +267,7 @@ class SymbolPredicate : public Predicate { must_be_true_(must_be_true) {} string ToString() const override { - return must_be_true() ? strings::StrCat("*", tensor_id_.ToString()) + return must_be_true() ? absl::StrCat("*", tensor_id_.ToString()) : tensor_id_.ToString(); } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 2788102620..ae7a22f451 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/device_name_utils.h" @@ -755,7 +755,7 @@ Status Encapsulator::Subgraph::RecordArg( if (inserted) { NodeDef arg_def; NodeDefBuilder builder( - strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + absl::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); @@ -790,7 +790,7 @@ Status Encapsulator::Subgraph::RecordResult( if (inserted) { NodeDef ret_def; NodeDefBuilder builder( - strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + absl::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); DataType dtype = src_node->output_type(src_slot); builder.Attr("T", dtype); builder.Attr("index", ret_index); @@ -950,16 +950,15 @@ Status Encapsulator::Subgraph::AddHostComputes( } NodeDef host_compute_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", - oc_subgraph_name, "_host_compute"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", + oc_subgraph_name, "_host_compute"), kHostComputeOp); builder.Input(inputs); builder.Attr("Tinputs", input_dtypes); builder.Attr("Toutputs", output_dtypes); builder.Attr("ancestors", host_compute_ancestors); - builder.Attr("key", - strings::StrCat("host_compute_channel_", subgraph_name, "_", - oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); builder.Attr("_outside_compilation_subgraph", oc_subgraph_name); Status s = builder.Finalize(&host_compute_def); if (!s.ok()) return s; @@ -1017,8 +1016,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; - NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), - "NoOp"); + NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp"); builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); builder.Device(device_); Status s = builder.Finalize(&seq_def); @@ -1091,10 +1089,10 @@ Status Encapsulator::Subgraph::BuildFunctionDef( if (VLOG_IS_ON(1)) { VLOG(2) << "Build function def " << name; - dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library); - dump_graph::DumpFunctionDefToFile( - strings::StrCat("encapsulate_fdef_", name), fdef); + dump_graph::DumpGraphToFile(absl::StrCat("encapsulate_fdef_graph_", name), + *graph_, library); + dump_graph::DumpFunctionDefToFile(absl::StrCat("encapsulate_fdef_", name), + fdef); } if (!reuse_existing_functions || library->Find(name) == nullptr) { @@ -1130,8 +1128,8 @@ Status Encapsulator::Subgraph::AddShapeInferenceInfo( host_compute->AddAttr("shapes", shapes); } else { string inference_graph_name = - strings::StrCat("_outside_compilation_shape_inference_", subgraph_name, - "_", outside_compilation_subgraph_name); + absl::StrCat("_outside_compilation_shape_inference_", subgraph_name, + "_", outside_compilation_subgraph_name); FunctionDef fdef; TF_RETURN_IF_ERROR( GraphToFunctionDef(*inference_graph, inference_graph_name, &fdef)); @@ -1155,10 +1153,10 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( if (VLOG_IS_ON(1)) { VLOG(2) << "Replace function def " << name; dump_graph::DumpGraphToFile( - strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, + absl::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, library); dump_graph::DumpFunctionDefToFile( - strings::StrCat("replace_encapsulate_fdef_", name), fdef); + absl::StrCat("replace_encapsulate_fdef_", name), fdef); } TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); @@ -1186,8 +1184,7 @@ Status Encapsulator::Subgraph::AddHostComputeKeyPlaceholder( GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); NodeDef key_def; NodeDefBuilder builder( - strings::StrCat(call_node_def_.name(), "_key_placeholder"), - "Placeholder"); + absl::StrCat(call_node_def_.name(), "_key_placeholder"), "Placeholder"); builder.Attr("dtype", DT_STRING); builder.Attr("shape", shape_proto); builder.Attr("_host_compute_call_node", call_node_def_.name()); @@ -1221,16 +1218,16 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode( } NodeDef recv_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, - "_", oc_subgraph_name, "_recv"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_recv"), kRecvAtHostOp); builder.Device(device_); builder.Attr("Toutputs", dtypes); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. builder.Attr("device_ordinal", 0); - builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, - "_", oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); builder.Attr(group_attribute, subgraph_name); builder.Attr(outside_compilation_attribute, oc_subgraph_name); builder.Input(host_compute_key_placeholder_->name(), 0, DT_STRING); @@ -1276,13 +1273,13 @@ Status Encapsulator::Subgraph::AddSendFromHostNode( } NodeDef send_def; - NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, - "_", oc_subgraph_name, "_send"), + NodeDefBuilder builder(absl::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_send"), kSendFromHostOp); builder.Device(device_); builder.Attr("Tinputs", dtypes); - builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, - "_", oc_subgraph_name)); + builder.Attr("key", absl::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); // The correct device_ordinal will be inserted during replication in a // subsequent rewrite. builder.Attr("device_ordinal", 0); @@ -1516,7 +1513,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { // Dump subgraphs. for (auto& entry : subgraphs_) { dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_subgraphs_subgraph_", entry.first), + absl::StrCat("encapsulate_subgraphs_subgraph_", entry.first), *entry.second.GetGraph(), library); } } @@ -2052,7 +2049,7 @@ struct PathDetails { struct SubgraphAndClusterHash { inline std::size_t operator()(const SubgraphAndCluster& v) const { return hash()( - strings::StrCat(v.subgraph, v.outside_compilation_cluster)); + absl::StrCat(v.subgraph, v.outside_compilation_cluster)); } }; diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 7bc0ef0303..49958093b8 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include "absl/strings/match.h" @@ -48,7 +49,7 @@ Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder, FunctionDef* fdef = library->add_function(); TF_RETURN_IF_ERROR(GraphToFunctionDef( *graph, - strings::StrCat("_outside_compilation_shape_inference_", name_suffix), + absl::StrCat("_outside_compilation_shape_inference_", name_suffix), fdef)); return Status::OK(); } @@ -65,18 +66,18 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, const auto iter = b.find(elt_a.first); if (iter == b.end()) { if (diff) { - *diff = strings::StrCat( - map_name, " expected: contains element with key '", - key_to_string(elt_a.first), "' got: map has no such element"); + *diff = absl::StrCat(map_name, " expected: contains element with key '", + key_to_string(elt_a.first), + "' got: map has no such element"); } return false; } if (!compare(elt_a.first, elt_a.second, iter->second)) { if (diff) { - *diff = strings::StrCat(map_name, " expected: element with key '", - key_to_string(elt_a.first), "' has value '", - value_to_string(elt_a.second), "' got: '", - value_to_string(iter->second), "'"); + *diff = absl::StrCat(map_name, " expected: element with key '", + key_to_string(elt_a.first), "' has value '", + value_to_string(elt_a.second), "' got: '", + value_to_string(iter->second), "'"); } return false; } @@ -85,9 +86,9 @@ bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, const auto iter = a.find(elt_b.first); if (iter == a.end()) { if (diff) { - *diff = strings::StrCat(map_name, " got: contains element with key '", - key_to_string(elt_b.first), - "' expected: map has no such element"); + *diff = absl::StrCat(map_name, " got: contains element with key '", + key_to_string(elt_b.first), + "' expected: map has no such element"); } return false; } @@ -99,25 +100,25 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, const string& diff_preamble, string* diff) { if (a.op() != b.op()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected op '", a.op(), "' got '", b.op()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected op '", a.op(), "' got '", b.op()); } return false; } if (a.device() != b.device()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected device '", a.device(), "' got '", - b.device()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected device '", a.device(), "' got '", + b.device()); } return false; } if (a.input_size() != b.input_size()) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - ", expected ", a.input_size(), " inputs got ", - b.input_size(), " expected:\n", a.DebugString(), - "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected ", a.input_size(), " inputs got ", + b.input_size(), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); } return false; } @@ -127,10 +128,10 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, if (absl::StartsWith(a.input(i), "^")) { if (!absl::StartsWith(b.input(i), "^")) { if (diff) { - *diff = strings::StrCat( - diff_preamble, " mismatch for node ", a.name(), " input ", i, - ", expected control input ", a.input(i), " got ", b.input(i), - " expected:\n", a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected control input ", + a.input(i), " got ", b.input(i), " expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); } return false; } @@ -138,19 +139,19 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, control_input_b.insert(b.input(i)); } else if (a.input(i) != b.input(i)) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - " input ", i, ", expected ", a.input(i), - " got ", b.input(i), " expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected ", a.input(i), " got ", + b.input(i), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); } return false; } } if (control_input_a != control_input_b) { if (diff) { - *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), - " control inputs differ expected:\n", - a.DebugString(), "\ngot:\n", b.DebugString()); + *diff = absl::StrCat(diff_preamble, " mismatch for node ", a.name(), + " control inputs differ expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); } return false; } @@ -170,18 +171,17 @@ bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, return av.DebugString() == bv.DebugString(); } }, - strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), - diff); + absl::StrCat(diff_preamble, " attr mismatch for node ", a.name()), diff); } bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, string* diff) { if (a.signature().DebugString() != b.signature().DebugString()) { if (diff) { - *diff = strings::StrCat("Signature mismatch for function ", - a.signature().name(), ", expected:\n", - a.signature().DebugString(), "\ngot:\n", - b.signature().DebugString()); + *diff = + absl::StrCat("Signature mismatch for function ", a.signature().name(), + ", expected:\n", a.signature().DebugString(), "\ngot:\n", + b.signature().DebugString()); } return false; } @@ -191,7 +191,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, [](const string& key, const AttrValue& av, const AttrValue& bv) { return av.DebugString() == bv.DebugString(); }, - strings::StrCat("attr mismatch for function ", a.signature().name()), + absl::StrCat("attr mismatch for function ", a.signature().name()), diff)) { return false; } @@ -201,7 +201,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, [](const string& key, const string& av, const string& bv) { return av == bv; }, - strings::StrCat("ret mismatch for function ", a.signature().name()), + absl::StrCat("ret mismatch for function ", a.signature().name()), diff)) { return false; } @@ -211,7 +211,7 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, if (a.node_def(i).name() == b.node_def(j).name()) { if (!EqualFunctionNodeDef( a.node_def(i), b.node_def(j), - strings::StrCat("Function ", a.signature().name()), diff)) { + absl::StrCat("Function ", a.signature().name()), diff)) { return false; } found = true; @@ -220,9 +220,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } if (!found) { if (diff) { - *diff = strings::StrCat("Function ", a.signature().name(), - ", expected: has node '", a.node_def(i).name(), - "' got: no node of that name"); + *diff = absl::StrCat("Function ", a.signature().name(), + ", expected: has node '", a.node_def(i).name(), + "' got: no node of that name"); } return false; } @@ -237,9 +237,9 @@ bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, } if (!found) { if (diff) { - *diff = strings::StrCat("Function ", a.signature().name(), - ", got: has node '", b.node_def(i).name(), - "' expected: no node of that name"); + *diff = absl::StrCat("Function ", a.signature().name(), + ", got: has node '", b.node_def(i).name(), + "' expected: no node of that name"); } return false; } @@ -258,8 +258,8 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, auto it = actual_index.find(expected_function.signature().name()); if (it == actual_index.end()) { if (diff) { - *diff = strings::StrCat("Did not find expected function '", - expected_function.signature().name(), "'"); + *diff = absl::StrCat("Did not find expected function '", + expected_function.signature().name(), "'"); } return false; } @@ -269,9 +269,9 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, if (!actual_index.empty()) { if (diff != nullptr) { - *diff = strings::StrCat("Found unexpected function '", - actual_index.begin()->second->signature().name(), - "'"); + *diff = + absl::StrCat("Found unexpected function '", + actual_index.begin()->second->signature().name(), "'"); } return false; } @@ -420,10 +420,9 @@ Node* RecvAtHost(ops::NodeOut key_input, const string& cluster, const string& oc_cluster, absl::Span dtypes, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = - strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); - string name = strings::StrCat("outside_compilation_", cluster, "_", - oc_cluster, "_recv"); + string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = + absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_recv"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaRecvAtHost"), "_XlaRecvAtHost", opts.op_registry()); node_builder.Input(std::move(key_input)); @@ -440,10 +439,9 @@ Node* SendFromHost(ops::NodeOut key_input, const string& cluster, const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; - string key = - strings::StrCat("host_compute_channel_", cluster, "_", oc_cluster); - string name = strings::StrCat("outside_compilation_", cluster, "_", - oc_cluster, "_send"); + string key = absl::StrCat("host_compute_channel_", cluster, "_", oc_cluster); + string name = + absl::StrCat("outside_compilation_", cluster, "_", oc_cluster, "_send"); NodeBuilder node_builder(opts.WithName(name).GetNameForOp("_XlaSendFromHost"), "_XlaSendFromHost", opts.op_registry()); node_builder.Input(inputs); @@ -682,8 +680,8 @@ std::vector> GraphEdges(const Graph& graph) { for (const Edge* edge : graph.edges()) { if (edge->src()->IsSource() || edge->dst()->IsSink()) continue; edges.emplace_back( - strings::StrCat(edge->src()->name(), ":", edge->src_output()), - strings::StrCat(edge->dst()->name(), ":", edge->dst_input())); + absl::StrCat(edge->src()->name(), ":", edge->src_output()), + absl::StrCat(edge->dst()->name(), ":", edge->dst_input())); } std::sort(edges.begin(), edges.end()); return edges; diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD index 676f71a75a..8212956adf 100644 --- a/tensorflow/compiler/jit/graphcycles/BUILD +++ b/tensorflow/compiler/jit/graphcycles/BUILD @@ -14,6 +14,7 @@ cc_library( hdrs = ["graphcycles.h"], deps = [ "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", ], ) diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc index 805bbc62c1..756377bd95 100644 --- a/tensorflow/compiler/jit/graphcycles/graphcycles.cc +++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc @@ -34,7 +34,7 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -44,7 +44,7 @@ namespace { typedef std::unordered_set NodeSet; template struct VecStruct { - typedef gtl::InlinedVector type; + typedef absl::InlinedVector type; }; template using Vec = typename VecStruct::type; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 4e4abade32..44caf0be52 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -43,7 +43,6 @@ limitations under the License. #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/public/version.h" @@ -617,7 +616,7 @@ Status MarkForCompilationPass::Run( } static string RatioToString(int numerator, int denominator) { - return strings::Printf("%d / %d (%.2f%%)", numerator, denominator, + return absl::StrFormat("%d / %d (%.2f%%)", numerator, denominator, (100.0 * numerator) / denominator); } @@ -626,14 +625,14 @@ static void VLogClusteringSummary(const Graph& g) { return; } - std::map cluster_name_to_size; - std::map> + std::map cluster_name_to_size; + std::map> cluster_name_to_op_histogram; - std::map unclustered_op_histogram; + std::map unclustered_op_histogram; int clustered_node_count = 0; for (Node* n : g.nodes()) { - absl::optional cluster_name = GetXlaClusterForNode(*n); + absl::optional cluster_name = GetXlaClusterForNode(*n); if (cluster_name) { clustered_node_count++; cluster_name_to_size[*cluster_name]++; @@ -650,7 +649,7 @@ static void VLogClusteringSummary(const Graph& g) { << RatioToString(clustered_node_count, g.num_nodes()); for (const auto& cluster_name_size_pair : cluster_name_to_size) { - StringPiece cluster_name = cluster_name_size_pair.first; + absl::string_view cluster_name = cluster_name_size_pair.first; int size = cluster_name_size_pair.second; VLOG(2) << " " << cluster_name << " " << RatioToString(size, g.num_nodes()); @@ -670,14 +669,15 @@ static void VLogClusteringSummary(const Graph& g) { } struct EdgeInfo { - StringPiece node_name; - absl::optional cluster_name; + absl::string_view node_name; + absl::optional cluster_name; - StringPiece GetClusterName() const { + absl::string_view GetClusterName() const { return cluster_name ? *cluster_name : "[none]"; } - std::pair> AsPair() const { + std::pair> AsPair() + const { return {node_name, cluster_name}; } @@ -686,19 +686,21 @@ static void VLogClusteringSummary(const Graph& g) { } }; - using EdgeInfoMap = std::map>; + using EdgeInfoMap = std::map>; EdgeInfoMap incoming_edge_infos; EdgeInfoMap outgoing_edge_infos; - std::set cluster_names_to_print; + std::set cluster_names_to_print; for (const Edge* e : g.edges()) { const Node* from = e->src(); - absl::optional from_cluster_name = GetXlaClusterForNode(*from); + absl::optional from_cluster_name = + GetXlaClusterForNode(*from); const Node* to = e->dst(); - absl::optional to_cluster_name = GetXlaClusterForNode(*to); + absl::optional to_cluster_name = + GetXlaClusterForNode(*to); if (to_cluster_name == from_cluster_name) { continue; @@ -721,9 +723,9 @@ static void VLogClusteringSummary(const Graph& g) { VLOG(2) << " [none]"; } - auto print_edge_info_set_for_cluster = [&](StringPiece cluster_name, + auto print_edge_info_set_for_cluster = [&](absl::string_view cluster_name, const EdgeInfoMap& edge_info_map, - StringPiece desc) { + absl::string_view desc) { auto it = edge_info_map.find(cluster_name); if (it != edge_info_map.end()) { VLOG(2) << " " << it->second.size() << " " << desc << " edges"; @@ -737,7 +739,7 @@ static void VLogClusteringSummary(const Graph& g) { } }; - for (StringPiece cluster_name : cluster_names_to_print) { + for (absl::string_view cluster_name : cluster_names_to_print) { VLOG(2) << " ** Cluster " << cluster_name; print_edge_info_set_for_cluster(cluster_name, incoming_edge_infos, "incoming"); @@ -966,7 +968,7 @@ Status MarkForCompilationPass::RunImpl( string& name = cluster_names[cluster]; if (name.empty()) { - name = strings::StrCat("cluster_", cluster_sequence_num++); + name = absl::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 807ab51fd3..9473ac0a4c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -633,7 +633,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { std::unique_ptr graph(new Graph(OpRegistry::Global())); Scope root = Scope::NewRootScope().ExitOnError(); { - auto BuildNoopNode = [](StringPiece name, Graph* graph) { + auto BuildNoopNode = [](absl::string_view name, Graph* graph) { NodeDefBuilder builder(name, "NoOp"); NodeDef def; TF_CHECK_OK(builder.Finalize(&def)); diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index a8f09bfa50..584c963f71 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/partially_decluster_pass.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -30,7 +31,7 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, MemoryTypeVector input_mtypes, output_mtypes; for (Node* n : post_order) { - absl::optional from_cluster = GetXlaClusterForNode(*n); + absl::optional from_cluster = GetXlaClusterForNode(*n); if (!from_cluster) { continue; } @@ -79,7 +80,7 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, // Check if `dst` is in a different cluster, unclustered, or about to be // partially declustered (here we rely on the post-order traversal order). // If yes, decluster `n` to avoid the device-to-host memcpy. - absl::optional dst_cluster = + absl::optional dst_cluster = result->count(dst) ? absl::nullopt : GetXlaClusterForNode(*dst); if (from_cluster != dst_cluster) { CHECK(result->insert(n).second); @@ -91,15 +92,16 @@ Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet* result, } Status PartiallyDeclusterNode(Graph* graph, Node* n) { - StringPiece cluster_name = *GetXlaClusterForNode(*n); - gtl::InlinedVector out_edges_to_clone; + absl::string_view cluster_name = *GetXlaClusterForNode(*n); + absl::InlinedVector out_edges_to_clone; for (const Edge* out_edge : n->out_edges()) { if (out_edge->IsControlEdge()) { continue; } Node* dst = out_edge->dst(); - absl::optional dst_cluster_name = GetXlaClusterForNode(*dst); + absl::optional dst_cluster_name = + GetXlaClusterForNode(*dst); if (dst_cluster_name != cluster_name) { out_edges_to_clone.push_back(out_edge); } @@ -108,7 +110,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { CHECK(!out_edges_to_clone.empty()) << n->DebugString(); NodeDef ndef = n->def(); - ndef.set_name(strings::StrCat(n->name(), "/declustered")); + ndef.set_name(absl::StrCat(n->name(), "/declustered")); RemoveFromXlaCluster(&ndef); Status s; Node* cloned_node = graph->AddNode(ndef, &s); diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 1ba4a5ef73..56e35c0059 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -165,7 +165,7 @@ bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { using ResourceOp = std::pair; string ResourceOpToString(const ResourceOp& resource_op) { - return strings::StrCat( + return absl::StrCat( resource_op.first, ": ", XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); } @@ -257,11 +257,11 @@ string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { std::vector elements_debug_string; std::transform(resource_op_set.begin(), resource_op_set.end(), std::back_inserter(elements_debug_string), ResourceOpToString); - return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); + return absl::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); } string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { - return strings::StrCat( + return absl::StrCat( "[", n.name(), ": ", n.type_string(), "(", XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); } diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 4f2fabd658..03380e9406 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/control_flow.h" @@ -52,8 +53,8 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, }; string description; - strings::StrAppend(&description, "Edge from ", node_name(src), " to ", - node_name(dst), " would create a cycle.\n"); + absl::StrAppend(&description, "Edge from ", node_name(src), " to ", + node_name(dst), " would create a cycle.\n"); path.resize(path_size); for (int32 node_id : path) { string ascii_art; @@ -64,7 +65,7 @@ string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, } else { ascii_art = "+-- "; } - strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); + absl::StrAppend(&description, ascii_art, node_name(node_id), "\n"); } return description; } @@ -186,7 +187,7 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { return Status::OK(); } -absl::optional GetXlaClusterForNode(const Node& node) { +absl::optional GetXlaClusterForNode(const Node& node) { const AttrValue* attr_value = node.attrs().Find(kXlaClusterAttr); if (attr_value == nullptr) { return absl::nullopt; diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h index b0439a63ca..17ae510a0e 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.h +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -47,7 +47,7 @@ Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); // Returns the XLA cluster in which `node` is placed if it is in an XLA cluster, // otherwise returns nullopt. -absl::optional GetXlaClusterForNode(const Node& node); +absl::optional GetXlaClusterForNode(const Node& node); // Removes `node_def` its XLA cluster (by clearing its _XlaCluster attribute). void RemoveFromXlaCluster(NodeDef* node_def); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index dcb0b3240a..3aa9e9c7ed 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -67,12 +67,12 @@ string XlaCompilationCache::DebugString() { string XlaCompilationCache::SignatureDebugString(const Signature& sig) { string result = sig.name; for (const auto& a : sig.arg_types) { - strings::StrAppend(&result, ",", DataTypeString(a.first), - a.second.DebugString()); + absl::StrAppend(&result, ",", DataTypeString(a.first), + a.second.DebugString()); } for (const auto& v : sig.arg_values) { - strings::StrAppend(&result, "; ", v.DebugString()); + absl::StrAppend(&result, "; ", v.DebugString()); } return result; } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index f31879a2bc..51797def04 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -148,10 +148,9 @@ Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { } const DeviceAttributes attrs = Device::BuildDeviceAttributes( - strings::StrCat(name_prefix, "/device:", device_name, ":", - device_ordinal), + absl::StrCat(name_prefix, "/device:", device_name, ":", device_ordinal), DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), - strings::StrCat("device: ", device_name, " device")); + absl::StrCat("device: ", device_name, " device")); device->reset( new XlaDevice(options, attrs, device_ordinal, DeviceType(jit_device_name), diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index ee07c5c964..af83c792e5 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -203,7 +203,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { @@ -339,7 +339,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) { manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 2e7445340c..df82421294 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -57,7 +57,7 @@ class XlaTransferManager { void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done); void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, @@ -111,7 +111,7 @@ class XlaDeviceContext : public DeviceContext { Tensor* device_tensor, StatusCallback done) const override; void CopyDeviceTensorToCPU(const Tensor* device_tensor, - StringPiece tensor_name, Device* device, + absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, const StatusCallback& done); diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc index 07cfab6151..bc0db558d8 100644 --- a/tensorflow/compiler/jit/xla_fusion_optimizer.cc +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/jit/deadness_analysis.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" @@ -326,7 +327,7 @@ Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, string& name = cluster_names[cluster]; if (name.empty()) { - name = strings::StrCat("cluster_", cluster_sequence_num++); + name = absl::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 4c9bb2e27b..d95da63405 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -122,7 +122,7 @@ class XlaTensor { std::shared_ptr definition_event_; // A list of all streams for which the tensor's content is defined for any // newly enqueued command. - gtl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); + absl::InlinedVector streams_defined_on_ GUARDED_BY(mu_); mutex mu_; }; diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 34defe1c7a..050d827a09 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1103,6 +1103,7 @@ cc_library( "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 0faf0fd8ed..bddda6f302 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -45,6 +45,8 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/core/common_runtime/device.h" @@ -61,7 +63,6 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" @@ -81,7 +82,7 @@ string* tf_xla_test_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; string LocalDeviceToFullDeviceName(const string& device) { - return strings::StrCat("/job:localhost/replica:0/task:0/device:", device); + return absl::StrCat("/job:localhost/replica:0/task:0/device:", device); } constexpr std::array kAllXlaTypes = { @@ -107,11 +108,12 @@ class OpTestBuilder { // Sets an attribute. template - OpTestBuilder& Attr(StringPiece attr_name, T&& value); + OpTestBuilder& Attr(absl::string_view attr_name, T&& value); // Overload needed to allow {...} expressions for value. template - OpTestBuilder& Attr(StringPiece attr_name, std::initializer_list value); + OpTestBuilder& Attr(absl::string_view attr_name, + std::initializer_list value); // Adds nodes that executes the operator under test on 'device' to 'graphdef'. // If 'use_jit' is true, marks the operator under test to be compiled by XLA. @@ -185,13 +187,13 @@ OpTestBuilder& OpTestBuilder::RandomUniqueInput(DataType type, } template -OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) { +OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, T&& value) { AddNodeAttr(attr_name, std::forward(value), &node_def_); return *this; } template -OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, +OpTestBuilder& OpTestBuilder::Attr(absl::string_view attr_name, std::initializer_list value) { Attr>(attr_name, std::move(value)); return *this; @@ -209,7 +211,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, NodeDef* test_def = graphdef->add_node(); *test_def = node_def_; - test_def->set_name(strings::StrCat(name_prefix, "_op_under_test")); + test_def->set_name(absl::StrCat(name_prefix, "_op_under_test")); test_def->set_device(device); AddDefaultsToNodeDef(*op_def, test_def); if (use_jit) { @@ -224,7 +226,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, // Build feed and fetch nodes. for (int i = 0; i < input_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = strings::StrCat(name_prefix, "_input_", i); + string name = absl::StrCat(name_prefix, "_input_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder") .Device(device) .Attr("dtype", input_types[i]) @@ -235,7 +237,7 @@ Status OpTestBuilder::BuildGraph(const string& name_prefix, for (int i = 0; i < output_types.size(); ++i) { NodeDef* def = graphdef->add_node(); - string name = strings::StrCat(name_prefix, "_output_", i); + string name = absl::StrCat(name_prefix, "_output_", i); TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity") .Device(device) .Attr("T", output_types[i]) @@ -726,11 +728,11 @@ bool IsClose(const complex64& x, const complex64& y, double atol, template string Str(T x) { - return strings::StrCat(x); + return absl::StrCat(x); } template <> string Str(complex64 x) { - return strings::StrCat("(", x.real(), ", ", x.imag(), ")"); + return absl::StrCat("(", x.real(), ", ", x.imag(), ")"); } template @@ -740,11 +742,11 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol, auto Ty = y.flat(); for (int i = 0; i < Tx.size(); ++i) { if (!IsClose(Tx(i), Ty(i), atol, rtol)) { - return errors::InvalidArgument(strings::StrCat( - i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ", - Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(), - "atol = ", atol, " rtol = ", rtol, - " tol = ", atol + rtol * Abs(Tx(i)))); + return errors::InvalidArgument( + absl::StrCat(i, "-th tensor element isn't close: ", Str(Tx(i)), + " vs. ", Str(Ty(i)), ". x = ", x.DebugString(), + "y = ", y.DebugString(), "atol = ", atol, + " rtol = ", rtol, " tol = ", atol + rtol * Abs(Tx(i)))); } } return Status::OK(); @@ -756,7 +758,7 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { auto Ty = y.flat(); for (int i = 0; i < Tx.size(); ++i) { if (Tx(i) != Ty(i)) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i), ". x = ", x.DebugString(), "y = ", y.DebugString())); } @@ -771,14 +773,14 @@ Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) { Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, double rtol) { if (a.dtype() != b.dtype()) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( "Tensors have different types: ", DataTypeString(a.dtype()), " and ", DataTypeString(b.dtype()))); } if (!a.IsSameSize(b)) { - return errors::InvalidArgument(strings::StrCat( - "Tensors have different shapes: ", a.shape().DebugString(), " and ", - b.shape().DebugString())); + return errors::InvalidArgument( + absl::StrCat("Tensors have different shapes: ", a.shape().DebugString(), + " and ", b.shape().DebugString())); } switch (a.dtype()) { @@ -827,7 +829,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( } string cpu_device = - LocalDeviceToFullDeviceName(strings::StrCat(DEVICE_CPU, ":0")); + LocalDeviceToFullDeviceName(absl::StrCat(DEVICE_CPU, ":0")); string test_device = LocalDeviceToFullDeviceName(*tf_xla_test_device_ptr); DeviceNameUtils::ParsedName parsed_name; @@ -842,7 +844,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( std::vector expected_inputs, test_inputs; std::vector expected_fetches, test_fetches; Status status = builder.BuildGraph( - strings::StrCat("test", num_tests_, "_expected"), cpu_device, + absl::StrCat("test", num_tests_, "_expected"), cpu_device, /* use_jit= */ false, &graph, /* test_node_def= */ nullptr, &expected_inputs, &expected_fetches); if (!status.ok()) { @@ -851,7 +853,7 @@ OpTest::TestResult OpTest::ExpectTfAndXlaOutputsAreClose( } NodeDef* node_def; - status = builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"), + status = builder.BuildGraph(absl::StrCat("test", num_tests_, "_test"), test_device, tf_xla_test_use_jit, &graph, &node_def, &test_inputs, &test_fetches); if (!status.ok()) { diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 0797b2cb17..22be7f048f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -291,6 +291,7 @@ cc_library( "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) @@ -433,6 +434,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", ], ) @@ -609,11 +611,10 @@ cc_library( srcs = ["resource_operation_table.cc"], hdrs = ["resource_operation_table.h"], deps = [ - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:ops", - "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index 24616c01c7..380c6a7e23 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/dump_graph_flags.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" @@ -52,9 +52,9 @@ string MakeUniqueFilename(string name) { string filename = name; if (count > 0) { - strings::StrAppend(&filename, "_", count); + absl::StrAppend(&filename, "_", count); } - strings::StrAppend(&filename, ".pbtxt"); + absl::StrAppend(&filename, ".pbtxt"); return filename; } @@ -69,7 +69,7 @@ string WriteTextProtoToUniqueFile( << proto_type << ": " << status; return "(unavailable)"; } - string filepath = strings::StrCat(dirname, "/", MakeUniqueFilename(name)); + string filepath = absl::StrCat(dirname, "/", MakeUniqueFilename(name)); status = WriteTextProto(Env::Default(), filepath, proto); if (!status.ok()) { LOG(WARNING) << "Failed to dump " << proto_type << " to file: " << filepath diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index e2affee51f..0911550f1f 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -42,7 +42,7 @@ namespace functionalize_cond { // TODO(jpienaar): Move to OutputTensor. string DebugString(const OutputTensor& tensor) { - return strings::StrCat(tensor.node->name(), ":", tensor.index); + return absl::StrCat(tensor.node->name(), ":", tensor.index); } string Branch_Name(BranchType b) { @@ -61,17 +61,17 @@ string Branch_Name(BranchType b) { string DebugString(StateMap::CondId cond_state) { if (cond_state == nullptr || cond_state->empty()) return "{}"; using value_type = StateMap::CondState::value_type; - return strings::StrCat( + return absl::StrCat( "{", absl::StrJoin(*cond_state, ", ", [](string* output, const value_type& pred_branch) { const OutputTensor& pred = pred_branch.first; const BranchType& branch = pred_branch.second; if (branch == BranchType::kNeither) - strings::StrAppend(output, "d"); + absl::StrAppend(output, "d"); else - strings::StrAppend(output, "s(", DebugString(pred), ",", - Branch_Name(branch), ")"); + absl::StrAppend(output, "s(", DebugString(pred), ",", + Branch_Name(branch), ")"); }), "}"); } @@ -159,8 +159,8 @@ struct CondArgNode { : src(src), src_output(src_output) {} string ToString() const { - return strings::StrCat("src=", src->name(), ":", src_output, - " switches=", NodesToString(switches)); + return absl::StrCat("src=", src->name(), ":", src_output, + " switches=", NodesToString(switches)); } Node* src; @@ -171,11 +171,11 @@ struct CondArgNode { using CondArgNodes = std::vector; string DebugString(const CondArgNodes& nodes) { - return strings::StrCat( + return absl::StrCat( "[", absl::StrJoin(nodes, ", ", [](string* output, const CondArgNode& node) { - strings::StrAppend(output, node.ToString()); + absl::StrAppend(output, node.ToString()); }), "]"); } @@ -373,7 +373,7 @@ Status Conditional::BuildArgumentNodes() { for (auto branch : {BranchType::kElseBranch, BranchType::kThenBranch}) { int branch_index = static_cast(branch); TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_Arg", arg_count), + NodeBuilder(absl::StrCat("_Arg", arg_count), FunctionLibraryDefinition::kArgOp) .Attr("T", dtype) .Attr("index", arg_count) @@ -441,7 +441,7 @@ Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, Node* src = edge->src(); int src_output = edge->src_output(); TF_RETURN_IF_ERROR( - NodeBuilder(graph->NewName(strings::StrCat(src->name(), "_added_switch")), + NodeBuilder(graph->NewName(absl::StrCat(src->name(), "_added_switch")), "Switch") .Input(src, src_output) .Input(const_cast(predicate_.node), predicate_.index) @@ -650,8 +650,8 @@ Status Conditional::BuildIfNode(Graph* graph, int64 id = ++sequence_num; NameAttrList body_name; - body_name.set_name(strings::StrCat("_functionalize_if_", - branch_name[branch_index], "_", id)); + body_name.set_name( + absl::StrCat("_functionalize_if_", branch_name[branch_index], "_", id)); VLOG(3) << "FunctionalizeControlFlow (" << branch_name[branch_index] << "): " @@ -804,7 +804,7 @@ Status Conditional::BuildAndReplace(Graph* graph, string Conditional::name() const { CHECK(!merges_.empty()); - return strings::StrCat((*merges_.begin())->name(), "_if"); + return absl::StrCat((*merges_.begin())->name(), "_if"); } Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, @@ -1327,12 +1327,12 @@ void FunctionalizeCond::DumpGraphWithCondState(const string& name) { for (Node* n : graph_->nodes()) { n->ClearAttr(kCondGroupDebugAttr); n->AddAttr(kCondGroupDebugAttr, - strings::StrCat(state_map_.CondStateToString(n), "_", - state_map_.AncestorStateToString(n))); + absl::StrCat(state_map_.CondStateToString(n), "_", + state_map_.AncestorStateToString(n))); } LOG(INFO) << "FunctionalizeControlFlow (" << name << "): " - << dump_graph::DumpGraphToFile( - strings::StrCat("functionalize_", name), *graph_, library_); + << dump_graph::DumpGraphToFile(absl::StrCat("functionalize_", name), + *graph_, library_); } Status FunctionalizeCond::Functionalize(Graph* graph, diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc index 924fcdd9cd..54cebc6177 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -42,7 +42,7 @@ xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { const char* const kRetValOp = "_Retval"; NodeDef ret_def; ret_def.set_op(kRetValOp); - ret_def.set_name(strings::StrCat(kRetValOp, index)); + ret_def.set_name(absl::StrCat(kRetValOp, index)); AddNodeAttr("T", type, &ret_def); AddNodeAttr("index", index, &ret_def); return AddNodeDefToGraph(ret_def, graph); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h index 61940e3586..582b49d511 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -43,13 +43,12 @@ xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index); // Returns a textual representation of the names of the nodes in the input. template string NodesToString(const T& nodes) { - return strings::StrCat("{", - absl::StrJoin(nodes, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, - node->name()); - }), - "}"); + return absl::StrCat("{", + absl::StrJoin(nodes, ",", + [](string* output, const Node* node) { + absl::StrAppend(output, node->name()); + }), + "}"); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 6e3c4b0e0f..7f45e3bffa 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -132,7 +132,7 @@ Status CopySubgraph(const Graph& graph, const Frame* frame, StatusOr BuildArgNode(Graph* graph, DataType type, int index) { const char* const kArgOp = "_Arg"; NodeDef arg_def; - NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); + NodeDefBuilder builder(absl::StrCat(kArgOp, index), kArgOp); builder.Attr("T", type); builder.Attr("index", index); TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); @@ -487,9 +487,9 @@ Status FunctionalizeLoop(const FunctionLibraryDefinition* lookup_library, static std::atomic sequence_num(0LL); int64 id = ++sequence_num; NameAttrList cond_name; - cond_name.set_name(strings::StrCat("_functionalize_cond_", id)); + cond_name.set_name(absl::StrCat("_functionalize_cond_", id)); NameAttrList body_name; - body_name.set_name(strings::StrCat("_functionalize_body_", id)); + body_name.set_name(absl::StrCat("_functionalize_body_", id)); FunctionDef cond_fdef; TF_RETURN_IF_ERROR( GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 1ed1fb3b02..bc2e640559 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -127,7 +127,7 @@ Status GraphCompiler::Compile() { TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch()) << "Not supported node: " << n->DebugString(); params.op_kernel = op_kernel.get(); - gtl::InlinedVector output_attr(n->num_outputs()); + absl::InlinedVector output_attr(n->num_outputs()); params.output_attr_array = output_attr.data(); // tensor_inputs_ is a buffer reused across graph traversal. We clean up and diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h index 127562eb23..ab7cac7100 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.h +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -89,7 +89,7 @@ class GraphCompiler { ScopedStepContainer* step_container_; // A buffer to hold tensor inputs to a node, this is reused across the graph // traversal. - gtl::InlinedVector tensor_inputs_; + absl::InlinedVector tensor_inputs_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index edced6bc0e..a18e04995b 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -26,7 +26,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp& input, absl::Span block_shape, const xla::Literal& crops) { const int input_rank = input_tensor_shape.dims(); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); const int block_rank = block_shape.size(); diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index 2e383b1473..182f7c9934 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -39,7 +39,7 @@ class BCastArgsOp : public XlaOpKernel { OP_REQUIRES( ctx, ctx->num_inputs() == 2, errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); - gtl::InlinedVector shapes; + absl::InlinedVector shapes; for (int i = 0; i < ctx->num_inputs(); ++i) { const TensorShape in_shape = ctx->InputShape(i); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), @@ -88,7 +88,7 @@ class BCastGradArgsOp : public XlaOpKernel { ctx, ctx->num_inputs() == 2, errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); - gtl::InlinedVector shapes; + absl::InlinedVector shapes; for (int i = 0; i < ctx->num_inputs(); ++i) { const TensorShape in_shape = ctx->InputShape(i); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 12b0e38288..e96a1adce4 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -48,7 +48,7 @@ class DepthToSpaceOp : public XlaOpKernel { OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got: ", input_rank)); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::XlaOp input = ctx->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index f6f158a73b..27690c156e 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -138,7 +138,7 @@ xla::TensorFormat XlaTensorFormat(tensorflow::TensorFormat data_format, int num_dims = num_spatial_dims + 2; int batch_dimension = GetTensorBatchDimIndex(num_dims, data_format); int feature_dimension = GetTensorFeatureDimIndex(num_dims, data_format); - gtl::InlinedVector spatial_dimensions(num_spatial_dims); + absl::InlinedVector spatial_dimensions(num_spatial_dims); for (int spatial_dim = 0; spatial_dim < num_spatial_dims; ++spatial_dim) { spatial_dimensions[spatial_dim] = GetTensorSpatialDimIndex(num_dims, data_format, spatial_dim); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 598248563b..118f2798d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -69,7 +69,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "data shape: " << data_shape.DebugString(); VLOG(1) << "axes : " << absl::StrJoin(axes, ","); - gtl::InlinedVector bitmap(data_shape.dims(), false); + absl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { @@ -103,7 +103,7 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { xla::XlaBuilder* const b = ctx->builder(); // Construct the builder for the reduction lambda. - xla::XlaBuilder r(strings::StrCat(desc, "-reduction")); + xla::XlaBuilder r(absl::StrCat(desc, "-reduction")); xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(reduction_type_, &type)); diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index c0afccaa5b..8494864b33 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -97,7 +97,7 @@ class ReverseV2Op : public XlaOpKernel { // witnessed_axes is used to ensure that the same axis is not marked to be // reversed multiple times. - gtl::InlinedVector witnessed_axes(x_shape.dims(), false); + absl::InlinedVector witnessed_axes(x_shape.dims(), false); for (int d = 0; d < axes.size(); ++d) { OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 4e0cf99d8e..2e0a69b70e 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -115,7 +115,7 @@ class ExpandDimsOp : public XlaOpKernel { // accept legacy scalars, even when they should be forbidden by the graphdef // version. OP_REQUIRES(ctx, dim_shape.num_elements() == 1, - errors::InvalidArgument(strings::StrCat( + errors::InvalidArgument(absl::StrCat( "dim input to ExpandDims must be a scalar; got ", dim_shape.DebugString()))); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index b7b4f3a546..76b79be6f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -26,7 +26,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input, absl::Span block_shape, const xla::Literal& paddings) { const int input_rank = input_tensor_shape.dims(); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); const int block_rank = block_shape.size(); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 4493539fe3..3293c13b21 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -48,7 +48,7 @@ class SpaceToDepthOp : public XlaOpKernel { OP_REQUIRES(ctx, kRequiredDims == input_rank, errors::InvalidArgument("Input rank should be ", kRequiredDims, "; got ", input_rank)); - const gtl::InlinedVector input_shape = + const absl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::XlaOp input = ctx->Input(0); diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index df91900570..ee70f508a9 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -111,7 +111,7 @@ class StackOp : public XlaOpKernel { xla::XlaOp value; XlaContext& xc = XlaContext::Get(ctx); XlaResource* resource; - string name = strings::StrCat("Stack: ", stack_name_); + string name = absl::StrCat("Stack: ", stack_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, TensorShape(), value, /*tensor_array_size=*/size, diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 472d4744d7..2b2e3de64f 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -46,9 +46,9 @@ class StridedSliceOp : public XlaOpKernel { const TensorShape input_shape = ctx->InputShape(0); TensorShape final_shape; - gtl::InlinedVector begin; - gtl::InlinedVector end; - gtl::InlinedVector strides; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); @@ -72,8 +72,8 @@ class StridedSliceOp : public XlaOpKernel { shrink_axis_mask_, &dummy_processing_shape, &final_shape, &dummy, &dummy, &dummy, &begin, &end, &strides)); - gtl::InlinedVector dimensions_to_reverse; - gtl::InlinedVector slice_begin, slice_end, slice_strides; + absl::InlinedVector dimensions_to_reverse; + absl::InlinedVector slice_begin, slice_end, slice_strides; for (int i = 0; i < begin.size(); ++i) { if (strides[i] > 0) { @@ -127,9 +127,9 @@ class StridedSliceGradOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape processing_shape, final_shape; - gtl::InlinedVector begin; - gtl::InlinedVector end; - gtl::InlinedVector strides; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); @@ -175,7 +175,7 @@ class StridedSliceGradOp : public XlaOpKernel { grad = xla::Reshape(grad, processing_shape.dim_sizes()); // Pad the input gradients. - gtl::InlinedVector dimensions_to_reverse; + absl::InlinedVector dimensions_to_reverse; xla::PaddingConfig padding_config; for (int i = 0; i < processing_shape.dims(); ++i) { @@ -238,9 +238,9 @@ class StridedSliceAssignOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape final_shape; - gtl::InlinedVector begin; - gtl::InlinedVector end; - gtl::InlinedVector strides; + absl::InlinedVector begin; + absl::InlinedVector end; + absl::InlinedVector strides; xla::Literal begin_literal, end_literal, strides_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal)); @@ -287,8 +287,8 @@ class StridedSliceAssignOp : public XlaOpKernel { xla::XlaOp rhs = ctx->Input(4); - gtl::InlinedVector dimensions_to_reverse; - gtl::InlinedVector slice_begin, slice_dims; + absl::InlinedVector dimensions_to_reverse; + absl::InlinedVector slice_begin, slice_dims; for (int i = 0; i < begin.size(); ++i) { // TODO(phawkins): implement strides != 1 OP_REQUIRES( diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index bb114d1aed..94108b764f 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -167,7 +167,7 @@ class TensorArrayOp : public XlaOpKernel { XlaContext& xc = XlaContext::Get(ctx); XlaResource* var; - string name = strings::StrCat("TensorArray: ", tensor_array_name_); + string name = absl::StrCat("TensorArray: ", tensor_array_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), dtype_, shape, value, /*tensor_array_size=*/size, diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index f9148b3942..6b303b31d4 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -61,7 +61,7 @@ class TransposeOp : public XlaOpKernel { std::vector transposed_order; // Check whether permutation is a permutation of integers of [0 .. dims). - gtl::InlinedVector bits(dims); + absl::InlinedVector bits(dims); bool is_identity = true; for (int i = 0; i < dims; ++i) { const int32 d = perm[i]; diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 9365d203f0..8597e7f139 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -205,7 +205,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/core:lib", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc index 5300e2c878..594ab1dfd0 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.cc +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -24,7 +24,7 @@ namespace tensorflow { xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - absl::Span initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder) { int arity = initial_values.size(); std::vector var_shapes; @@ -47,7 +47,7 @@ xla::StatusOr> XlaWhileLoop( // Build the condition. std::unique_ptr cond_builder = - builder->CreateSubBuilder(strings::StrCat(name, "_condition")); + builder->CreateSubBuilder(absl::StrCat(name, "_condition")); { auto parameter = xla::Parameter(cond_builder.get(), 0, tuple_shape, "parameter"); @@ -61,7 +61,7 @@ xla::StatusOr> XlaWhileLoop( // Build the body. std::unique_ptr body_builder = - builder->CreateSubBuilder(strings::StrCat(name, "_body")); + builder->CreateSubBuilder(absl::StrCat(name, "_body")); { auto parameter = xla::Parameter(body_builder.get(), 0, tuple_shape, "parameter"); @@ -84,7 +84,7 @@ xla::StatusOr> XlaWhileLoop( xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder) { auto while_cond_fn = [&](absl::Span values, diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h index 115ebf390d..f2134bb449 100644 --- a/tensorflow/compiler/tf2xla/lib/while_loop.h +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -19,11 +19,11 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/core/lib/core/stringpiece.h" namespace tensorflow { @@ -50,7 +50,7 @@ typedef std::function>( xla::StatusOr> XlaWhileLoop( const LoopConditionFunction& condition_function, const LoopBodyFunction& body_function, - absl::Span initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder); // Builds an XLA loop that repeats a computation `num_iterations` times. @@ -65,7 +65,7 @@ typedef std::function>( xla::StatusOr> XlaForEachIndex( int64 num_iterations, xla::PrimitiveType num_iterations_type, const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, StringPiece name, + absl::Span initial_values, absl::string_view name, xla::XlaBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc index 32ba6df2e6..20f2ce2919 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" namespace tensorflow { -/*static*/ StringPiece XlaResourceOpInfo::XlaResourceOpKindToString( +/*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString( XlaResourceOpKind op_kind) { switch (op_kind) { case XlaResourceOpKind::kRead: @@ -30,11 +30,11 @@ namespace tensorflow { } } -static gtl::FlatMap* CreateResourceOpInfoMap() { - gtl::FlatMap* result = - new gtl::FlatMap; +static gtl::FlatMap* +CreateResourceOpInfoMap() { + auto* result = new gtl::FlatMap; - auto add = [&](StringPiece op, XlaResourceOpKind op_kind, + auto add = [&](absl::string_view op, XlaResourceOpKind op_kind, XlaResourceKind resource_kind) { auto insert_result = result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)}); @@ -103,23 +103,23 @@ static gtl::FlatMap* CreateResourceOpInfoMap() { return result; } -static const gtl::FlatMap& +static const gtl::FlatMap& GetStaticResourceOpInfoMap() { - static gtl::FlatMap* op_info_map = + static gtl::FlatMap* op_info_map = CreateResourceOpInfoMap(); return *op_info_map; } -const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op) { - const gtl::FlatMap& op_infos = +const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) { + const gtl::FlatMap& op_infos = GetStaticResourceOpInfoMap(); auto it = op_infos.find(op); return it == op_infos.end() ? nullptr : &it->second; } namespace resource_op_table_internal { -std::vector GetKnownResourceOps() { - std::vector result; +std::vector GetKnownResourceOps() { + std::vector result; for (const auto& p : GetStaticResourceOpInfoMap()) { result.push_back(p.first); } diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.h b/tensorflow/compiler/tf2xla/resource_operation_table.h index 7f627a64c6..61c7a56ff0 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table.h +++ b/tensorflow/compiler/tf2xla/resource_operation_table.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/stringpiece.h" +#include "absl/strings/string_view.h" #include "tensorflow/core/platform/logging.h" // Exposes information about the resource operations supported by tf2xla in a @@ -47,7 +47,7 @@ class XlaResourceOpInfo { XlaResourceOpKind kind() const { return op_kind_; } XlaResourceKind resource_kind() const { return resource_kind_; } - static StringPiece XlaResourceOpKindToString(XlaResourceOpKind op_kind); + static absl::string_view XlaResourceOpKindToString(XlaResourceOpKind op_kind); private: XlaResourceOpKind op_kind_; @@ -57,13 +57,13 @@ class XlaResourceOpInfo { // Returns a XlaResourceOpInfo describing `op` if it is a resource operation // supported by tf2xla, otherwise returns null (i.e. if this returns null then // `op` is either not a resource operation or is unsupported by XLA). -const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op); +const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op); namespace resource_op_table_internal { // NB! Implementation detail exposed for unit testing, do not use. // // Returns the set of resource operations known by this module. -std::vector GetKnownResourceOps(); +std::vector GetKnownResourceOps(); } // namespace resource_op_table_internal } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc index 0343f80de9..a85ef040a7 100644 --- a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -34,7 +34,7 @@ bool HasResourceInputOrOutput(const OpDef& op_def) { TEST(ResourceOperationTableTest, HaveAllResourceOps) { gtl::FlatMap known_resource_ops; - for (StringPiece known_resource_op : + for (absl::string_view known_resource_op : resource_op_table_internal::GetKnownResourceOps()) { ASSERT_TRUE( known_resource_ops.insert({string(known_resource_op), false}).second); diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 2d7eb8b915..8aae498be1 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -17,7 +17,6 @@ limitations under the License. #include "absl/strings/match.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index f34af2d67d..7dbe3a0b58 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/shape_util.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -75,7 +75,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, auto node_it = node_map.find(remap_it->second); if (node_it == node_map.end()) { // Strip off the aot_feed_#/ prefix. - StringPiece name(remap_it->second); + absl::string_view name(remap_it->second); const auto index = name.find('/'); if (index > 0) name.remove_prefix(index + 1); return errors::InvalidArgument( @@ -89,7 +89,7 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, // explicitly specify or override them. Node* arg_node = nullptr; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp) + NodeBuilder(absl::StrCat("_arg_", arg_index), kArgOp) .Attr("T", BaseType(feed_node->output_type(output_index))) .Attr("index", arg_index) .Attr(kFeedIdAttr, TensorIdToString(feed.id())) @@ -136,7 +136,7 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, // Connects fetch_node -> retval_node. Node* retval_node = nullptr; TF_RETURN_IF_ERROR( - NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp) + NodeBuilder(absl::StrCat("_retval_", ret_index), kRetvalOp) .Input(fetch_node, id.output_index()) .Attr("T", BaseType(fetch_node->output_type(id.output_index()))) .Attr("index", ret_index) @@ -256,7 +256,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( - strings::StrCat("/device:", DEVICE_CPU_XLA_JIT)); + absl::StrCat("/device:", DEVICE_CPU_XLA_JIT)); } std::vector xla_args; TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index e284e0b191..211caf8736 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -112,8 +112,8 @@ Status AddPlaceholdersForFeeds( const string name_port = TensorIdToString(feed->id()); PlaceholderInfo& info = placeholder_info[name_port]; info.feed = feed; - info.placeholder_name = strings::StrCat( - "aot_feed_", feed->id().output_index(), "/", feed->id().node_name()); + info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(), + "/", feed->id().node_name()); (*feed_remapping)[name_port] = info.placeholder_name; } @@ -258,7 +258,7 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, } string TensorIdToString(const tf2xla::TensorId& id) { - return strings::StrCat(id.node_name(), ":", id.output_index()); + return absl::StrCat(id.node_name(), ":", id.output_index()); } Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { @@ -289,7 +289,7 @@ Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { return Status::OK(); } -void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, +void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, KernelDef* kdef) { for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) { if (constraint.name() == name) { diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 33620ef810..a29e764466 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -53,7 +53,7 @@ string TensorIdToString(const tf2xla::TensorId& id); Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); // Add an allowed data type to the AttrConstraint with the given name. -void AddDtypeToKernalDefConstraint(StringPiece name, DataType dtype, +void AddDtypeToKernalDefConstraint(absl::string_view name, DataType dtype, KernelDef* kdef); // Returns the next random seed to use for seeding xla rng. diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 2b1f724dc7..68441b3d47 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" @@ -25,8 +27,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -153,7 +153,7 @@ static tf2xla::Config FetchesConfig(std::vector fetches) { tf2xla::Config config; for (const auto& fetch_node_name : fetches) { auto* fetch = config.add_fetch(); - fetch->set_name(strings::StrCat("fetch_", fetch_node_name)); + fetch->set_name(absl::StrCat("fetch_", fetch_node_name)); fetch->mutable_id()->set_node_name(fetch_node_name); } return config; diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index d98237bd5c..7f860500c7 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -76,12 +76,11 @@ class XlaCompilationAllocator : public Allocator { XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, DeviceType type) - : LocalDevice( - options, - Device::BuildDeviceAttributes( - strings::StrCat("/device:", type.type(), ":0"), type, - Bytes(256 << 20), DeviceLocality(), - strings::StrCat("device: XLA compilation device ", type.type()))), + : LocalDevice(options, Device::BuildDeviceAttributes( + absl::StrCat("/device:", type.type(), ":0"), + type, Bytes(256 << 20), DeviceLocality(), + absl::StrCat("device: XLA compilation device ", + type.type()))), allocator_(new XlaCompilationAllocator()) {} XlaCompilationDevice::~XlaCompilationDevice() {} diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 0c300c282e..41d305d461 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -198,14 +198,14 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // lowest-numbered core that consumes the argument. We choose the // lowest-numbered core so the assignment is deterministic. for (Node* n : graph->nodes()) { - if (StringPiece(n->type_string()) == "_Arg") { + if (absl::string_view(n->type_string()) == "_Arg") { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); } } // Do _Retval as a second loop, in case the retval's input is an _Arg (which // may have gotten a device assignment from the first loop). for (Node* n : graph->nodes()) { - if (StringPiece(n->type_string()) == "_Retval") { + if (absl::string_view(n->type_string()) == "_Retval") { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); } } @@ -213,8 +213,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileFunction: " << dump_graph::DumpGraphToFile( - strings::StrCat("xla_compile_function_", function_id), - *graph); + absl::StrCat("xla_compile_function_", function_id), *graph); } VLOG(1) << "===================================================="; @@ -522,7 +521,7 @@ Status XlaCompiler::BuildArguments( // Use the _Arg nodes in the graph to resolve core assignments. for (const Node* n : graph.nodes()) { - if (StringPiece(n->type_string()) != "_Arg") continue; + if (absl::string_view(n->type_string()) != "_Arg") continue; int index; TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); TF_RET_CHECK(index >= 0 && index < args.size()) @@ -581,7 +580,7 @@ Status XlaCompiler::BuildArguments( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = xla::Parameter(builder, i, (*input_shapes)[i], - strings::StrCat("arg", i)); + absl::StrCat("arg", i)); } } @@ -644,7 +643,7 @@ Status XlaCompiler::CompileSingleOp( // dependency edge to the _SOURCE node. for (int64 i = 0; i < ctx->num_inputs(); ++i) { Node* node; - string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); + string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_arg"); Status status = NodeBuilder(name, "_Arg") .ControlInput(graph->source_node()) .Attr("T", ctx->input_dtype(i)) @@ -657,7 +656,7 @@ Status XlaCompiler::CompileSingleOp( // Similarly with return values, create dummy _Retval nodes fed by `node`. for (int64 i = 0; i < ctx->num_outputs(); ++i) { Node* node; - string name = strings::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); + string name = absl::StrCat(ctx->op_kernel().name(), "_", i, "_retval"); Status status = NodeBuilder(name, "_Retval") .Input(main_node, i) .Attr("T", ctx->expected_output_dtype(i)) @@ -693,7 +692,7 @@ Status ValidateGraph(const Graph* graph, const DeviceType& device_type, const string& name) { auto maybe_error = [&](const Node* node, const Status& s) -> Status { if (!s.ok()) { - return errors::InvalidArgument(strings::StrCat( + return errors::InvalidArgument(absl::StrCat( "Detected unsupported operations when trying to compile graph ", name, " on ", device_type.type_string(), ": ", node->def().op(), " (", s.error_message(), ")", FormatNodeForError(*node))); @@ -734,7 +733,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileGraph: " << dump_graph::DumpGraphToFile( - strings::StrCat("xla_compile_graph_", name), *graph); + absl::StrCat("xla_compile_graph_", name), *graph); } // Report the error here if initialization failed. diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 24a4b92b45..e8b4b0eb36 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -32,7 +32,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 1499c99ed1..d67e50375b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -67,7 +67,7 @@ const xla::XlaOp& XlaOpKernelContext::Input(int index) { return GetComputationFromTensor(context_->input(index)); } -const xla::XlaOp& XlaOpKernelContext::Input(StringPiece name) { +const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) { return GetComputationFromTensor(GetInputTensorByName(name)); } @@ -75,7 +75,7 @@ TensorShape XlaOpKernelContext::InputShape(int index) { return context_->input(index).shape(); } -TensorShape XlaOpKernelContext::InputShape(StringPiece name) { +TensorShape XlaOpKernelContext::InputShape(absl::string_view name) { return GetInputTensorByName(name).shape(); } @@ -100,7 +100,7 @@ Status XlaOpKernelContext::ConstantInput(int index, } static xla::StatusOr InputIndex(XlaOpKernelContext* context, - StringPiece name) { + absl::string_view name) { int start, stop; TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop)); if (stop != start + 1) { @@ -112,7 +112,7 @@ static xla::StatusOr InputIndex(XlaOpKernelContext* context, return start; } -Status XlaOpKernelContext::ConstantInput(StringPiece name, +Status XlaOpKernelContext::ConstantInput(absl::string_view name, xla::Literal* constant_literal) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInput(index, constant_literal); @@ -265,7 +265,7 @@ Status XlaOpKernelContext::ConstantInputAsIntScalar(int index, int64* out) { return LiteralToInt64Scalar(literal, out); } -Status XlaOpKernelContext::ConstantInputAsIntScalar(StringPiece name, +Status XlaOpKernelContext::ConstantInputAsIntScalar(absl::string_view name, int64* out) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInputAsIntScalar(index, out); @@ -305,7 +305,7 @@ Status XlaOpKernelContext::ConstantInputAsIntVector(int index, return LiteralToInt64Vector(literal, out); } -Status XlaOpKernelContext::ConstantInputAsIntVector(StringPiece name, +Status XlaOpKernelContext::ConstantInputAsIntVector(absl::string_view name, std::vector* out) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInputAsIntVector(index, out); @@ -344,7 +344,7 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, } } -Status XlaOpKernelContext::ConstantInputAsInt64Literal(StringPiece name, +Status XlaOpKernelContext::ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInputAsInt64Literal(index, out); @@ -361,7 +361,7 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) { return Status::OK(); } -Status XlaOpKernelContext::InputList(StringPiece name, +Status XlaOpKernelContext::InputList(absl::string_view name, std::vector* handles, std::vector* shapes) { OpInputList inputs; @@ -376,7 +376,7 @@ Status XlaOpKernelContext::InputList(StringPiece name, } Status XlaOpKernelContext::ConstantInputList( - StringPiece name, std::vector* outputs) { + absl::string_view name, std::vector* outputs) { int start, stop; TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); outputs->resize(stop - start); @@ -429,8 +429,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, value); } -Status XlaOpKernelContext::ReadVariableInput(StringPiece name, DataType type, - TensorShape* shape, +Status XlaOpKernelContext::ReadVariableInput(absl::string_view name, + DataType type, TensorShape* shape, xla::XlaOp* value) { return ReadVariableInputTensor(GetInputTensorByName(name), type, context_, shape, value); @@ -564,7 +564,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, handle, builder()); } -Status XlaOpKernelContext::AssignVariable(StringPiece name, DataType type, +Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type, xla::XlaOp handle) { TF_RET_CHECK(handle.valid()); return AssignVariableTensor(GetInputTensorByName(name), type, context_, @@ -610,7 +610,7 @@ const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul( return XlaContext::Get(context_).GetOrCreateMul(type); } -const Tensor& XlaOpKernelContext::GetInputTensorByName(StringPiece name) { +const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) { const Tensor* tensor; CHECK(context_->input(name, &tensor).ok()); return *tensor; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 45cfa7da74..962c86d3a5 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -80,14 +80,14 @@ class XlaOpKernelContext { TensorShape InputShape(int index); // Returns the shape of input `name`. - TensorShape InputShape(StringPiece name); + TensorShape InputShape(absl::string_view name); // Returns input `index` as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. const xla::XlaOp& Input(int index); // Returns input `name` as a XlaOp. - const xla::XlaOp& Input(StringPiece name); + const xla::XlaOp& Input(absl::string_view name); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -97,7 +97,7 @@ class XlaOpKernelContext { // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status InputList(StringPiece name, std::vector* handles, + Status InputList(absl::string_view name, std::vector* handles, std::vector* shapes); // Helper methods for constant inputs. @@ -106,7 +106,7 @@ class XlaOpKernelContext { // expression cannot be evaluated, e.g., because it depends on unbound // parameters, returns a non-OK status. Status ConstantInput(int index, xla::Literal* constant_literal); - Status ConstantInput(StringPiece name, xla::Literal* constant_literal); + Status ConstantInput(absl::string_view name, xla::Literal* constant_literal); // Evaluates input `index`, reshapes it to `new_shape` if new_shape != // InputShape(index), and stores it in `*constant_literal`. If the input @@ -118,14 +118,15 @@ class XlaOpKernelContext { // Converts a constant scalar int32 or int64 tensor into an int64. Status ConstantInputAsIntScalar(int index, int64* out); - Status ConstantInputAsIntScalar(StringPiece name, int64* out); + Status ConstantInputAsIntScalar(absl::string_view name, int64* out); // Converts a constant scalar float32 or float64 tensor into a float64. Status ConstantInputAsFloatScalar(int index, double* out); // Converts a constant 1D int32 or int64 tensor into a vector of int64s. Status ConstantInputAsIntVector(int index, std::vector* out); - Status ConstantInputAsIntVector(StringPiece name, std::vector* out); + Status ConstantInputAsIntVector(absl::string_view name, + std::vector* out); // Reshapes and converts a constant int32 or int64 tensor into a vector of // int64s. @@ -133,7 +134,7 @@ class XlaOpKernelContext { // Converts a constant int32 or int64 Tensor into an xla int64 Literal. Status ConstantInputAsInt64Literal(int index, xla::Literal* out); - Status ConstantInputAsInt64Literal(StringPiece name, xla::Literal* out); + Status ConstantInputAsInt64Literal(absl::string_view name, xla::Literal* out); // Converts a constant 1D int32 or int64 tensor into a TensorShape. Status ConstantInputAsShape(int index, TensorShape* shape); @@ -141,7 +142,7 @@ class XlaOpKernelContext { // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status ConstantInputList(StringPiece name, + Status ConstantInputList(absl::string_view name, std::vector* literals); // Outputs @@ -190,8 +191,8 @@ class XlaOpKernelContext { xla::XlaOp* value); // Reads the current value of the resouce variable referred to by input // `name`. - Status ReadVariableInput(StringPiece name, DataType type, TensorShape* shape, - xla::XlaOp* value); + Status ReadVariableInput(absl::string_view name, DataType type, + TensorShape* shape, xla::XlaOp* value); // Assigns the value `handle` to the variable referenced by input // `input_index`. The variable must be of `type`. Returns an error if the @@ -199,7 +200,8 @@ class XlaOpKernelContext { // different shape. Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); // Assigns the value `handle` to the variable referenced by input `name`. - Status AssignVariable(StringPiece name, DataType type, xla::XlaOp handle); + Status AssignVariable(absl::string_view name, DataType type, + xla::XlaOp handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); @@ -248,7 +250,7 @@ class XlaOpKernelContext { private: // Returns the tensor of input `name`. - const Tensor& GetInputTensorByName(StringPiece name); + const Tensor& GetInputTensorByName(absl::string_view name); OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index dae2d956ca..b0eeee3174 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -371,26 +371,28 @@ XlaOpRegistry& XlaOpRegistry::Instance() { return *r; } -XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(StringPiece name) { +XlaOpRegistrationBuilder::XlaOpRegistrationBuilder(absl::string_view name) { registration_.reset(new XlaOpRegistry::OpRegistration); registration_->name = string(name); } -XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name(StringPiece name) { +XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name( + absl::string_view name) { XlaOpRegistrationBuilder registration(name); return registration; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( - absl::Span devices) { + absl::Span devices) { registration_->has_device_whitelist = true; - for (StringPiece device : devices) { + for (absl::string_view device : devices) { registration_->device_whitelist.emplace(device); } return *this; } -XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device(StringPiece device) { +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( + absl::string_view device) { registration_->has_device_whitelist = true; registration_->device_whitelist.emplace(device); return *this; @@ -407,7 +409,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowResourceTypes() { } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( - StringPiece attr_name, DataType allowed) { + absl::string_view attr_name, DataType allowed) { std::set& types = registration_->type_constraints[string(attr_name)]; types.insert(allowed); @@ -415,7 +417,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( - StringPiece attr_name, absl::Span allowed) { + absl::string_view attr_name, absl::Span allowed) { std::set& types = registration_->type_constraints[string(attr_name)]; for (DataType t : allowed) { @@ -425,7 +427,7 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( - StringPiece input_name) { + absl::string_view input_name) { registration_->compile_time_constant_inputs.emplace(input_name); return *this; } @@ -452,7 +454,7 @@ XlaOpRegistrar::XlaOpRegistrar( } XlaBackendRegistrar::XlaBackendRegistrar( - StringPiece name, absl::Span types, + absl::string_view name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter) { XlaOpRegistry& registry = XlaOpRegistry::Instance(); registry.RegisterBackend(string(name), types, op_filter); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index c640842dc0..74a4885f1f 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -232,18 +232,18 @@ class XlaOpRegistry { class XlaOpRegistrationBuilder { public: // Starts an operator registration chain. - static XlaOpRegistrationBuilder Name(StringPiece name); + static XlaOpRegistrationBuilder Name(absl::string_view name); // Specifies a whitelist of devices on which the operator may run. - XlaOpRegistrationBuilder& Device(StringPiece devices); - XlaOpRegistrationBuilder& Device(absl::Span devices); + XlaOpRegistrationBuilder& Device(absl::string_view devices); + XlaOpRegistrationBuilder& Device(absl::Span devices); // Specifies a type constraint for a type variable attribute. Each constraint // specifies the set of types that the type variable may assume. - XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, + XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, DataType allowed); - XlaOpRegistrationBuilder& TypeConstraint(StringPiece attr_name, + XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, absl::Span allowed); // Specifies that a dummy copy of this operator should not be registered on @@ -254,13 +254,13 @@ class XlaOpRegistrationBuilder { XlaOpRegistrationBuilder& AllowResourceTypes(); // Mark 'input_name' as an argument whose value must be known at compile-time. - XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name); + XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name); std::unique_ptr Build( XlaOpRegistry::Factory factory); private: - XlaOpRegistrationBuilder(StringPiece name); + XlaOpRegistrationBuilder(absl::string_view name); std::unique_ptr registration_; }; @@ -288,7 +288,7 @@ class XlaOpRegistrar { class XlaBackendRegistrar { public: - XlaBackendRegistrar(StringPiece name, absl::Span types, + XlaBackendRegistrar(absl::string_view name, absl::Span types, XlaOpRegistry::BackendOpFilter op_filter = nullptr); }; diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 7928fa0347..56c2e01055 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -43,7 +43,7 @@ XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, for (const string& gradient : tensor_array_gradients) { tensor_array_gradients_[gradient].reset(new XlaResource( /*kind=*/kTensorArray, /*arg_num=*/-1, - /*name=*/strings::StrCat("TensorArrayGrad: ", name_), type_, shape_, + /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, xla::XlaOp(), tensor_array_size_, /*tensor_array_gradients=*/{})); } } @@ -135,7 +135,7 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, xla::Broadcast(XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, - /*name=*/strings::StrCat("TensorArrayGrad: ", name_), + /*name=*/absl::StrCat("TensorArrayGrad: ", name_), type_, shape_, gradient_value, tensor_array_size_, /*tensor_array_gradients=*/{})); } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index f6cfac6537..64141ed191 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2520,6 +2520,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "@com_google_absl//absl/container:inlined_vector", ], ) @@ -3187,6 +3188,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc index c822c94f1b..8a6e5327e0 100644 --- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -259,7 +259,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionTwoLoops) { TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) { // Fusing a reduce into a loop fusion would require changing the fusion kind. // That's not supported yet. - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[6400]{0} parameter(0) ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) @@ -277,7 +277,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopReduceToInputFusion) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[6400]{0} parameter(0) ROOT mul = f32[6400]{0} multiply(p0.1, p0.1) @@ -301,7 +301,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionLoopElementwise) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) ROOT mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) @@ -324,7 +324,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopsDifferentShapes) { } TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) @@ -358,7 +358,7 @@ TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoop) { TEST_F(MultiOutputFusionTest, MultiOutputFusionSiblingLoopAndMultiOutputLoopDifferentShapes) { - auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + auto module = ParseHloString(absl::StrCat(kModulePrefix, R"( fused_computation_1 { p0.1 = f32[8,1,5,16,1,1]{5,4,3,2,1,0} parameter(0) mul = f32[8,1,5,16,1,1]{5,4,3,2,1,0} multiply(p0.1, p0.1) diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index cb367adf5e..b59c9ba3ed 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/hash/hash.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc index aab1180662..56145822be 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "absl/algorithm/container.h" +#include "absl/container/inlined_vector.h" #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index efbe980278..2ff97914f8 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -56,6 +56,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/stream_executor", "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ], ) diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 68ba17a424..9e3d2454d1 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -46,19 +46,15 @@ cc_library( deps = [ ":xrt_state_ops", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt:xrt_utils", "//tensorflow/core:core_cpu_internal", @@ -67,6 +63,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor:stream_executor_headers_lib", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 5cf2bc8861..1d4f8d97f2 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/proto_serialization.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/types.h" @@ -70,7 +70,7 @@ Status CompilationCacheKey(const xrt::XLAComputation& computation, string serialized; TF_RET_CHECK(SerializeToStringDeterministic(computation, &serialized)); uint64 fingerprint = Fingerprint64(serialized); - *key = strings::StrCat(fingerprint); + *key = absl::StrCat(fingerprint); return Status::OK(); } diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc index 911ac9a78b..2c3b07da58 100644 --- a/tensorflow/compiler/xrt/xrt_state.cc +++ b/tensorflow/compiler/xrt/xrt_state.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/stream_executor/stream_executor.h" @@ -201,14 +201,14 @@ const se::DeviceMemoryBase& XRTTupleAllocation::root_allocation() { /*static*/ Status XRTTupleAllocation::Lookup(ResourceMgr* rm, int64 key, XRTTupleAllocation** allocation) { - string key_string = strings::StrCat(key); + string key_string = absl::StrCat(key); TF_RETURN_IF_ERROR(rm->Lookup(kTupleContainer, key_string, allocation)); return Status::OK(); } /*static*/ Status XRTTupleAllocation::DeleteFromResourceManager(ResourceMgr* rm, int64 key) { - string key_string = strings::StrCat(key); + string key_string = absl::StrCat(key); return rm->Delete(kTupleContainer, key_string); } @@ -410,7 +410,7 @@ typedef XRTBufferAllocation* XRTBufferAllocationPtr; Status XRTTupleAllocation::Intern(ResourceMgr* rm, int64* key) { *key = get_uid(); - string key_string = strings::StrCat(*key); + string key_string = absl::StrCat(*key); return rm->Create(kTupleContainer, key_string, this); } -- GitLab From 9059375e16a563af1cc208a8f4cb898a4892a396 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Wed, 5 Sep 2018 14:01:26 -0700 Subject: [PATCH 138/540] [XLA] Rename PrecisionConfigProto to PrecisionConfig The "Proto" suffix adds little clarity but makes a long type name even longer. PiperOrigin-RevId: 211693871 --- tensorflow/compiler/tests/xla_ops_test.py | 10 +- .../compiler/tf2xla/kernels/xla_conv_op.cc | 2 +- .../compiler/tf2xla/kernels/xla_dot_op.cc | 2 +- tensorflow/compiler/tf2xla/lib/batch_dot.cc | 4 +- tensorflow/compiler/tf2xla/lib/batch_dot.h | 10 +- tensorflow/compiler/tf2xla/lib/cholesky.cc | 4 +- tensorflow/compiler/tf2xla/lib/cholesky.h | 6 +- tensorflow/compiler/tf2xla/lib/qr.cc | 6 +- tensorflow/compiler/tf2xla/lib/qr.h | 3 +- .../compiler/tf2xla/lib/triangular_solve.cc | 12 +-- .../compiler/tf2xla/lib/triangular_solve.h | 9 +- tensorflow/compiler/tf2xla/ops/xla_ops.cc | 4 +- tensorflow/compiler/xla/client/xla_builder.cc | 82 +++++++--------- tensorflow/compiler/xla/client/xla_builder.h | 97 +++++++++---------- tensorflow/compiler/xla/reference_util.cc | 4 +- .../xla/service/algebraic_simplifier_test.cc | 11 +-- .../service/bfloat16_normalization_test.cc | 4 +- .../xla/service/buffer_assignment_test.cc | 4 +- .../cpu/cpu_instruction_fusion_test.cc | 4 +- .../compiler/xla/service/graphviz_example.cc | 4 +- tensorflow/compiler/xla/service/hlo.proto | 2 +- .../xla/service/hlo_computation_test.cc | 12 +-- .../xla/service/hlo_creation_utils.cc | 9 +- .../compiler/xla/service/hlo_creation_utils.h | 9 +- .../xla/service/hlo_dataflow_analysis_test.cc | 4 +- .../compiler/xla/service/hlo_evaluator.cc | 2 +- .../compiler/xla/service/hlo_evaluator.h | 2 +- .../compiler/xla/service/hlo_instruction.cc | 47 +++++---- .../compiler/xla/service/hlo_instruction.h | 16 ++- .../xla/service/hlo_instruction_test.cc | 6 +- .../compiler/xla/service/hlo_instructions.cc | 2 +- .../compiler/xla/service/hlo_instructions.h | 2 +- tensorflow/compiler/xla/service/hlo_parser.cc | 26 ++--- .../xla/service/indexed_array_analysis.cc | 8 +- .../xla/service/indexed_array_analysis.h | 13 +-- .../service/tuple_points_to_analysis_test.cc | 4 +- .../compiler/xla/tests/hlo_test_base.cc | 6 +- tensorflow/compiler/xla/tests/hlo_test_base.h | 2 +- tensorflow/compiler/xla/xla_data.proto | 2 +- 39 files changed, 218 insertions(+), 238 deletions(-) diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index b2f026df6c..3f928a1bea 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -97,9 +97,9 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): args=(np.array([0xFFFFFFFF, 16], dtype=np.uint32), np.uint32(4)), expected=np.array([0xFFFFFFFF, 1], dtype=np.uint32)) - PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfigProto.DEFAULT, - xla_data_pb2.PrecisionConfigProto.HIGH, - xla_data_pb2.PrecisionConfigProto.HIGHEST) + PRECISION_VALUES = (None, xla_data_pb2.PrecisionConfig.DEFAULT, + xla_data_pb2.PrecisionConfig.HIGH, + xla_data_pb2.PrecisionConfig.HIGHEST) @parameterized.parameters(*PRECISION_VALUES) def testConv(self, precision): @@ -120,7 +120,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dnums.output_spatial_dimensions.extend(range(2, 2 + num_spatial_dims)) precision_config = None if precision: - precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config = xla_data_pb2.PrecisionConfig() precision_config.operand_precision.extend([precision, precision]) return xla.conv( lhs, @@ -151,7 +151,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dnums.rhs_batch_dimensions.append(0) precision_config = None if precision: - precision_config = xla_data_pb2.PrecisionConfigProto() + precision_config = xla_data_pb2.PrecisionConfig() precision_config.operand_precision.extend([precision, precision]) return xla.dot_general( lhs, diff --git a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc index 8848623868..fecc7c556e 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_conv_op.cc @@ -84,7 +84,7 @@ class XlaConvOp : public XlaOpKernel { private: xla::ConvolutionDimensionNumbers dnums_; - xla::PrecisionConfigProto precision_config_; + xla::PrecisionConfig precision_config_; TF_DISALLOW_COPY_AND_ASSIGN(XlaConvOp); }; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc index 2fed53e5c0..40b15b5579 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -54,7 +54,7 @@ class XlaDotOp : public XlaOpKernel { private: xla::DotDimensionNumbers dnums_; - xla::PrecisionConfigProto precision_config_; + xla::PrecisionConfig precision_config_; TF_DISALLOW_COPY_AND_ASSIGN(XlaDotOp); }; diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index d8c050d09e..64f2d781a6 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -28,7 +28,7 @@ namespace tensorflow { xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, bool transpose_y, bool conjugate_x, bool conjugate_y, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = x.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); @@ -96,7 +96,7 @@ xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x, y = xla::Conj(y); } - xla::PrecisionConfigProto precision_proto; + xla::PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index 6cfccd5553..6edd63a4d3 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -43,11 +43,11 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot(xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, - bool transpose_y = false, bool conjugate_x = false, - bool conjugate_y = false, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::DEFAULT); +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, bool transpose_x = false, + bool transpose_y = false, bool conjugate_x = false, + bool conjugate_y = false, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index c50a8de33e..ab3d0a5668 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -50,7 +50,7 @@ namespace { // l[..., j, j] // return l xla::XlaOp CholeskyUnblocked(xla::XlaOp a, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); @@ -150,7 +150,7 @@ xla::XlaOp CholeskyUnblocked(xla::XlaOp a, } // namespace xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 60cd7ded53..9a561c34b9 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -30,9 +30,9 @@ namespace tensorflow { // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. // TODO(znado): handle the complex Hermitian case -xla::XlaOp Cholesky(xla::XlaOp a, int64 block_size = 256, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::HIGHEST); +xla::XlaOp Cholesky( + xla::XlaOp a, int64 block_size = 256, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/qr.cc b/tensorflow/compiler/tf2xla/lib/qr.cc index 0a140fa93c..6b3f2b6e06 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.cc +++ b/tensorflow/compiler/tf2xla/lib/qr.cc @@ -150,7 +150,7 @@ struct QRBlockResult { xla::XlaOp vs; // Shape: [..., m, n] }; xla::StatusOr QRBlock( - xla::XlaOp a, xla::PrecisionConfigProto::Precision precision) { + xla::XlaOp a, xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); @@ -257,7 +257,7 @@ xla::StatusOr QRBlock( xla::StatusOr ComputeWYRepresentation( xla::PrimitiveType type, absl::Span batch_dims, xla::XlaOp vs, xla::XlaOp taus, int64 m, int64 n, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); int64 n_index = batch_dims.size() + 1; @@ -332,7 +332,7 @@ xla::StatusOr ComputeWYRepresentation( // rather than WY transformations. xla::StatusOr QRDecomposition( xla::XlaOp a, bool full_matrices, int64 block_size, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = xla::ShapeUtil::Rank(a_shape); diff --git a/tensorflow/compiler/tf2xla/lib/qr.h b/tensorflow/compiler/tf2xla/lib/qr.h index 8a389fb7b0..24b537ac8b 100644 --- a/tensorflow/compiler/tf2xla/lib/qr.h +++ b/tensorflow/compiler/tf2xla/lib/qr.h @@ -35,8 +35,7 @@ struct QRDecompositionResult { xla::StatusOr QRDecomposition( xla::XlaOp a, bool full_matrices, int64 block_size = 128, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::HIGHEST); + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 37b2240b45..6524c2a9b1 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -110,9 +110,9 @@ xla::XlaOp DiagonalBlocks(xla::XlaOp a, int64 block_size) { }); } -xla::XlaOp InvertDiagonalBlocks( - xla::XlaOp diag_blocks, bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfigProto::Precision precision) { +xla::XlaOp InvertDiagonalBlocks(xla::XlaOp diag_blocks, bool lower, + bool transpose_a, bool conjugate_a, + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = diag_blocks.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { // Input is a batch of square lower triangular square matrices. Its shape is @@ -216,7 +216,7 @@ xla::XlaOp InvertDiagonalBlocks( dnums.add_rhs_batch_dimensions(0); dnums.add_lhs_contracting_dimensions(2); dnums.add_rhs_contracting_dimensions(1); - xla::PrecisionConfigProto precision_proto; + xla::PrecisionConfig precision_proto; precision_proto.add_operand_precision(precision); precision_proto.add_operand_precision(precision); auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto); @@ -245,7 +245,7 @@ xla::XlaOp InvertDiagonalBlocks( xla::XlaOp SolveWithInvertedDiagonalBlocks( xla::XlaOp a, xla::XlaOp b, xla::XlaOp inv_diag_blocks, bool left_side, bool lower, bool transpose_a, bool conjugate_a, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape blocks_shape, @@ -346,7 +346,7 @@ xla::XlaOp SolveWithInvertedDiagonalBlocks( xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, bool conjugate_a, int64 block_size, - xla::PrecisionConfigProto::Precision precision) { + xla::PrecisionConfig::Precision precision) { xla::XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index ac42a48352..2303234f36 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -57,11 +57,10 @@ namespace tensorflow { // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -xla::XlaOp TriangularSolve(xla::XlaOp a, xla::XlaOp b, bool left_side, - bool lower, bool transpose_a, bool conjugate_a, - int64 block_size = 128, - xla::PrecisionConfigProto::Precision precision = - xla::PrecisionConfigProto::HIGHEST); +xla::XlaOp TriangularSolve( + xla::XlaOp a, xla::XlaOp b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, int64 block_size = 128, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 2cd9ae799f..68cfdc1785 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -83,7 +83,7 @@ lhs_dilation: dilation to apply between input elements rhs_dilation: dilation to apply between kernel elements feature_group_count: number of feature groups for grouped convolution. dimension_numbers: a serialized xla::ConvolutionDimensionNumbers proto. -precision_config: a serialized xla::PrecisionConfigProto proto. +precision_config: a serialized xla::PrecisionConfig proto. )doc"); REGISTER_OP("XlaDot") @@ -102,7 +102,7 @@ Wraps the XLA ConvGeneralDilated operator, documented at lhs: the LHS tensor rhs: the RHS tensor dimension_numbers: a serialized xla::DotDimensionNumbers proto. -precision_config: a serialized xla::PrecisionConfigProto proto. +precision_config: a serialized xla::PrecisionConfig proto. )doc"); REGISTER_OP("XlaDynamicUpdateSlice") diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 7f2125f74c..887b970661 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -820,7 +820,7 @@ XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs, } XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -828,14 +828,13 @@ XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs, dimension_numbers.add_lhs_contracting_dimensions( lhs_shape.dimensions_size() == 1 ? 0 : 1); dimension_numbers.add_rhs_contracting_dimensions(0); - return DotGeneral(lhs, rhs, dimension_numbers, precision_config_proto); + return DotGeneral(lhs, rhs, dimension_numbers, precision_config); }); } -XlaOp XlaBuilder::DotGeneral( - const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto) { +XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -844,8 +843,8 @@ XlaOp XlaBuilder::DotGeneral( ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dimension_numbers)); *instr.mutable_dot_dimension_numbers() = dimension_numbers; - if (precision_config_proto != nullptr) { - *instr.mutable_precision_config() = *precision_config_proto; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; } return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs}); }); @@ -899,28 +898,26 @@ Status XlaBuilder::VerifyConvolution( XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ConvGeneral(lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size()), - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } XlaOp XlaBuilder::ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs)); @@ -948,7 +945,7 @@ XlaOp XlaBuilder::ConvWithGeneralDimensions( MakePadding(base_area_dimensions, window_dimensions, window_strides, padding), dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); }); } @@ -956,11 +953,10 @@ XlaOp XlaBuilder::ConvGeneral( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp XlaBuilder::ConvGeneralDilated( @@ -968,8 +964,7 @@ XlaOp XlaBuilder::ConvGeneralDilated( absl::Span> padding, absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs)); @@ -996,8 +991,8 @@ XlaOp XlaBuilder::ConvGeneralDilated( *instr.mutable_convolution_dimension_numbers() = dimension_numbers; instr.set_feature_group_count(feature_group_count); - if (precision_config_proto != nullptr) { - *instr.mutable_precision_config() = *precision_config_proto; + if (precision_config != nullptr) { + *instr.mutable_precision_config() = *precision_config; } return AddInstruction(std::move(instr), HloOpcode::kConvolution, @@ -2594,43 +2589,40 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, } XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto) { - return lhs.builder()->Dot(lhs, rhs, precision_config_proto); + const PrecisionConfig* precision_config) { + return lhs.builder()->Dot(lhs, rhs, precision_config); } XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->DotGeneral(lhs, rhs, dimension_numbers, - precision_config_proto); + precision_config); } XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->Conv(lhs, rhs, window_strides, padding, - feature_group_count, precision_config_proto); + feature_group_count, precision_config); } -XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, - absl::Span> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { - return lhs.builder()->ConvWithGeneralPadding(lhs, rhs, window_strides, - padding, feature_group_count, - precision_config_proto); +XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count, + const PrecisionConfig* precision_config) { + return lhs.builder()->ConvWithGeneralPadding( + lhs, rhs, window_strides, padding, feature_group_count, precision_config); } XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + int64 feature_group_count, const PrecisionConfig* precision_config) { return lhs.builder()->ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, @@ -2638,10 +2630,10 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneral(lhs, rhs, window_strides, padding, dimension_numbers, feature_group_count, - precision_config_proto); + precision_config); } XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, @@ -2651,10 +2643,10 @@ XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto) { + const PrecisionConfig* precision_config) { return lhs.builder()->ConvGeneralDilated( lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, - dimension_numbers, feature_group_count, precision_config_proto); + dimension_numbers, feature_group_count, precision_config); } XlaOp Fft(const XlaOp& operand, FftType fft_type, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 59fbc664f2..58e8f4e7fa 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -496,20 +496,19 @@ class XlaBuilder { // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. - XlaOp DotGeneral( - const XlaOp& lhs, const XlaOp& rhs, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). @@ -518,7 +517,7 @@ class XlaBuilder { absl::Span window_strides, absl::Span> padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. @@ -527,29 +526,27 @@ class XlaBuilder { absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. - XlaOp ConvGeneral( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span window_strides, - absl::Span> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. - XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. @@ -1150,32 +1147,30 @@ class XlaBuilder { friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions); friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_number, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralPadding( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + const PrecisionConfig* precision_config); friend XlaOp ConvGeneralDilated( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, @@ -1183,8 +1178,7 @@ class XlaBuilder { absl::Span lhs_dilation, absl::Span rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count, - const PrecisionConfigProto* precision_config_proto); + int64 feature_group_count, const PrecisionConfig* precision_config); friend XlaOp Fft(const XlaOp& operand, FftType fft_type, absl::Span fft_length); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, @@ -1629,27 +1623,27 @@ XlaOp Le(const XlaOp& lhs, const XlaOp& rhs, // Enqueues a dot instruction onto the computation. XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a general dot instruction onto the computation. XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, which uses the // default convolution dimension numbers. XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration in the format returned by MakePadding(). -XlaOp ConvWithGeneralPadding( - const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, - absl::Span> padding, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); +XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided dimension numbers configuration. @@ -1657,7 +1651,7 @@ XlaOp ConvWithGeneralDimensions( const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration as well as the dimension numbers. @@ -1666,17 +1660,18 @@ XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs, absl::Span> padding, const ConvolutionDimensionNumbers& dimension_numbers, int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); + const PrecisionConfig* precision_config = nullptr); // Enqueues a convolution instruction onto the computation, with the caller // provided padding configuration, dilation factors and dimension numbers. -XlaOp ConvGeneralDilated( - const XlaOp& lhs, const XlaOp& rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64 feature_group_count = 1, - const PrecisionConfigProto* precision_config_proto = nullptr); +XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64 feature_group_count = 1, + const PrecisionConfig* precision_config = nullptr); // Enqueues an FFT instruction onto the computation, of the given type and // with the given FFT length. diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 8a05d1b0d7..9f1afa2671 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -574,9 +574,9 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( HloInstruction* rhs_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - /*new_size=*/2, PrecisionConfigProto::DEFAULT); + /*new_size=*/2, PrecisionConfig::DEFAULT); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1, window, dnums, precision_config)); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 0db74bd038..aa40fba9bb 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2379,9 +2379,9 @@ TEST_P(ConvFilterPaddingTest, DoIt) { // Add a PrecisionConfig and check that AlgebraicSimplifier keeps it in place // after the transformation. - PrecisionConfigProto precision_config; - precision_config.add_operand_precision(PrecisionConfigProto::HIGH); - precision_config.add_operand_precision(PrecisionConfigProto::HIGHEST); + PrecisionConfig precision_config; + precision_config.add_operand_precision(PrecisionConfig::HIGH); + precision_config.add_operand_precision(PrecisionConfig::HIGHEST); orig_conv->set_precision_config(precision_config); auto module = CreateNewModule(); @@ -2401,9 +2401,8 @@ TEST_P(ConvFilterPaddingTest, DoIt) { conv->operand(1)->shape().dimensions(2), conv->operand(1)->shape().dimensions(3), testcase.expected_conv_window)); - EXPECT_THAT( - conv->precision_config().operand_precision(), - ElementsAre(PrecisionConfigProto::HIGH, PrecisionConfigProto::HIGHEST)); + EXPECT_THAT(conv->precision_config().operand_precision(), + ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::HIGHEST)); } } diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index d480d72297..933cf873e0 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -308,9 +308,9 @@ TEST_F(BFloat16NormalizationTest, DoNotAddUnsupportedMixedPrecision) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); HloInstruction* dot = builder.AddInstruction( HloInstruction::CreateDot(bf16_shape, a, b, dot_dnums, precision_config)); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 7398f105a0..56bd67fb55 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1490,9 +1490,9 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); auto dot_ab = builder.AddInstruction(HloInstruction::CreateDot( shape_2x4, param_a, param_b, dot_dnums, precision_config)); auto dot_bc = builder.AddInstruction(HloInstruction::CreateDot( diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 6bd0a2dd90..0fea462c85 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -38,9 +38,9 @@ std::unique_ptr MakeDot(const Shape& shape, HloInstruction* lhs, DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums, precision_config); } diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 0a49d85c6d..ef70b68877 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -112,9 +112,9 @@ std::unique_ptr MakeBigGraph() { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - /*new_size=*/2, PrecisionConfigProto::DEFAULT); + /*new_size=*/2, PrecisionConfig::DEFAULT); auto dot = builder.AddInstruction(HloInstruction::CreateDot( vshape, clamp, param_v0, dot_dnums, precision_config)); auto tuple = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 58b7af93eb..99d0cf50ca 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -172,7 +172,7 @@ message HloInstructionProto { xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; // Precision configuration for the instruction. Has backend-specific meaning. - xla.PrecisionConfigProto precision_config = 51; + xla.PrecisionConfig precision_config = 51; // Collective permute field. repeated SourceTarget source_target_pairs = 52; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index a2c1ce34c6..2aaaef1d36 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -601,9 +601,9 @@ TEST_F(HloComputationTest, Stringification) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); @@ -636,9 +636,9 @@ TEST_F(HloComputationTest, StringificationIndent) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); @@ -672,9 +672,9 @@ TEST_F(HloComputationTest, StringificationCanonical) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); builder.AddInstruction( HloInstruction::CreateDot(sout, x, reshape, dot_dnums, precision_config)); auto module = CreateNewModule(); diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index a6ae0337a5..a3fcc0fefa 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -63,7 +63,7 @@ StatusOr MakeSliceHlo(HloInstruction* operand, StatusOr MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config) { + const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN(Shape convolve_shape, @@ -167,10 +167,9 @@ StatusOr MakeConcatHlo( HloInstruction::CreateConcatenate(concat_shape, operands, dimension)); } -StatusOr MakeDotHlo( - HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config) { +StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config) { HloComputation* computation = lhs->parent(); CHECK_EQ(computation, rhs->parent()); TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 1c82956907..b22058abb4 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -50,7 +50,7 @@ StatusOr MakeSliceHlo(HloInstruction* operand, StatusOr MakeConvolveHlo( HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config); + const PrecisionConfig& precision_config); // Creates a transpose HLO instruction and adds it to the computation containing // `operand`. @@ -98,10 +98,9 @@ StatusOr MakeConcatHlo( // Creates a Dot HLO instruction and adds it to the computation containing `lhs` // and `rhs` (both must be in the same computation). -StatusOr MakeDotHlo( - HloInstruction* lhs, HloInstruction* rhs, - const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config); +StatusOr MakeDotHlo(HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config); // Creates a Map HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 62eea2b06c..72b236801a 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2334,9 +2334,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - 2, PrecisionConfigProto::DEFAULT); + 2, PrecisionConfig::DEFAULT); auto dot = builder.AddInstruction( HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index ffb3451164..d0d955fea8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -345,7 +345,7 @@ StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( StatusOr> HloEvaluator::EvaluateDotOp( const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, const Literal& lhs, + const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs) { std::unique_ptr lhs_instr = HloInstruction::CreateConstant(lhs.CloneToUnique()); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index e13af8e999..72252bafc7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -116,7 +116,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { StatusOr> EvaluateDotOp( const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, const Literal& lhs, + const PrecisionConfig& precision_config, const Literal& lhs, const Literal& rhs); protected: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f25761ac70..471a12d6aa 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -347,9 +347,9 @@ StatusOr> HloInstruction::CreateFromProto( << proto.operand_ids_size(); TF_RET_CHECK(proto.has_window()); TF_RET_CHECK(proto.has_convolution_dimension_numbers()); - PrecisionConfigProto precision_config = proto.precision_config(); + PrecisionConfig precision_config = proto.precision_config(); precision_config.mutable_operand_precision()->Resize( - proto.operand_ids_size(), PrecisionConfigProto::DEFAULT); + proto.operand_ids_size(), PrecisionConfig::DEFAULT); instruction = CreateConvolve( proto.shape(), operands(0), operands(1), std::max(proto.feature_group_count(), 1), proto.window(), @@ -475,7 +475,7 @@ StatusOr> HloInstruction::CreateFromProto( if (instruction->opcode() == HloOpcode::kDot) { instruction->precision_config_ = proto.precision_config(); instruction->precision_config_.mutable_operand_precision()->Resize( - instruction->operand_count(), PrecisionConfigProto::DEFAULT); + instruction->operand_count(), PrecisionConfig::DEFAULT); TF_RET_CHECK(proto.has_dot_dimension_numbers()); instruction->dot_dimension_numbers_ = absl::make_unique( @@ -657,7 +657,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config) { + const PrecisionConfig& precision_config) { return absl::make_unique( shape, lhs, rhs, feature_group_count, window, dimension_numbers, precision_config); @@ -673,7 +673,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config) { + const PrecisionConfig& precision_config) { auto instruction = absl::WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); instruction->AppendOperand(lhs); @@ -2888,8 +2888,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) { return absl::AsciiStrToLower(RandomDistribution_Name(distribution)); } -string PrecisionToString(const PrecisionConfigProto::Precision& precision) { - return absl::AsciiStrToLower(PrecisionConfigProto::Precision_Name(precision)); +string PrecisionToString(const PrecisionConfig::Precision& precision) { + return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision)); } string ConvolutionDimensionNumbersToString( @@ -2967,32 +2967,31 @@ StatusOr StringToRandomDistribution(const string& name) { string HloInstruction::PrecisionConfigToString() const { if (absl::c_all_of( precision_config_.operand_precision(), [](int32 precision) { - return static_cast(precision) == - PrecisionConfigProto::DEFAULT; + return static_cast(precision) == + PrecisionConfig::DEFAULT; })) { return ""; } return StrCat( "operand_precision={", - StrJoin(precision_config_.operand_precision(), ",", - [](string* out, int32 precision) { - CHECK(PrecisionConfigProto::Precision_IsValid(precision)) - << precision; - StrAppend(out, PrecisionToString( - static_cast( - precision))); - }), + StrJoin( + precision_config_.operand_precision(), ",", + [](string* out, int32 precision) { + CHECK(PrecisionConfig::Precision_IsValid(precision)) << precision; + StrAppend(out, + PrecisionToString( + static_cast(precision))); + }), "}"); } -StatusOr StringToPrecision( - const string& name) { - static std::unordered_map* map = [] { +StatusOr StringToPrecision(const string& name) { + static std::unordered_map* map = [] { static auto* map = - new std::unordered_map; - for (int i = 0; i < PrecisionConfigProto::Precision_ARRAYSIZE; i++) { - if (PrecisionConfigProto::Precision_IsValid(i)) { - auto value = static_cast(i); + new std::unordered_map; + for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) { + if (PrecisionConfig::Precision_IsValid(i)) { + auto value = static_cast(i); (*map)[PrecisionToString(value)] = value; } } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 55d592ff94..691f8155f9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -407,7 +407,7 @@ class HloInstruction { const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config); + const PrecisionConfig& precision_config); // Creates an FFT op, of the type indicated by fft_type. static std::unique_ptr CreateFft( @@ -419,7 +419,7 @@ class HloInstruction { static std::unique_ptr CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config); + const PrecisionConfig& precision_config); // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS @@ -1262,10 +1262,8 @@ class HloInstruction { // information. Transformations to other HLOs will not preserve this // information but it is presumed that the alternate lowering is strictly // superior. - const PrecisionConfigProto& precision_config() const { - return precision_config_; - } - void set_precision_config(const PrecisionConfigProto& precision_config) { + const PrecisionConfig& precision_config() const { return precision_config_; } + void set_precision_config(const PrecisionConfig& precision_config) { precision_config_ = precision_config; } @@ -1680,7 +1678,7 @@ class HloInstruction { // Information used to communicate to the implementation about the algorithm // used to produce results. See the documentation on precision_config(). - PrecisionConfigProto precision_config_; + PrecisionConfig precision_config_; // String identifier for instruction. string name_; @@ -1704,12 +1702,12 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); -string PrecisionToString(const PrecisionConfigProto::Precision& precision); +string PrecisionToString(const PrecisionConfig::Precision& precision); string ConvolutionDimensionNumbersToString( const ConvolutionDimensionNumbers& dnums); StatusOr StringToRandomDistribution(const string& name); -StatusOr StringToPrecision(const string& name); +StatusOr StringToPrecision(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 9eab6eea80..c1b7c3832b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1752,9 +1752,9 @@ TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) { auto* conv = module->entry_computation()->root_instruction(); auto clone = conv->Clone(); - EXPECT_THAT(clone->precision_config().operand_precision(), - ::testing::ElementsAre(PrecisionConfigProto::HIGH, - PrecisionConfigProto::DEFAULT)); + EXPECT_THAT( + clone->precision_config().operand_precision(), + ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT)); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e3683aaec9..ad87aa1123 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1630,7 +1630,7 @@ HloConvolutionInstruction::HloConvolutionInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config) + const PrecisionConfig& precision_config) : HloInstruction(HloOpcode::kConvolution, shape), feature_group_count_(feature_group_count), window_(window), diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 1c85aa4681..e1215a7566 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -944,7 +944,7 @@ class HloConvolutionInstruction : public HloInstruction { const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers, - const PrecisionConfigProto& precision_config); + const PrecisionConfig& precision_config); const Window& window() const override { return window_; } void set_window(const Window& window) override { window_ = window; } const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 62f01c4adb..0f26ed4235 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -221,7 +221,7 @@ class HloParser { bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); - bool ParsePrecisionList(std::vector* result); + bool ParsePrecisionList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); @@ -240,7 +240,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); - bool ParsePrecision(PrecisionConfigProto::Precision* result); + bool ParsePrecision(PrecisionConfig::Precision* result); bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -909,7 +909,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; - optional> operand_precision; + optional> operand_precision; attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, &operand_precision}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -922,13 +922,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!feature_group_count) { feature_group_count = 1; } - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; if (operand_precision) { *precision_config.mutable_operand_precision() = { operand_precision->begin(), operand_precision->end()}; } else { precision_config.mutable_operand_precision()->Resize( - operands.size(), PrecisionConfigProto::DEFAULT); + operands.size(), PrecisionConfig::DEFAULT); } instruction = builder->AddInstruction(HloInstruction::CreateConvolve( shape, /*lhs=*/operands[0], /*rhs=*/operands[1], @@ -1279,7 +1279,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; - optional> operand_precision; + optional> operand_precision; attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, &operand_precision}; @@ -1306,13 +1306,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, rhs_batch_dims->end()}; } - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; if (operand_precision) { *precision_config.mutable_operand_precision() = { operand_precision->begin(), operand_precision->end()}; } else { precision_config.mutable_operand_precision()->Resize( - operands.size(), PrecisionConfigProto::DEFAULT); + operands.size(), PrecisionConfig::DEFAULT); } instruction = builder->AddInstruction(HloInstruction::CreateDot( @@ -2410,11 +2410,11 @@ bool HloParser::ParseAttributeHelper( return ParseDomain(static_cast(attr_out_ptr)); } case AttrTy::kPrecisionList: { - std::vector result; + std::vector result; if (!ParsePrecisionList(&result)) { return false; } - static_cast>*>( + static_cast>*>( attr_out_ptr) ->emplace(result); return true; @@ -2698,9 +2698,9 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { // ::= /*empty*/ // ::= precision_val (delim precision_val)* bool HloParser::ParsePrecisionList( - std::vector* result) { + std::vector* result) { auto parse_and_add_item = [&]() { - PrecisionConfigProto::Precision item; + PrecisionConfig::Precision item; if (!ParsePrecision(&item)) { return false; } @@ -3032,7 +3032,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { return true; } -bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) { +bool HloParser::ParsePrecision(PrecisionConfig::Precision* result) { VLOG(1) << "ParsePrecision"; if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects random distribution"); diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 4a71ee909b..37b774b8a5 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -1031,8 +1031,8 @@ bool CanFoldDotIntoIndexedArray( StatusOr IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, - ScalarIndexedConstantArray* lhs, ConstantArray* rhs) { + const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, + ConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedLhs(" << ToString(lhs) << " " << ToString(rhs); if (!CanFoldDotIntoIndexedArray( @@ -1066,7 +1066,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedLhs( StatusOr IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, ConstantArray* lhs, + const PrecisionConfig& precision_config, ConstantArray* lhs, ScalarIndexedConstantArray* rhs) { VLOG(3) << "ComputeArrayForDotWithIndexedRhs(" << ToString(lhs) << " " << ToString(rhs); @@ -1101,7 +1101,7 @@ IndexedArrayAnalysis::ComputeArrayForDotWithIndexedRhs( StatusOr IndexedArrayAnalysis::ComputeArrayForDot( const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs) { + const PrecisionConfig& precision_config, Array* lhs, Array* rhs) { // Intuitively, if // // - The LHS of a dot product is a gathered sequence of rows from a constant diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index f21e784a4d..9746d176cc 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -267,17 +267,18 @@ class IndexedArrayAnalysis { StatusOr ComputeArrayForDotWithIndexedLhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, - ScalarIndexedConstantArray* lhs, ConstantArray* rhs); + const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, + ConstantArray* rhs); StatusOr ComputeArrayForDotWithIndexedRhs( const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, ConstantArray* lhs, + const PrecisionConfig& precision_config, ConstantArray* lhs, ScalarIndexedConstantArray* rhs); - StatusOr ComputeArrayForDot( - const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfigProto& precision_config, Array* lhs, Array* rhs); + StatusOr ComputeArrayForDot(const Shape& shape, + const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, + Array* lhs, Array* rhs); // This tries to fold a ScalarIndexedArray which has another // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index e3328203a6..2b2a2eb42a 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1064,9 +1064,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { DotDimensionNumbers dot_dnums; dot_dnums.add_lhs_contracting_dimensions(1); dot_dnums.add_rhs_contracting_dimensions(0); - PrecisionConfigProto precision_config; + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - /*new_size=*/2, PrecisionConfigProto::DEFAULT); + /*new_size=*/2, PrecisionConfig::DEFAULT); auto dot = builder.AddInstruction( HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config)); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index edab480091..3df99aac7d 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -121,10 +121,10 @@ StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, } /* static */ -PrecisionConfigProto HloTestBase::DefaultPrecisionConfig(int operands) { - PrecisionConfigProto precision_config; +PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { + PrecisionConfig precision_config; precision_config.mutable_operand_precision()->Resize( - operands, PrecisionConfigProto::DEFAULT); + operands, PrecisionConfig::DEFAULT); return precision_config; } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 89e72a045e..21d77c0cc4 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -80,7 +80,7 @@ class HloTestBase : public ::testing::Test { static StatusOr RunHloPass(HloPassInterface* hlo_pass, HloModule* module); - static PrecisionConfigProto DefaultPrecisionConfig(int operands); + static PrecisionConfig DefaultPrecisionConfig(int operands); protected: // This uses the interpreter backend as the reference backend and diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 8e43f275e1..dd329f1181 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -580,7 +580,7 @@ message SourceTarget { // Used to indicate the precision configuration. It has backend specific // meaning. -message PrecisionConfigProto { +message PrecisionConfig { enum Precision { DEFAULT = 0; HIGH = 1; -- GitLab From 2724362dcd8b2f1c417e4cabedd0ebdf6f6e100c Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Wed, 5 Sep 2018 14:22:37 -0700 Subject: [PATCH 139/540] Correct gradient for multi-output tfe.py_func PiperOrigin-RevId: 211698400 --- tensorflow/python/kernel_tests/py_func_test.py | 12 ++++++++++++ tensorflow/python/ops/script_ops.py | 6 +++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index 79fcbaad43..5f5e24bd63 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -566,6 +566,18 @@ class PyFuncTest(test.TestCase): dy_dx = gradients_impl.gradients(y, x)[0] self.assertEqual(self.evaluate(dy_dx), 6.0) + def testEagerGradientGraphTwoOutputs(self): + + def f(x, y): + return x * y, x / y + + x = constant_op.constant(3.0) + y = constant_op.constant(2.0) + fa, fb = script_ops.eager_py_func(f, inp=[x, y], + Tout=[dtypes.float32, dtypes.float32]) + dy_dx = gradients_impl.gradients(fa + fb, x)[0] + self.assertEqual(self.evaluate(dy_dx), 2.5) + @test_util.run_in_graph_and_eager_modes def testEagerGradientTapeMultipleArgs(self): diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 8d66de6b20..2ec4b540fb 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -287,19 +287,19 @@ def _internal_py_func(func, # TODO(akshayka): Implement higher-order derivatives. @ops.RegisterGradient("EagerPyFunc") -def _EagerPyFuncGrad(op, dy): +def _EagerPyFuncGrad(op, *dy): """Computes the gradient of an EagerPyFunc.""" token = op.get_attr("token") - def eagerly_executed_grad(dy): + def eagerly_executed_grad(*dy): tape, eager_inputs, eager_outputs = tape_cache.pop(compat.as_bytes(token)) return tape.gradient(eager_outputs, eager_inputs, output_gradients=dy) with ops.control_dependencies(op.outputs): return _internal_py_func( func=eagerly_executed_grad, - inp=[dy] if isinstance(dy, ops.Tensor) else dy, + inp=dy, Tout=[tensor.dtype for tensor in op.inputs], eager=True, is_grad_func=True) -- GitLab From a3c1ccd1da64040eeb139a0c6c1fc34ae46d7290 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 5 Sep 2018 14:33:37 -0700 Subject: [PATCH 140/540] Deprecate `tf.train.batch()` and related APIs. These APIs are based on queue runners, which have been deprecated and will be removed in TensorFlow 2.0. They have been replaced with `tf.data.Dataset`, which provides a more efficient version of the same functionality. PiperOrigin-RevId: 211700442 --- tensorflow/python/training/input.py | 48 +++++++++++++++---- .../api/golden/v2/tensorflow.train.pbtxt | 32 ------------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 0d6207f8c4..94c6b47027 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -45,6 +45,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.summary import summary from tensorflow.python.training import queue_runner +from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export @@ -894,7 +895,11 @@ def _shuffle_batch_join(tensors_list, batch_size, capacity, # Batching functions ---------------------------------------------------------- -@tf_export("train.batch") +@tf_export(v1=["train.batch"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.batch(batch_size)` (or `padded_batch(...)` if " + "`dynamic_pad=True`).") def batch(tensors, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -989,7 +994,11 @@ def batch(tensors, batch_size, num_threads=1, capacity=32, name=name) -@tf_export("train.maybe_batch") +@tf_export(v1=["train.maybe_batch"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.filter(...).batch(batch_size)` (or `padded_batch(...)`" + " if `dynamic_pad=True`).") def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -1042,7 +1051,11 @@ def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32, name=name) -@tf_export("train.batch_join") +@tf_export(v1=["train.batch_join"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.interleave(...).batch(batch_size)` (or " + "`padded_batch(...)` if `dynamic_pad=True`).") def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -1148,7 +1161,11 @@ def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False, name=name) -@tf_export("train.maybe_batch_join") +@tf_export(v1=["train.maybe_batch_join"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.interleave(...).filter(...).batch(batch_size)` (or " + "`padded_batch(...)` if `dynamic_pad=True`).") def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, @@ -1201,7 +1218,10 @@ def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32, name=name) -@tf_export("train.shuffle_batch") +@tf_export(v1=["train.shuffle_batch"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.shuffle(min_after_dequeue).batch(batch_size)`.") def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -1301,7 +1321,11 @@ def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, name=name) -@tf_export("train.maybe_shuffle_batch") +@tf_export(v1=["train.maybe_shuffle_batch"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.filter(...).shuffle(min_after_dequeue).batch(batch_size)`" + ".") def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, keep_input, num_threads=1, seed=None, enqueue_many=False, shapes=None, @@ -1361,7 +1385,11 @@ def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, name=name) -@tf_export("train.shuffle_batch_join") +@tf_export(v1=["train.shuffle_batch_join"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.interleave(...).shuffle(min_after_dequeue).batch" + "(batch_size)`.") def shuffle_batch_join(tensors_list, batch_size, capacity, min_after_dequeue, seed=None, enqueue_many=False, shapes=None, allow_smaller_final_batch=False, @@ -1455,7 +1483,11 @@ def shuffle_batch_join(tensors_list, batch_size, capacity, name=name) -@tf_export("train.maybe_shuffle_batch_join") +@tf_export(v1=["train.maybe_shuffle_batch_join"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.interleave(...).filter(...).shuffle(min_after_dequeue)" + ".batch(batch_size)`.") def maybe_shuffle_batch_join(tensors_list, batch_size, capacity, min_after_dequeue, keep_input, seed=None, enqueue_many=False, shapes=None, diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt index c35e254843..e2b74e4d67 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt @@ -248,14 +248,6 @@ tf_module { name: "basic_train_loop" argspec: "args=[\'supervisor\', \'train_step_fn\', \'args\', \'kwargs\', \'master\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'\'], " } - member_method { - name: "batch" - argspec: "args=[\'tensors\', \'batch_size\', \'num_threads\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], " - } - member_method { - name: "batch_join" - argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], " - } member_method { name: "checkpoint_exists" argspec: "args=[\'checkpoint_prefix\'], varargs=None, keywords=None, defaults=None" @@ -352,22 +344,6 @@ tf_module { name: "match_filenames_once" argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "maybe_batch" - argspec: "args=[\'tensors\', \'keep_input\', \'batch_size\', \'num_threads\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], " - } - member_method { - name: "maybe_batch_join" - argspec: "args=[\'tensors_list\', \'keep_input\', \'batch_size\', \'capacity\', \'enqueue_many\', \'shapes\', \'dynamic_pad\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'False\', \'False\', \'None\', \'None\'], " - } - member_method { - name: "maybe_shuffle_batch" - argspec: "args=[\'tensors\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'keep_input\', \'num_threads\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], " - } - member_method { - name: "maybe_shuffle_batch_join" - argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'keep_input\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], " - } member_method { name: "natural_exp_decay" argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " @@ -408,14 +384,6 @@ tf_module { name: "sdca_shrink_l1" argspec: "args=[\'weights\', \'l1\', \'l2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "shuffle_batch" - argspec: "args=[\'tensors\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'num_threads\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], " - } - member_method { - name: "shuffle_batch_join" - argspec: "args=[\'tensors_list\', \'batch_size\', \'capacity\', \'min_after_dequeue\', \'seed\', \'enqueue_many\', \'shapes\', \'allow_smaller_final_batch\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\', \'None\', \'None\'], " - } member_method { name: "slice_input_producer" argspec: "args=[\'tensor_list\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], " -- GitLab From 75390d4c3568358ea81a072b0ccc94071022c38d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 14:40:57 -0700 Subject: [PATCH 141/540] Special-case the AccumulateNV2 op in print_selective_registration_header AccumulateNV2 doesn't have or need a kernel. It gets rewritten to other ops by accumulate_n_optimizer.cc. This change allows it to be mentioned in the output of print_selective_registration_header, rather than being ignored with a warning. Behavior for other ops is preserved. PiperOrigin-RevId: 211701878 --- .../print_selective_registration_header_test.py | 12 ++++++++++++ .../tools/selective_registration_header_lib.py | 17 +++++++++++++++-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/tools/print_selective_registration_header_test.py b/tensorflow/python/tools/print_selective_registration_header_test.py index 4b3d98242c..cce8060fb9 100644 --- a/tensorflow/python/tools/print_selective_registration_header_test.py +++ b/tensorflow/python/tools/print_selective_registration_header_test.py @@ -59,6 +59,9 @@ GRAPH_DEF_TXT = """ } """ +# AccumulateNV2 is included because it should be included in the header despite +# lacking a kernel (it's rewritten by AccumulateNV2RemovePass; see +# core/common_runtime/accumulate_n_optimizer.cc. GRAPH_DEF_TXT_2 = """ node: { name: "node_4" @@ -67,6 +70,12 @@ GRAPH_DEF_TXT_2 = """ device: "/cpu:0" attr: { key: "T" value: { type: DT_FLOAT } } } + node: { + name: "node_5" + op: "AccumulateNV2" + attr: { key: "T" value: { type: DT_INT32 } } + attr: { key : "N" value: { i: 3 } } + } """ @@ -100,6 +109,7 @@ class PrintOpFilegroupTest(test.TestCase): self.assertListEqual( [ + ('AccumulateNV2', None), # ('BiasAdd', 'BiasOp'), # ('MatMul', matmul_prefix + 'MatMulOp'), # @@ -117,6 +127,7 @@ class PrintOpFilegroupTest(test.TestCase): 'rawproto', self.WriteGraphFiles(graphs), default_ops) self.assertListEqual( [ + ('AccumulateNV2', None), # ('BiasAdd', 'BiasOp'), # ('MatMul', matmul_prefix + 'MatMulOp'), # @@ -196,6 +207,7 @@ class PrintOpFilegroupTest(test.TestCase): constexpr inline bool ShouldRegisterOp(const char op[]) { return false + || isequal(op, "AccumulateNV2") || isequal(op, "BiasAdd") ; } diff --git a/tensorflow/python/tools/selective_registration_header_lib.py b/tensorflow/python/tools/selective_registration_header_lib.py index dc0612bb3f..b99c632c3e 100644 --- a/tensorflow/python/tools/selective_registration_header_lib.py +++ b/tensorflow/python/tools/selective_registration_header_lib.py @@ -32,6 +32,16 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging +# Usually, we use each graph node to induce registration of an op and +# corresponding kernel; nodes without a corresponding kernel (perhaps due to +# attr types) generate a warning but are otherwise ignored. Ops in this set are +# registered even if there's no corresponding kernel. +OPS_WITHOUT_KERNEL_WHITELIST = frozenset([ + # AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see + # core/common_runtime/accumulate_n_optimizer.cc. + 'AccumulateNV2' +]) + def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): """Gets the ops and kernels needed from the model files.""" @@ -53,8 +63,10 @@ def get_ops_and_kernels(proto_fileformat, proto_files, default_ops_str): node_def.device = '/cpu:0' kernel_class = pywrap_tensorflow.TryFindKernelClass( node_def.SerializeToString()) - if kernel_class: - op_and_kernel = (str(node_def.op), str(kernel_class.decode('utf-8'))) + op = str(node_def.op) + if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST: + op_and_kernel = (op, str(kernel_class.decode('utf-8')) + if kernel_class else None) if op_and_kernel not in ops: ops.add(op_and_kernel) else: @@ -129,6 +141,7 @@ def get_header_from_ops_and_kernels(ops_and_kernels, ''' line += 'constexpr const char* kNecessaryOpKernelClasses[] = {\n' for _, kernel_class in ops_and_kernels: + if kernel_class is None: continue line += '"%s",\n' % kernel_class line += '};' append(line) -- GitLab From 9a343a2be2469442ea6bb87f23fc043e1d14cc3b Mon Sep 17 00:00:00 2001 From: Suharsh Sivakumar Date: Wed, 5 Sep 2018 14:48:34 -0700 Subject: [PATCH 142/540] Skip quantization of optional tensors (tensor_idx = -1) PiperOrigin-RevId: 211703281 --- .../lite/tools/optimize/quantize_weights.cc | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc index e5bb3c990a..692efb9029 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc @@ -168,11 +168,16 @@ std::vector GetQuantizableTensorsFromOperator( bool eval_hybrid = use_hybrid_evaluation && IsHybridEvaluationOp(op, op_code); - bool skipped_tensor = false; std::vector op_input_indices = GetWeightInputIndices(op_code); for (const int32_t op_input_idx : op_input_indices) { int32_t tensor_idx = op->inputs[op_input_idx]; + if (tensor_idx == -1) { + LOG(INFO) << "Skipping optional tensor input " << op_input_idx + << " of operation " << EnumNameBuiltinOperator(op_code); + continue; + } + TensorT* tensor = subgraph->tensors[tensor_idx].get(); // TODO(suharshs): Support shared weights, i.e. If two tensors share the // same weight array, things may break. (i.e. SSD object detection) @@ -180,14 +185,12 @@ std::vector GetQuantizableTensorsFromOperator( CountTensorConsumers(model, subgraph, tensor_idx) != 1) { LOG(INFO) << "Skipping quantization of tensor " << tensor->name << " that is shared between multiple multiple operations."; - skipped_tensor = true; continue; } if (tensor->type != TensorType_FLOAT32) { LOG(INFO) << "Skipping quantization of tensor " << tensor->name << " that is not type float."; - skipped_tensor = true; continue; } @@ -196,7 +199,9 @@ std::vector GetQuantizableTensorsFromOperator( LOG(INFO) << "Skipping quantization of tensor " << tensor->name << " because it has fewer than " << weights_min_num_elements << " elements (" << num_elements << ")."; - skipped_tensor = true; + // If one of the weights isn't quantized, then we cannot use the hybrid + // kernel for this operation, since it expects everything to be quantized. + eval_hybrid = false; continue; } @@ -209,12 +214,6 @@ std::vector GetQuantizableTensorsFromOperator( tensor_infos.push_back(tensor_info); } - // For hybrid operations we either need to quantize all tensors or none. So - // if we skipped any tensors we need to return no quantized tensors. - if (eval_hybrid && skipped_tensor) { - return {}; - } - return tensor_infos; } -- GitLab From b84d27deb8c13eb426951dca6656de2f333f13d5 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Wed, 5 Sep 2018 14:58:10 -0700 Subject: [PATCH 143/540] Support converting eager tensor to tf.float16 if a numpy half is passed. This still defaults to float32 for all normal floats. PiperOrigin-RevId: 211704918 --- tensorflow/python/eager/tensor_test.py | 1 + tensorflow/python/lib/core/py_seq_tensor.cc | 25 +++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 871136e2c8..32742a9b96 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -295,6 +295,7 @@ class TFETensorUtilTest(test_util.TensorFlowTestCase): def testFloatTensor(self): self.assertEqual(dtypes.float64, _create_tensor(np.float64()).dtype) self.assertEqual(dtypes.float32, _create_tensor(np.float32()).dtype) + self.assertEqual(dtypes.float16, _create_tensor(np.float16()).dtype) self.assertEqual(dtypes.float32, _create_tensor(0.0).dtype) def testSliceDimOutOfRange(self): diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 3b4f12ae31..269142a7c2 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -55,6 +55,10 @@ bool IsPyDouble(PyObject* obj) { return PyIsInstance(obj, &PyDoubleArrType_Type); // NumPy double type. } +bool IsNumpyHalf(PyObject* obj) { + return PyIsInstance(obj, &PyHalfArrType_Type); +} + bool IsPyFloat(PyObject* obj) { return PyFloat_Check(obj) || PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types @@ -156,6 +160,8 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) { } } else if (IsPyDouble(obj)) { *dtype = DT_DOUBLE; + } else if (IsNumpyHalf(obj)) { + *dtype = DT_HALF; } else if (IsPyFloat(obj)) { *dtype = DT_FLOAT; } else if (PyBool_Check(obj) || PyIsInstance(obj, &PyBoolArrType_Type)) { @@ -357,6 +363,17 @@ const char* ConvertOneFloat(PyObject* v, T* out) { DEFINE_HELPER(ConvertDouble, double, DT_DOUBLE, ConvertOneFloat); DEFINE_HELPER(ConvertFloat, float, DT_FLOAT, ConvertOneFloat); +const char* ConvertOneNumpyHalf(PyObject* v, Eigen::half* out) { + // NOTE(nareshmodi): Is there a way to convert to C double without the + // intermediate Python double? This will help with ConvertOneFloat as well. + Safe_PyObjectPtr as_float = make_safe(PyNumber_Float(v)); + double v_double = PyFloat_AS_DOUBLE(as_float.get()); + *out = Eigen::half(v_double); + + return nullptr; +} +DEFINE_HELPER(ConvertNumpyHalf, Eigen::half, DT_HALF, ConvertOneNumpyHalf); + // String support const char* ConvertOneString(PyObject* v, string* out) { @@ -452,6 +469,9 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) { if (ConvertDouble(obj, shape, ret) == nullptr) return Status::OK(); break; + case DT_HALF: + RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret)); + case DT_INT64: if (ConvertInt64(obj, shape, ret) == nullptr) return Status::OK(); break; @@ -489,8 +509,13 @@ Status PySeqToTensor(PyObject* obj, PyObject* dtype, Tensor* ret) { // final type. RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret)); } + case DT_DOUBLE: RETURN_STRING_AS_STATUS(ConvertDouble(obj, shape, ret)); + + case DT_HALF: + RETURN_STRING_AS_STATUS(ConvertNumpyHalf(obj, shape, ret)); + case DT_INT64: if (requested_dtype == DT_INVALID) { const char* error = ConvertInt32(obj, shape, ret); -- GitLab From 40e262c0dc3f6eafe46978f63a4d849e5fd6d69e Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Wed, 5 Sep 2018 15:00:09 -0700 Subject: [PATCH 144/540] Experimental work-in-progress support for TPUStrategy in keras. PiperOrigin-RevId: 211705274 --- .../distribute/python/examples/keras_mnist.py | 4 +- .../keras/engine/training_distributed.py | 237 ++++++++++++++---- 2 files changed, 193 insertions(+), 48 deletions(-) diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py index a20069c4fe..0495134636 100644 --- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py +++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py @@ -58,13 +58,13 @@ def get_input_datasets(): train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_ds = train_ds.repeat() train_ds = train_ds.shuffle(100) - train_ds = train_ds.batch(64) + train_ds = train_ds.batch(64, drop_remainder=True) # eval dataset eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) eval_ds = eval_ds.repeat() eval_ds = eval_ds.shuffle(100) - eval_ds = eval_ds.batch(64) + eval_ds = eval_ds.batch(64, drop_remainder=True) return train_ds, eval_ds, input_shape diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py index a7bb1f8177..e440e02bfb 100644 --- a/tensorflow/python/keras/engine/training_distributed.py +++ b/tensorflow/python/keras/engine/training_distributed.py @@ -19,13 +19,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np +from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.keras import backend as K from tensorflow.python.keras import callbacks as cbks from tensorflow.python.keras import optimizers from tensorflow.python.keras.engine import distributed_training_utils from tensorflow.python.keras.utils.generic_utils import Progbar +from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import distribute as distribute_lib def fit_loop( @@ -64,6 +67,11 @@ def fit_loop( """ current_strategy = model._distribution_strategy + # TODO(priyag, sourabhbajaj): Remove this when the codepaths are merged. + if current_strategy.__class__.__name__ == 'TPUStrategy': + return _experimental_fit_loop( + model, iterator, epochs, initial_epoch, steps_per_epoch) + clone_model_on_towers( model, current_strategy, make_callback_model=True) @@ -116,11 +124,6 @@ def fit_loop( do_validation = False if validation_steps: do_validation = True - if steps_per_epoch is None: - raise ValueError('Can only use `validation_steps` ' - 'when doing step-wise ' - 'training, i.e. `steps_per_epoch` ' - 'must be set.') # Copy the weights from the original model to each of the replicated models. orig_model_weights = model.get_weights() @@ -140,44 +143,46 @@ def fit_loop( verbose=verbose) out_labels = model.metrics_names or [] callbacks.on_train_begin() + + assert steps_per_epoch is not None + for epoch in range(initial_epoch, epochs): callbacks.on_epoch_begin(epoch) - if steps_per_epoch is not None: - epoch_logs = {} - for step_index in range(steps_per_epoch): - batch_logs = {'batch': step_index, 'size': 1} - callbacks.on_batch_begin(step_index, batch_logs) - try: - outs = distributed_train_function(ins) - except errors.OutOfRangeError: - logging.warning('Your dataset iterator ran out of data; ' - 'interrupting training. Make sure that your dataset ' - 'can generate at least `steps_per_epoch * epochs` ' - 'batches (in this case, %d batches).' % - steps_per_epoch * epochs) - break - - if not isinstance(outs, list): - outs = [outs] - - outs = _aggregate_metrics_across_towers( - current_strategy.num_towers, out_labels, outs) - for l, o in zip(out_labels, outs): - batch_logs[l] = o - callbacks.on_batch_end(step_index, batch_logs) - if callbacks.model.stop_training: - break - if do_validation: - val_outs = test_loop( - model, - val_iterator, - steps=validation_steps, - verbose=0) - if not isinstance(val_outs, list): - val_outs = [val_outs] - # Same labels assumed. - for l, o in zip(out_labels, val_outs): - epoch_logs['val_' + l] = o + epoch_logs = {} + for step_index in range(steps_per_epoch): + batch_logs = {'batch': step_index, 'size': 1} + callbacks.on_batch_begin(step_index, batch_logs) + try: + outs = distributed_train_function(ins) + except errors.OutOfRangeError: + logging.warning('Your dataset iterator ran out of data; ' + 'interrupting training. Make sure that your dataset ' + 'can generate at least `steps_per_epoch * epochs` ' + 'batches (in this case, %d batches).' % + steps_per_epoch * epochs) + break + + if not isinstance(outs, list): + outs = [outs] + + outs = _aggregate_metrics_across_towers( + current_strategy.num_towers, out_labels, outs) + for l, o in zip(out_labels, outs): + batch_logs[l] = o + callbacks.on_batch_end(step_index, batch_logs) + if callbacks.model.stop_training: + break + if do_validation: + val_outs = test_loop( + model, + val_iterator, + steps=validation_steps, + verbose=0) + if not isinstance(val_outs, list): + val_outs = [val_outs] + # Same labels assumed. + for l, o in zip(out_labels, val_outs): + epoch_logs['val_' + l] = o callbacks.on_epoch_end(epoch, epoch_logs) if callbacks.model.stop_training: @@ -192,6 +197,139 @@ def fit_loop( return model.history +def _experimental_fit_loop( + model, + iterator, + epochs=100, + initial_epoch=0, + steps_per_epoch=None): + """fit function when using TPU DistributionStrategy for training. + + Arguments: + model: Keras Model instance. + iterator: Iterator that returns inputs and targets + epochs: Number of times to iterate over the data + initial_epoch: Epoch at which to start training + (useful for resuming a previous training run) + steps_per_epoch: Total number of steps (batches of samples) + before declaring one epoch finished and starting the + next epoch. Ignored with the default value of `None`. + + Returns: + Returns `None`. + + Raises: + ValueError: in case of invalid arguments. + """ + current_strategy = model._distribution_strategy + + # TODO(priyag): Add validation that shapes are fully defined for TPU case. + + # TODO(priyag, sourabhbajaj): This should be moved into a callback instead. + K.get_session().run(current_strategy.initialize()) + + def _per_device_train_function(model): + model._make_train_function() + return (model.train_function.inputs, + model.train_function.outputs, + model.train_function.updates_op, + model.train_function.session_kwargs) + + # TODO(priyag, sourabhbajaj): This should likely not be hardcoded here. + K.set_learning_phase(1) + + def step_fn(ctx, inputs, targets): + """Clones the model and calls make_train_function.""" + # TODO(priyag, sourabhbajaj): Should cache this keyed on input shapes. + clone_model_on_towers( + model, + current_strategy, + make_callback_model=True, + inputs=inputs, + targets=targets) + + (grouped_inputs, grouped_outputs, grouped_updates, + grouped_session_args) = current_strategy.call_for_each_tower( + _per_device_train_function, model._grouped_model) + (all_inputs, all_outputs, all_updates, + all_session_args) = distributed_training_utils.unwrap_values( + current_strategy, grouped_inputs, grouped_outputs, + grouped_updates, grouped_session_args, with_loss_tensor=True) + combined_fn = K.Function( + all_inputs, all_outputs, + updates=all_updates, + name='distributed_train_function', + **all_session_args) + + # TODO(priyag, sourabhbajaj): Perhaps the aggregation type needs to be + # something else for different outputs. + out_labels = model.metrics_names or [] + for label, output in zip(out_labels, combined_fn.outputs): + ctx.set_last_step_output(label, output, + aggregation=distribute_lib.get_loss_reduction()) + + # TODO(priyag, sourabhbajaj): Ignoring these things from the combined_fn: + # feed_dict, session kwargs, run options, run_metadata for now. These should + # be handled appropriately + return combined_fn.updates_op + + # Add initial dummy values for loss and other metric tensors. + initial_loop_values = {} + initial_loop_values['loss'] = constant_op.constant(1e7) + for name, tensor in zip(model.metrics_names[1:], model.metrics_tensors): + initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype) + + with current_strategy.scope(): + # TODO(priyag, sourabhbajaj): Adjust steps_per_run appropriately based on + # steps_per_epoch and number of epochs. + ctx = current_strategy.run_steps_on_dataset( + step_fn, iterator, iterations=current_strategy.steps_per_run, + initial_loop_values=initial_loop_values) + + train_op = ctx.run_op + output_tensors = ctx.last_step_outputs + + # Copy the weights from the original model to each of the replicated models. + orig_model_weights = model.get_weights() + with current_strategy.scope(): + distributed_model = current_strategy.unwrap(model._grouped_model)[0] + distributed_training_utils.set_weights( + current_strategy, distributed_model, orig_model_weights) + + assert steps_per_epoch is not None + + # TODO(priyag, sourabhbajaj): Add callbacks support. + # TODO(priyag, sourabhbajaj): Add validation. + for epoch in range(initial_epoch, epochs): + for step_index in range( + 0, steps_per_epoch, current_strategy.steps_per_run): + try: + _, outs = K.get_session().run([train_op, output_tensors]) + # TODO(priyag, sourabhbajaj): Remove this logging in favor of proper + # summaries through callbacks. + print('Epoch: {}, step_index: {}, loss: {}'.format( + epoch, step_index, outs['loss'])) + for label, out in outs.items(): + print(label, ': ', out) + except errors.OutOfRangeError: + logging.warning('Your dataset iterator ran out of data; ' + 'interrupting training. Make sure that your dataset ' + 'can generate at least `steps_per_epoch * epochs` ' + 'batches (in this case, %d batches).' % + steps_per_epoch * epochs) + break + + # Copy the weights back from the replicated model to the original model. + with current_strategy.scope(): + updated_weights = current_strategy.unwrap( + model._grouped_model)[0].get_weights() + model.set_weights(updated_weights) + + K.get_session().run(current_strategy.finalize()) + + # TODO(priyag, sourabhbajaj): Return history. + + def test_loop(model, iterator, verbose=0, steps=None): """evaluate method to validate a model that uses DistributionStrategy. @@ -373,12 +511,12 @@ def predict_loop(model, iterator, verbose=0, steps=None): ] -def _clone_and_build_model(model): +def _clone_and_build_model(model, inputs=None, targets=None): """Clone and build the given keras_model.""" # We need to set the import here since we run into a circular dependency # error. from tensorflow.python.keras import models # pylint: disable=g-import-not-at-top - cloned_model = models.clone_model(model, input_tensors=None) + cloned_model = models.clone_model(model, input_tensors=inputs) # Compile and build model. if isinstance(model.optimizer, optimizers.TFOptimizer): @@ -387,22 +525,29 @@ def _clone_and_build_model(model): optimizer_config = model.optimizer.get_config() optimizer = model.optimizer.__class__.from_config(optimizer_config) + # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a + # single tensor should be OK but it throws an error in that case. + if (targets is not None and not isinstance(targets, list) and + not isinstance(targets, dict)): + targets = [targets] cloned_model.compile( optimizer, model.loss, metrics=model.metrics, loss_weights=model.loss_weights, sample_weight_mode=model.sample_weight_mode, - weighted_metrics=model.weighted_metrics) + weighted_metrics=model.weighted_metrics, + target_tensors=targets) return cloned_model -def clone_model_on_towers(model, strategy, make_callback_model=False): +def clone_model_on_towers( + model, strategy, make_callback_model=False, inputs=None, targets=None): """Create a cloned model on each tower, unless already created.""" if not model._grouped_model: with strategy.scope(): model._grouped_model = strategy.call_for_each_tower( - _clone_and_build_model, model) + _clone_and_build_model, model, inputs, targets) if make_callback_model: model._make_callback_model() -- GitLab From 7e2577b0984a1d8f41af97942fcdf5b9f1ff8622 Mon Sep 17 00:00:00 2001 From: Jiri Simsa Date: Wed, 5 Sep 2018 15:07:07 -0700 Subject: [PATCH 145/540] [tf.data] Minor fix to remove unnecessary difference between the implementations of the batch and padded batch reducers. PiperOrigin-RevId: 211706766 --- tensorflow/contrib/data/python/ops/batching.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 9c2001c34f..367c159dc5 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -272,9 +272,9 @@ def _padded_batch_dense_window(dataset, padded_shape, padding_value=None): padding_value = 0 def batch_init_fn(_): - return array_ops.fill( - array_ops.concat([np.array([0], dtype=np.int32), padded_shape], 0), - constant_op.constant(padding_value, dtype=dataset.output_types)) + batch_shape = array_ops.concat( + [np.array([0], dtype=np.int32), padded_shape], 0) + return gen_array_ops.empty(batch_shape, dtype=dataset.output_types) def batch_reduce_fn(state, value): return array_ops.concat([state, [value]], 0) -- GitLab From 59c43f26dec90afa66116acdaff8cdf693728adb Mon Sep 17 00:00:00 2001 From: Jianwei Xie Date: Wed, 5 Sep 2018 15:08:37 -0700 Subject: [PATCH 146/540] Remove logging which generates tons of logs for large model. PiperOrigin-RevId: 211707155 --- tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py index a423aeace7..170977d8ab 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -30,7 +30,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import tf_logging as logging @contextlib.contextmanager @@ -258,7 +257,6 @@ def replicated_scope(num_replicas): collections = [ops.GraphKeys.GLOBAL_VARIABLES] kwargs["collections"] = [] - logging.info("Constructing replicated variable %s", name) variables = [] index = {} for i in range(num_replicas): -- GitLab From 24bd1154b3c83cbf07883010240c3d1d13e25833 Mon Sep 17 00:00:00 2001 From: Niranjan Hasabnis Date: Wed, 5 Sep 2018 15:28:00 -0700 Subject: [PATCH 147/540] Addressing review comments --- .../core/common_runtime/mkl_cpu_allocator.h | 55 +++++++++++-------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 2778213a82..553f07020e 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -50,9 +50,9 @@ class MklSubAllocator : public SubAllocator { void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); } }; -/// CPU allocator that handles small-size allocations by calling -/// suballocator directly. Mostly, it is just a wrapper around a suballocator -/// (that calls malloc and free directly) with support for bookkeeping. +// CPU allocator that handles small-size allocations by calling +// suballocator directly. Mostly, it is just a wrapper around a suballocator +// (that calls malloc and free directly) with support for bookkeeping. class MklSmallSizeAllocator : public VisitableAllocator { public: MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory, @@ -67,12 +67,12 @@ class MklSmallSizeAllocator : public VisitableAllocator { inline string Name() override { return name_; } void* AllocateRaw(size_t alignment, size_t num_bytes) override { - void* ptr = nullptr; - if ((ptr = sub_allocator_->Alloc(alignment, num_bytes)) != nullptr) { + void* ptr = sub_allocator_->Alloc(alignment, num_bytes); + if (ptr != nullptr) { std::pair map_val(ptr, num_bytes); mutex_lock l(mutex_); // Check that insertion in the hash map was successful. - CHECK_EQ(map_.insert(map_val).second, true); + CHECK(map_.insert(map_val).second); // Increment statistics for small-size allocations. IncrementStats(num_bytes); // Call alloc visitors. @@ -100,6 +100,9 @@ class MklSmallSizeAllocator : public VisitableAllocator { sub_allocator_->Free(ptr, dealloc_bytes); DecrementStats(dealloc_bytes); map_.erase(map_iter); + } else { + LOG(ERROR) << "tried to deallocate invalid pointer"; + return; } } @@ -129,8 +132,8 @@ class MklSmallSizeAllocator : public VisitableAllocator { } private: - /// Increment statistics for the allocator handling small allocations. - inline void IncrementStats(size_t alloc_size) { + // Increment statistics for the allocator handling small allocations. + inline void IncrementStats(size_t alloc_size) GUARDED_BY(mutex_) { ++stats_.num_allocs; stats_.bytes_in_use += alloc_size; stats_.max_bytes_in_use = std::max(stats_.max_bytes_in_use, @@ -139,27 +142,27 @@ class MklSmallSizeAllocator : public VisitableAllocator { static_cast(stats_.max_alloc_size)); } - /// Decrement statistics for the allocator handling small allocations. - inline void DecrementStats(size_t dealloc_size) { + // Decrement statistics for the allocator handling small allocations. + inline void DecrementStats(size_t dealloc_size) GUARDED_BY(mutex_) { stats_.bytes_in_use -= dealloc_size; } SubAllocator* sub_allocator_; // Not owned by this class. - /// Mutex for protecting updates to map of allocations. + // Mutex for protecting updates to map of allocations. mutable mutex mutex_; - /// Allocator name + // Allocator name string name_; - /// Hash map to keep track of "small" allocations - /// We do not use BFC allocator for small allocations. + // Hash map to keep track of "small" allocations + // We do not use BFC allocator for small allocations. std::unordered_map map_ GUARDED_BY(mutex_); - /// Allocator stats for small allocs + // Allocator stats for small allocs AllocatorStats stats_ GUARDED_BY(mutex_); - /// Visitors + // Visitors std::vector alloc_visitors_ GUARDED_BY(mutex_); std::vector free_visitors_ GUARDED_BY(mutex_); }; @@ -217,6 +220,9 @@ class MklCPUAllocator : public VisitableAllocator { VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes; sub_allocator_ = new MklSubAllocator(); + + // SubAllocator is owned by BFCAllocator, so we do not need to deallocate + // it in MklSmallSizeAllocator. small_size_allocator_ = new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName); large_size_allocator_ = new BFCAllocator(sub_allocator_, max_mem_bytes, @@ -264,8 +270,11 @@ class MklCPUAllocator : public VisitableAllocator { stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use; stats->max_bytes_in_use = l_stats.max_bytes_in_use + s_stats.max_bytes_in_use; - stats->max_alloc_size = std::max(l_stats.max_alloc_size, - s_stats.max_alloc_size); + + // Since small-size allocations go to MklSmallSizeAllocator, + // max_alloc_size from large_size_allocator would be the maximum + // size allocated by MklCPUAllocator. + stats->max_alloc_size = l_stats.max_alloc_size; } void ClearStats() override { @@ -308,13 +317,13 @@ class MklCPUAllocator : public VisitableAllocator { TF_CHECK_OK(s); // way to assert with an error message } - /// Do we allow growth in BFC Allocator + // Do we allow growth in BFC Allocator static const bool kAllowGrowth = true; - /// Name + // Name static constexpr const char* kName = "mklcpu"; - /// The alignment that we need for the allocations + // The alignment that we need for the allocations static constexpr const size_t kAlignment = 64; VisitableAllocator* large_size_allocator_; // owned by this class @@ -322,8 +331,8 @@ class MklCPUAllocator : public VisitableAllocator { SubAllocator* sub_allocator_; // not owned by this class - /// Size in bytes that defines the upper-bound for "small" allocations. - /// Any allocation below this threshold is "small" allocation. + // Size in bytes that defines the upper-bound for "small" allocations. + // Any allocation below this threshold is "small" allocation. static constexpr const size_t kSmallAllocationsThreshold = 4096; }; -- GitLab From 99fe2f603466a03897fd653f9fdf583b78b9d5b0 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 5 Sep 2018 15:13:16 -0700 Subject: [PATCH 148/540] Fold CapturingGraph into FuncGraph. There's no need for the two separate classes anymore. This also cleans up some other parts of the interface: * Removes the clear_resource_control_flow_state, which isn't used anywhere * Makes capture_value a private method of FuncGraph (_capture_helper) * Makes create_substitute_placeholder private PiperOrigin-RevId: 211707906 --- tensorflow/python/eager/function.py | 211 ++++++++++--------- tensorflow/python/keras/engine/base_layer.py | 2 +- tensorflow/python/keras/engine/network.py | 2 +- 3 files changed, 111 insertions(+), 104 deletions(-) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index b57979b484..d56c1457e0 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -59,7 +59,7 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access -def create_substitute_placeholder(value, name, dtype=None): +def _create_substitute_placeholder(value, name, dtype=None): """Creates a placeholder for `value` and propagates shape info to it.""" # Note: setting ops.control_dependencies(None) ensures we always put # capturing placeholders outside of any control flow context. @@ -91,100 +91,6 @@ def create_substitute_placeholder(value, name, dtype=None): return placeholder -def capture_value(tensor_map, value, dtype, name): - """Capture a value from outside the function, to pass in as an extra arg.""" - captured_value = tensor_map.get(value, None) - if captured_value is None: - captured_value = create_substitute_placeholder(value, name=name, - dtype=dtype) - tensor_map[value] = captured_value - tape.record_operation("captured_value", [captured_value], [value], - lambda x: [x]) - return captured_value - - -class CapturingGraph(ops.Graph): - """Graph that can capture tensors from other graphs. - - Attributes: - captures: Maps external tensor -> internal tensor (e.g. input placeholder). - The entries are in the order they were captured. - """ - - def __init__(self): - super(CapturingGraph, self).__init__() - - self.captures = collections.OrderedDict() - self._building_function = True - - # Map from resource tensor name to last op (in program order) which uses - # this tensor. Used to enforce that execution order matches program order - # for resource tensors. - self._last_op_using_resource_tensor = {} - - def clear_resource_control_flow_state(self): - self._last_op_using_resource_tensor = {} - - # TODO(skyewm): get rid of name and use the name of `tensor`. - def capture(self, tensor, name=None): - """Capture `tensor` if it's external to this graph. - - If `tensor` is from a different graph, returns a placeholder for it. - `tensor` and the placeholder will also appears in self.captures. Multiple - calls to this method with the same `tensor` argument will return the same - placeholder. If `tensor` is from this graph, returns `tensor`. - - Args: - tensor: Tensor. May be from this FuncGraph or a different graph. - name: Optional name if a placeholder is created. - - Returns: - Tensor from this FuncGraph. - """ - if isinstance(tensor, ops.EagerTensor): - if name is None: - name = str(ops.uid()) - return capture_value(self.captures, tensor, tensor.dtype, name) - if tensor.graph is not self: - if name is None: - name = tensor.op.name - return capture_value(self.captures, tensor, tensor.dtype, name) - return tensor - - def create_op( - self, - op_type, - inputs, - dtypes, # pylint: disable=redefined-outer-name - input_types=None, - name=None, - attrs=None, - op_def=None, - compute_shapes=True, - compute_device=True): - """Captures an external inputs before calling Graph.capture_op.""" - # This capturing logic interacts poorly with control flow contexts which - # want to replace inputs of ops far too late in the process. This can lead - # the context to get confused and try to create an Enter for an Enter. We - # can detect this here and skip the additional Enter which can confuse loop - # validation logic. - if op_type == "Enter" and inputs[0].op.type == "Enter": - if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: - return inputs[0].op - # Calling AddValue on the control flow contexts to force creation of the - # backward accumulators in the original graph before we create placeholders - # to capture the inputs. - ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access - for i, inp in enumerate(inputs): - if ctxt is not None and hasattr(ctxt, "AddValue"): - inp = ctxt.AddValue(inp) - inp = self.capture(inp) - inputs[i] = inp - return super(CapturingGraph, self).create_op( - op_type, inputs, dtypes, input_types, name, attrs, op_def, - compute_device=compute_device) - - def _get_device_functions(ctx, graph): """Returns a tuple of device functions representing the device stack.""" if ctx.executing_eagerly(): @@ -193,7 +99,7 @@ def _get_device_functions(ctx, graph): return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access -class FuncGraph(CapturingGraph): +class FuncGraph(ops.Graph): """Graph representing a function body. Attributes: @@ -210,6 +116,8 @@ class FuncGraph(CapturingGraph): variables: Variables that should be watched during function execution. outer_graph: The graph this function is defined in. May be another FuncGraph or the global default Graph. + captures: Maps external tensor -> internal tensor (i.e. input placeholder). + The entries are in the order they were captured. seed: The graph-level random seed. """ @@ -230,6 +138,13 @@ class FuncGraph(CapturingGraph): self.structured_outputs = None self.variables = [] self.outer_graph = ops.get_default_graph() + self.captures = collections.OrderedDict() + + self._building_function = True + # Map from resource tensor name to last op (in program order) which uses + # this tensor. Used to enforce that execution order matches program order + # for resource tensors. + self._last_op_using_resource_tensor = {} graph = self.outer_graph @@ -258,15 +173,107 @@ class FuncGraph(CapturingGraph): self._graph_key = graph._graph_key # pylint: enable=protected-access + def create_op( + self, + op_type, + inputs, + dtypes, + input_types=None, + name=None, + attrs=None, + op_def=None, + compute_shapes=True, + compute_device=True): + """Like Graph.create_op, except handles external input tensors. + + This overload adds functionality to create_op to "capture" any external + input tensors, i.e. tensors from the eager context or outer function graphs + if this is a nested function. See `capture` for more information. + + Args: + op_type: The `Operation` type to create. This corresponds to the + `OpDef.name` field for the proto that defines the operation. + inputs: A list of `Tensor` objects that will be inputs to the `Operation`. + dtypes: A list of `DType` objects that will be the types of the tensors + that the operation produces. + input_types: (Optional.) A list of `DType`s that will be the types of + the tensors that the operation consumes. By default, uses the base + `DType` of each input in `inputs`. Operations that expect + reference-typed inputs must specify `input_types` explicitly. + name: (Optional.) A string name for the operation. If not specified, a + name is generated based on `op_type`. + attrs: (Optional.) A dictionary where the key is the attribute name (a + string) and the value is the respective `attr` attribute of the + `NodeDef` proto that will represent the operation (an `AttrValue` + proto). + op_def: (Optional.) The `OpDef` proto that describes the `op_type` that + the operation will have. + compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always + computed). + compute_device: (Optional.) If True, device functions will be executed + to compute the device property of the Operation. + + Returns: + An `Operation` object. + """ + # This capturing logic interacts poorly with control flow contexts which + # want to replace inputs of ops far too late in the process. This can lead + # the context to get confused and try to create an Enter for an Enter. We + # can detect this here and skip the additional Enter which can confuse loop + # validation logic. + if op_type == "Enter" and inputs[0].op.type == "Enter": + if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: + return inputs[0].op + # Calling AddValue on the control flow contexts to force creation of the + # backward accumulators in the original graph before we create placeholders + # to capture the inputs. + ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access + for i, inp in enumerate(inputs): + # TPU Estimator defines a control flow context with no AddValue method. + if ctxt is not None and hasattr(ctxt, "AddValue"): + inp = ctxt.AddValue(inp) + inp = self.capture(inp) + inputs[i] = inp + return super(FuncGraph, self).create_op( + op_type, inputs, dtypes, input_types, name, attrs, op_def, + compute_device=compute_device) + def capture(self, tensor, name=None): - """Calls CapturingGraph.capture and updates self.inputs if necessary.""" - new_capture = tensor not in self.captures - internal_tensor = super(FuncGraph, self).capture(tensor, name) + """Captures `tensor` if it's external to this graph. - if new_capture and tensor is not internal_tensor: - self.inputs.append(internal_tensor) + If `tensor` is from a different graph, returns a placeholder for it. + `tensor` and the placeholder will appear in self.captures, and the + placeholder will appear in self.inputs. Multiple calls to this method with + the same `tensor` argument will return the same placeholder. If `tensor` is + from this graph, returns `tensor`. + + Args: + tensor: Tensor. May be from this FuncGraph or a different graph. + name: Optional name if a placeholder is created. + + Returns: + Tensor from this FuncGraph. + """ + if isinstance(tensor, ops.EagerTensor): + if name is None: + name = str(ops.uid()) + return self._capture_helper(tensor, name) + if tensor.graph is not self: + if name is None: + name = tensor.op.name + return self._capture_helper(tensor, name) + return tensor - return internal_tensor + def _capture_helper(self, tensor, name): + captured_tensor = self.captures.get(tensor, None) + if captured_tensor is None: + captured_tensor = _create_substitute_placeholder(tensor, name=name, + dtype=tensor.dtype) + self.captures[tensor] = captured_tensor + self.inputs.append(captured_tensor) + tape.record_operation("captured_value", [captured_tensor], [tensor], + lambda x: [x]) + return captured_tensor @property def external_captures(self): diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index b6b05c0311..cb19a412a2 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -1001,7 +1001,7 @@ class Layer(checkpointable.CheckpointableBase): self.build(input_shape) with context.graph_mode(): - graph = eager_function.CapturingGraph() + graph = eager_function.FuncGraph('graph') with graph.as_default(): if isinstance(input_shape, list): inputs = [generate_placeholders_from_shape(shape) diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index f8c23ed124..10dd70cf23 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -770,7 +770,7 @@ class Network(base_layer.Layer): # and graph building, the variables created after building the model in # a Graph are still valid when executing eagerly. with context.graph_mode(): - graph = eager_function.CapturingGraph() + graph = eager_function.FuncGraph('graph') with graph.as_default(): if isinstance(input_shape, list): x = [base_layer.generate_placeholders_from_shape(shape) -- GitLab From ebf6d259fd4c57114c17646e40fdcfa4a1472972 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 5 Sep 2018 15:15:20 -0700 Subject: [PATCH 149/540] Deprecate `tf.ReaderBase` and related APIs. These APIs are based on queue runners, which have been deprecated and will be removed in TensorFlow 2.0. They have been replaced with `tf.data.Dataset`, which provides a more efficient version of the same functionality. PiperOrigin-RevId: 211708268 --- tensorflow/python/ops/io_ops.py | 37 +++++++++++---- ...nsorflow.-fixed-length-record-reader.pbtxt | 46 ------------------- .../v2/tensorflow.-identity-reader.pbtxt | 46 ------------------- .../v2/tensorflow.-l-m-d-b-reader.pbtxt | 46 ------------------- .../golden/v2/tensorflow.-reader-base.pbtxt | 45 ------------------ .../v2/tensorflow.-t-f-record-reader.pbtxt | 46 ------------------- .../v2/tensorflow.-text-line-reader.pbtxt | 46 ------------------- .../v2/tensorflow.-whole-file-reader.pbtxt | 46 ------------------- .../tools/api/golden/v2/tensorflow.pbtxt | 28 ----------- 9 files changed, 29 insertions(+), 357 deletions(-) delete mode 100644 tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt delete mode 100644 tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt delete mode 100644 tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt delete mode 100644 tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt delete mode 100644 tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt delete mode 100644 tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt delete mode 100644 tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index fbc1350c61..f84785df2c 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -33,8 +33,9 @@ from tensorflow.python.ops import gen_io_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_io_ops import * -from tensorflow.python.util.tf_export import tf_export # pylint: enable=wildcard-import +from tensorflow.python.util import deprecation +from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access @@ -95,7 +96,7 @@ def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, preferred_shard, name=name) -@tf_export("ReaderBase") +@tf_export(v1=["ReaderBase"]) class ReaderBase(object): """Base class for different Reader types, that produce a record every step. @@ -309,7 +310,7 @@ ops.NotDifferentiable("ReaderRestoreState") ops.NotDifferentiable("ReaderReset") -@tf_export("WholeFileReader") +@tf_export(v1=["WholeFileReader"]) class WholeFileReader(ReaderBase): """A Reader that outputs the entire contents of a file as a value. @@ -324,6 +325,9 @@ class WholeFileReader(ReaderBase): @end_compatibility """ + @deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.map(tf.read_file)`.") def __init__(self, name=None): """Create a WholeFileReader. @@ -337,7 +341,7 @@ class WholeFileReader(ReaderBase): ops.NotDifferentiable("WholeFileReader") -@tf_export("TextLineReader") +@tf_export(v1=["TextLineReader"]) class TextLineReader(ReaderBase): """A Reader that outputs the lines of a file delimited by newlines. @@ -351,6 +355,9 @@ class TextLineReader(ReaderBase): """ # TODO(josh11b): Support serializing and restoring state. + @deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.TextLineDataset`.") def __init__(self, skip_header_lines=None, name=None): """Create a TextLineReader. @@ -367,7 +374,7 @@ class TextLineReader(ReaderBase): ops.NotDifferentiable("TextLineReader") -@tf_export("FixedLengthRecordReader") +@tf_export(v1=["FixedLengthRecordReader"]) class FixedLengthRecordReader(ReaderBase): """A Reader that outputs fixed-length records from a file. @@ -380,6 +387,9 @@ class FixedLengthRecordReader(ReaderBase): """ # TODO(josh11b): Support serializing and restoring state. + @deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.FixedLengthRecordDataset`.") def __init__(self, record_bytes, header_bytes=None, @@ -410,7 +420,7 @@ class FixedLengthRecordReader(ReaderBase): ops.NotDifferentiable("FixedLengthRecordReader") -@tf_export("TFRecordReader") +@tf_export(v1=["TFRecordReader"]) class TFRecordReader(ReaderBase): """A Reader that outputs the records from a TFRecords file. @@ -423,6 +433,9 @@ class TFRecordReader(ReaderBase): """ # TODO(josh11b): Support serializing and restoring state. + @deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.TFRecordDataset`.") def __init__(self, name=None, options=None): """Create a TFRecordReader. @@ -441,7 +454,7 @@ class TFRecordReader(ReaderBase): ops.NotDifferentiable("TFRecordReader") -@tf_export("LMDBReader") +@tf_export(v1=["LMDBReader"]) class LMDBReader(ReaderBase): """A Reader that outputs the records from a LMDB file. @@ -452,6 +465,10 @@ class LMDBReader(ReaderBase): use `tf.data` to get data into your model. @end_compatibility """ + + @deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.contrib.data.LMDBDataset`.") def __init__(self, name=None, options=None): """Create a LMDBReader. @@ -459,6 +476,7 @@ class LMDBReader(ReaderBase): name: A name for the operation (optional). options: A LMDBRecordOptions object (optional). """ + del options rr = gen_io_ops.lmdb_reader(name=name) super(LMDBReader, self).__init__(rr) @@ -466,7 +484,7 @@ class LMDBReader(ReaderBase): ops.NotDifferentiable("LMDBReader") -@tf_export("IdentityReader") +@tf_export(v1=["IdentityReader"]) class IdentityReader(ReaderBase): """A Reader that outputs the queued work as both the key and value. @@ -481,6 +499,9 @@ class IdentityReader(ReaderBase): @end_compatibility """ + @deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.map(...)`.") def __init__(self, name=None): """Create a IdentityReader. diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt deleted file mode 100644 index 260c796fd6..0000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.-fixed-length-record-reader.pbtxt +++ /dev/null @@ -1,46 +0,0 @@ -path: "tensorflow.FixedLengthRecordReader" -tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - member { - name: "reader_ref" - mtype: "" - } - member { - name: "supports_serialize" - mtype: "" - } - member_method { - name: "__init__" - argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'hop_bytes\', \'name\', \'encoding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " - } - member_method { - name: "num_records_produced" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "num_work_units_completed" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read" - argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read_up_to" - argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "reset" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "restore_state" - argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "serialize_state" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt deleted file mode 100644 index 2eda320d63..0000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.-identity-reader.pbtxt +++ /dev/null @@ -1,46 +0,0 @@ -path: "tensorflow.IdentityReader" -tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - member { - name: "reader_ref" - mtype: "" - } - member { - name: "supports_serialize" - mtype: "" - } - member_method { - name: "__init__" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "num_records_produced" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "num_work_units_completed" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read" - argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read_up_to" - argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "reset" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "restore_state" - argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "serialize_state" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt deleted file mode 100644 index f9b7e9bbca..0000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.-l-m-d-b-reader.pbtxt +++ /dev/null @@ -1,46 +0,0 @@ -path: "tensorflow.LMDBReader" -tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - member { - name: "reader_ref" - mtype: "" - } - member { - name: "supports_serialize" - mtype: "" - } - member_method { - name: "__init__" - argspec: "args=[\'self\', \'name\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "num_records_produced" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "num_work_units_completed" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read" - argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read_up_to" - argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "reset" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "restore_state" - argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "serialize_state" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt deleted file mode 100644 index f6a3ce76a1..0000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.-reader-base.pbtxt +++ /dev/null @@ -1,45 +0,0 @@ -path: "tensorflow.ReaderBase" -tf_class { - is_instance: "" - is_instance: "" - member { - name: "reader_ref" - mtype: "" - } - member { - name: "supports_serialize" - mtype: "" - } - member_method { - name: "__init__" - argspec: "args=[\'self\', \'reader_ref\', \'supports_serialize\'], varargs=None, keywords=None, defaults=[\'False\'], " - } - member_method { - name: "num_records_produced" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "num_work_units_completed" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read" - argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read_up_to" - argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "reset" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "restore_state" - argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "serialize_state" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt deleted file mode 100644 index cdf7937391..0000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.-t-f-record-reader.pbtxt +++ /dev/null @@ -1,46 +0,0 @@ -path: "tensorflow.TFRecordReader" -tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - member { - name: "reader_ref" - mtype: "" - } - member { - name: "supports_serialize" - mtype: "" - } - member_method { - name: "__init__" - argspec: "args=[\'self\', \'name\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "num_records_produced" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "num_work_units_completed" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read" - argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read_up_to" - argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "reset" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "restore_state" - argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "serialize_state" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt deleted file mode 100644 index e9779f0762..0000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.-text-line-reader.pbtxt +++ /dev/null @@ -1,46 +0,0 @@ -path: "tensorflow.TextLineReader" -tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - member { - name: "reader_ref" - mtype: "" - } - member { - name: "supports_serialize" - mtype: "" - } - member_method { - name: "__init__" - argspec: "args=[\'self\', \'skip_header_lines\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "num_records_produced" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "num_work_units_completed" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read" - argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read_up_to" - argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "reset" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "restore_state" - argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "serialize_state" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt deleted file mode 100644 index 4ac759891c..0000000000 --- a/tensorflow/tools/api/golden/v2/tensorflow.-whole-file-reader.pbtxt +++ /dev/null @@ -1,46 +0,0 @@ -path: "tensorflow.WholeFileReader" -tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - member { - name: "reader_ref" - mtype: "" - } - member { - name: "supports_serialize" - mtype: "" - } - member_method { - name: "__init__" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "num_records_produced" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "num_work_units_completed" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read" - argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "read_up_to" - argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "reset" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "restore_state" - argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "serialize_state" - argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " - } -} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 7d45ea22c8..9332e16bf6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -60,10 +60,6 @@ tf_module { name: "FixedLenSequenceFeature" mtype: "" } - member { - name: "FixedLengthRecordReader" - mtype: "" - } member { name: "GIT_VERSION" mtype: "" @@ -108,10 +104,6 @@ tf_module { name: "HistogramProto" mtype: "" } - member { - name: "IdentityReader" - mtype: "" - } member { name: "IndexedSlices" mtype: "" @@ -120,10 +112,6 @@ tf_module { name: "InteractiveSession" mtype: "" } - member { - name: "LMDBReader" - mtype: "" - } member { name: "LogMessage" mtype: "" @@ -176,10 +164,6 @@ tf_module { name: "RandomShuffleQueue" mtype: "" } - member { - name: "ReaderBase" - mtype: "" - } member { name: "RegisterGradient" mtype: "" @@ -224,10 +208,6 @@ tf_module { name: "SummaryMetadata" mtype: "" } - member { - name: "TFRecordReader" - mtype: "" - } member { name: "Tensor" mtype: "" @@ -244,10 +224,6 @@ tf_module { name: "TensorShape" mtype: "" } - member { - name: "TextLineReader" - mtype: "" - } member { name: "VERSION" mtype: "" @@ -272,10 +248,6 @@ tf_module { name: "VariableSynchronization" mtype: "" } - member { - name: "WholeFileReader" - mtype: "" - } member { name: "app" mtype: "" -- GitLab From 47b1af2a3a724a5d783ae06ca0e0e78b30e0799b Mon Sep 17 00:00:00 2001 From: Eddie Zhou Date: Wed, 5 Sep 2018 15:24:38 -0700 Subject: [PATCH 150/540] Expose an axis argument for VocabInfo, which allows for warm-starting of the second axis of Tensors through tf.train.warm_start. Note that the underlying initializer already has this functionality (for example, for output layers). PiperOrigin-RevId: 211709879 --- tensorflow/python/estimator/estimator.py | 2 +- tensorflow/python/training/checkpoint_ops.py | 3 +- .../python/training/warm_starting_util.py | 100 +++++++++++-- .../training/warm_starting_util_test.py | 140 ++++++++++++++++-- .../v1/tensorflow.estimator.-vocab-info.pbtxt | 4 + .../v1/tensorflow.train.-vocab-info.pbtxt | 4 + .../v2/tensorflow.estimator.-vocab-info.pbtxt | 4 + .../v2/tensorflow.train.-vocab-info.pbtxt | 4 + 8 files changed, 235 insertions(+), 26 deletions(-) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index e44a69b374..0f20acefdf 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -2056,7 +2056,7 @@ class WarmStartSettings( var_name_to_vocab_info: [Optional] Dict of variable names (strings) to `tf.estimator.VocabInfo`. The variable names should be "full" variables, not the names of the partitions. If not explicitly provided, the variable - is assumed to have no vocabulary. + is assumed to have no (changes to) vocabulary. var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to name of the previously-trained variable in `ckpt_to_initialize_from`. If not explicitly provided, the name of the variable is assumed to be same diff --git a/tensorflow/python/training/checkpoint_ops.py b/tensorflow/python/training/checkpoint_ops.py index a6e9662b73..cfd9b39ddc 100644 --- a/tensorflow/python/training/checkpoint_ops.py +++ b/tensorflow/python/training/checkpoint_ops.py @@ -268,7 +268,8 @@ def _load_and_remap_matrix_initializer(ckpt_path, vocab files are the same, and no column remapping is done. The returned initializer only supports div-partitioning along the row axis. It - does not support partitioning along the column axis or mod-partitioning. + does not support partitioning along the column axis (as this is not common in + practice) or mod-partitioning. NOTE: When this is used to warm-start variables, client code should use `tf.lookup.index_table_from_tensor()` like diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py index c0dd46bfa5..bea9bb6dff 100644 --- a/tensorflow/python/training/warm_starting_util.py +++ b/tensorflow/python/training/warm_starting_util.py @@ -41,6 +41,7 @@ class VocabInfo( "old_vocab", "old_vocab_size", "backup_initializer", + "axis", ])): """Vocabulary information for warm-starting. @@ -62,6 +63,42 @@ class VocabInfo( backup_initializer: [Optional] A variable initializer used for variables corresponding to new vocabulary entries and OOV. If not provided, these entries will be zero-initialized. + axis: [Optional] Denotes what axis the vocabulary corresponds to. The + default, 0, corresponds to the most common use case (embeddings or + linear weights for binary classification / regression). An axis of 1 + could be used for warm-starting output layers with class vocabularies. + + For example: + + embeddings_vocab_info = tf.VocabInfo( + new_vocab='embeddings_vocab', + new_vocab_size=100, + num_oov_buckets=1, + old_vocab='pretrained_embeddings_vocab', + old_vocab_size=10000, + backup_initializer=tf.truncated_normal_initializer( + mean=0.0, stddev=(1 / math.sqrt(embedding_dim))), + axis=0) + + softmax_output_layer_kernel_vocab_info = tf.VocabInfo( + new_vocab='class_vocab', + new_vocab_size=5, + num_oov_buckets=0, # No OOV for classes. + old_vocab='old_class_vocab', + old_vocab_size=8, + backup_initializer=tf.glorot_uniform_initializer(), + axis=1) + + softmax_output_layer_bias_vocab_info = tf.VocabInfo( + new_vocab='class_vocab', + new_vocab_size=5, + num_oov_buckets=0, # No OOV for classes. + old_vocab='old_class_vocab', + old_vocab_size=8, + backup_initializer=tf.zeros_initializer(), + axis=0) + + Currently, only axis=0 and axis=1 are supported. """ def __new__(cls, @@ -70,7 +107,12 @@ class VocabInfo( num_oov_buckets, old_vocab, old_vocab_size=-1, - backup_initializer=None): + backup_initializer=None, + axis=0): + if axis != 0 and axis != 1: + raise ValueError("The only supported values for the axis argument are 0 " + "and 1. Provided axis: {}".format(axis)) + return super(VocabInfo, cls).__new__( cls, new_vocab, @@ -79,6 +121,7 @@ class VocabInfo( old_vocab, old_vocab_size, backup_initializer, + axis, ) @@ -149,7 +192,8 @@ def _warm_start_var_with_vocab(var, previous_vocab_size=-1, current_oov_buckets=0, prev_tensor_name=None, - initializer=None): + initializer=None, + axis=0): """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. Use this method when the `var` is backed by vocabulary. This method stitches @@ -180,6 +224,7 @@ def _warm_start_var_with_vocab(var, None, we lookup tensor with same name as given `var`. initializer: Variable initializer to be used for missing entries. If None, missing entries will be zero-initialized. + axis: Axis of the variable that the provided vocabulary corresponds to. Raises: ValueError: If required args are not provided. @@ -204,6 +249,8 @@ def _warm_start_var_with_vocab(var, # Assume tensor name remains the same. prev_tensor_name = _infer_var_name(var) + # TODO(eddz): Fix functionality for rank-1 Variables (like FC biases). + total_v_first_axis = sum([v.get_shape().as_list()[0] for v in var]) for v in var: v_shape = v.get_shape().as_list() slice_info = v._get_save_slice_info() @@ -213,19 +260,45 @@ def _warm_start_var_with_vocab(var, full_shape=slice_info.full_shape, var_offset=slice_info.var_offset) - # TODO(eddz): Support cases where class vocabularies need remapping too. + if axis == 0: + new_row_vocab_size = current_vocab_size + new_col_vocab_size = v_shape[1] + old_row_vocab_size = previous_vocab_size + old_row_vocab_file = prev_vocab_path + new_row_vocab_file = current_vocab_path + old_col_vocab_file = None + new_col_vocab_file = None + num_row_oov_buckets = current_oov_buckets + num_col_oov_buckets = 0 + elif axis == 1: + # Note that we must compute this value across all partitions, whereas + # in the axis = 0 case, we can simply use v_shape[1] because we don't + # allow partitioning across axis = 1. + new_row_vocab_size = total_v_first_axis + new_col_vocab_size = current_vocab_size + old_row_vocab_size = -1 + old_row_vocab_file = None + new_row_vocab_file = None + old_col_vocab_file = prev_vocab_path + new_col_vocab_file = current_vocab_path + num_row_oov_buckets = 0 + num_col_oov_buckets = current_oov_buckets + else: + raise ValueError("The only supported values for the axis argument are 0 " + "and 1. Provided axis: {}".format(axis)) + init = checkpoint_ops._load_and_remap_matrix_initializer( ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt), old_tensor_name=prev_tensor_name, - new_row_vocab_size=current_vocab_size, - new_col_vocab_size=v_shape[1], - old_row_vocab_size=previous_vocab_size, - old_row_vocab_file=prev_vocab_path, - new_row_vocab_file=current_vocab_path, - old_col_vocab_file=None, - new_col_vocab_file=None, - num_row_oov_buckets=current_oov_buckets, - num_col_oov_buckets=0, + new_row_vocab_size=new_row_vocab_size, + new_col_vocab_size=new_col_vocab_size, + old_row_vocab_size=old_row_vocab_size, + old_row_vocab_file=old_row_vocab_file, + new_row_vocab_file=new_row_vocab_file, + old_col_vocab_file=old_col_vocab_file, + new_col_vocab_file=new_col_vocab_file, + num_row_oov_buckets=num_row_oov_buckets, + num_col_oov_buckets=num_col_oov_buckets, initializer=initializer) new_init_val = ops.convert_to_tensor( init(shape=v_shape, partition_info=partition_info)) @@ -374,7 +447,8 @@ def warm_start(ckpt_to_initialize_from, previous_vocab_size=vocab_info.old_vocab_size, current_oov_buckets=vocab_info.num_oov_buckets, prev_tensor_name=prev_var_name, - initializer=vocab_info.backup_initializer) + initializer=vocab_info.backup_initializer, + axis=vocab_info.axis) else: # For the special value of vars_to_warm_start = None, # we only warm-start variables with explicitly specified vocabularies. diff --git a/tensorflow/python/training/warm_starting_util_test.py b/tensorflow/python/training/warm_starting_util_test.py index 70a84bc3f6..3ee0f6aaa2 100644 --- a/tensorflow/python/training/warm_starting_util_test.py +++ b/tensorflow/python/training/warm_starting_util_test.py @@ -107,7 +107,7 @@ class WarmStartingUtilTest(test.TestCase): "fruit_weights", initializer=[[0.], [0.], [0.], [0.]]) ws_util._warm_start_var(fruit_weights, self.get_temp_dir()) sess.run(variables.global_variables_initializer()) - self.assertAllEqual(prev_val, fruit_weights.eval(sess)) + self.assertAllClose(prev_val, fruit_weights.eval(sess)) def testWarmStartVarPrevVarPartitioned(self): _, weights = self._create_prev_run_var( @@ -123,7 +123,7 @@ class WarmStartingUtilTest(test.TestCase): "fruit_weights", initializer=[[0.], [0.], [0.], [0.]]) ws_util._warm_start_var(fruit_weights, self.get_temp_dir()) sess.run(variables.global_variables_initializer()) - self.assertAllEqual(prev_val, fruit_weights.eval(sess)) + self.assertAllClose(prev_val, fruit_weights.eval(sess)) def testWarmStartVarCurrentVarPartitioned(self): _, prev_val = self._create_prev_run_var( @@ -143,7 +143,7 @@ class WarmStartingUtilTest(test.TestCase): fruit_weights = fruit_weights._get_variable_list() new_val = np.concatenate( [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0) - self.assertAllEqual(prev_val, new_val) + self.assertAllClose(prev_val, new_val) def testWarmStartVarBothVarsPartitioned(self): _, weights = self._create_prev_run_var( @@ -170,7 +170,7 @@ class WarmStartingUtilTest(test.TestCase): fruit_weights = fruit_weights._get_variable_list() new_val = np.concatenate( [fruit_weights[0].eval(sess), fruit_weights[1].eval(sess)], axis=0) - self.assertAllEqual(prev_val, new_val) + self.assertAllClose(prev_val, new_val) def testWarmStartVarWithVocab(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], @@ -189,9 +189,34 @@ class WarmStartingUtilTest(test.TestCase): ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5, self.get_temp_dir(), prev_vocab_path) sess.run(variables.global_variables_initializer()) - self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]], + self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]], fruit_weights.eval(sess)) + def testWarmStartVarWithColumnVocab(self): + prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") + self._create_prev_run_var( + "fruit_output_layer", + initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]]) + + # New vocab with elements in reverse order and one new element. + new_vocab_path = self._write_vocab(["orange", "apple", "banana"], + "new_vocab") + # New session and new graph. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + fruit_output_layer = variable_scope.get_variable( + "fruit_output_layer", + initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], + [0., 0., 0.]]) + ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, + current_vocab_size=3, + prev_ckpt=self.get_temp_dir(), + prev_vocab_path=prev_vocab_path, + axis=1) + sess.run(variables.global_variables_initializer()) + self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.], + [2.3, 2., 0.]], fruit_output_layer.eval(sess)) + def testWarmStartVarWithVocabConstrainedOldVocabSize(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], "old_vocab") @@ -215,7 +240,7 @@ class WarmStartingUtilTest(test.TestCase): previous_vocab_size=2) sess.run(variables.global_variables_initializer()) # Old vocabulary limited to ['apple', 'banana']. - self.assertAllEqual([[0.], [0.], [1.], [0.5], [0.]], + self.assertAllClose([[0.], [0.], [1.], [0.5], [0.]], fruit_weights.eval(sess)) def testWarmStartVarWithVocabPrevVarPartitioned(self): @@ -238,9 +263,36 @@ class WarmStartingUtilTest(test.TestCase): ws_util._warm_start_var_with_vocab(fruit_weights, new_vocab_path, 5, self.get_temp_dir(), prev_vocab_path) sess.run(variables.global_variables_initializer()) - self.assertAllEqual([[2.], [1.5], [1.], [0.5], [0.]], + self.assertAllClose([[2.], [1.5], [1.], [0.5], [0.]], fruit_weights.eval(sess)) + def testWarmStartVarWithColumnVocabPrevVarPartitioned(self): + prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") + self._create_prev_run_var( + "fruit_output_layer", + shape=[4, 2], + initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]], + partitioner=lambda shape, dtype: [2, 1]) + + # New vocab with elements in reverse order and one new element. + new_vocab_path = self._write_vocab(["orange", "apple", "banana"], + "new_vocab") + # New session and new graph. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + fruit_output_layer = variable_scope.get_variable( + "fruit_output_layer", + initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], + [0., 0., 0.]]) + ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, + current_vocab_size=3, + prev_ckpt=self.get_temp_dir(), + prev_vocab_path=prev_vocab_path, + axis=1) + sess.run(variables.global_variables_initializer()) + self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.], [1.2, 1.5, 0.], + [2.3, 2., 0.]], fruit_output_layer.eval(sess)) + def testWarmStartVarWithVocabCurrentVarPartitioned(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], "old_vocab") @@ -269,11 +321,43 @@ class WarmStartingUtilTest(test.TestCase): self.assertTrue( isinstance(fruit_weights, variables.PartitionedVariable)) fruit_weights_vars = fruit_weights._get_variable_list() - self.assertAllEqual([[2.], [1.5], [1.]], + self.assertAllClose([[2.], [1.5], [1.]], fruit_weights_vars[0].eval(sess)) - self.assertAllEqual([[0.5], [0.], [0.]], + self.assertAllClose([[0.5], [0.], [0.]], fruit_weights_vars[1].eval(sess)) + def testWarmStartVarWithColumnVocabCurrentVarPartitioned(self): + prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") + self._create_prev_run_var( + "fruit_output_layer", + initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]]) + + # New vocab with elements in reverse order and one new element. + new_vocab_path = self._write_vocab(["orange", "apple", "banana"], + "new_vocab") + # New session and new graph. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + fruit_output_layer = variable_scope.get_variable( + "fruit_output_layer", + shape=[4, 3], + initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], + [0., 0., 0.]], + partitioner=lambda shape, dtype: [2, 1]) + ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, + current_vocab_size=3, + prev_ckpt=self.get_temp_dir(), + prev_vocab_path=prev_vocab_path, + axis=1) + sess.run(variables.global_variables_initializer()) + self.assertTrue( + isinstance(fruit_output_layer, variables.PartitionedVariable)) + fruit_output_layer_vars = fruit_output_layer._get_variable_list() + self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]], + fruit_output_layer_vars[0].eval(sess)) + self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]], + fruit_output_layer_vars[1].eval(sess)) + def testWarmStartVarWithVocabBothVarsPartitioned(self): prev_vocab_path = self._write_vocab(["apple", "banana", "guava", "orange"], "old_vocab") @@ -301,11 +385,45 @@ class WarmStartingUtilTest(test.TestCase): self.assertTrue( isinstance(fruit_weights, variables.PartitionedVariable)) fruit_weights_vars = fruit_weights._get_variable_list() - self.assertAllEqual([[2.], [1.5], [1.]], + self.assertAllClose([[2.], [1.5], [1.]], fruit_weights_vars[0].eval(sess)) - self.assertAllEqual([[0.5], [0.], [0.]], + self.assertAllClose([[0.5], [0.], [0.]], fruit_weights_vars[1].eval(sess)) + def testWarmStartVarWithColumnVocabBothVarsPartitioned(self): + prev_vocab_path = self._write_vocab(["apple", "orange"], "old_vocab") + self._create_prev_run_var( + "fruit_output_layer", + shape=[4, 2], + initializer=[[0.5, 0.3], [1., 0.8], [1.5, 1.2], [2., 2.3]], + partitioner=lambda shape, dtype: [2, 1]) + + # New vocab with elements in reverse order and one new element. + new_vocab_path = self._write_vocab(["orange", "apple", "banana"], + "new_vocab") + # New session and new graph. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + fruit_output_layer = variable_scope.get_variable( + "fruit_output_layer", + shape=[4, 3], + initializer=[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.], + [0., 0., 0.]], + partitioner=lambda shape, dtype: [2, 1]) + ws_util._warm_start_var_with_vocab(fruit_output_layer, new_vocab_path, + current_vocab_size=3, + prev_ckpt=self.get_temp_dir(), + prev_vocab_path=prev_vocab_path, + axis=1) + sess.run(variables.global_variables_initializer()) + self.assertTrue( + isinstance(fruit_output_layer, variables.PartitionedVariable)) + fruit_output_layer_vars = fruit_output_layer._get_variable_list() + self.assertAllClose([[0.3, 0.5, 0.], [0.8, 1.0, 0.]], + fruit_output_layer_vars[0].eval(sess)) + self.assertAllClose([[1.2, 1.5, 0.], [2.3, 2., 0.]], + fruit_output_layer_vars[1].eval(sess)) + def testWarmStart_ListOfVariables(self): # Save checkpoint from which to warm-start. _, prev_int_val = self._create_prev_run_var("v1", shape=[10, 1], diff --git a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt index 5301b94eb3..b6942cb7ed 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.estimator.-vocab-info.pbtxt @@ -3,6 +3,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "axis" + mtype: "" + } member { name: "backup_initializer" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt index 4ce7cb1111..39b946b82f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.-vocab-info.pbtxt @@ -3,6 +3,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "axis" + mtype: "" + } member { name: "backup_initializer" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt index 5301b94eb3..b6942cb7ed 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.estimator.-vocab-info.pbtxt @@ -3,6 +3,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "axis" + mtype: "" + } member { name: "backup_initializer" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt index 4ce7cb1111..39b946b82f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.-vocab-info.pbtxt @@ -3,6 +3,10 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" + member { + name: "axis" + mtype: "" + } member { name: "backup_initializer" mtype: "" -- GitLab From 25241c4270ca3c8679710fbe1803c836b6c983ea Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Wed, 5 Sep 2018 15:34:55 -0700 Subject: [PATCH 151/540] Update diagram in TOCO README. PiperOrigin-RevId: 211711493 --- tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg index 262e13a591..335debde57 100644 --- a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg +++ b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg @@ -1 +1 @@ - \ No newline at end of file + \ No newline at end of file -- GitLab From 007443c69511aa001696a53150aa5a4334ffb8b9 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Wed, 5 Sep 2018 15:44:59 -0700 Subject: [PATCH 152/540] Temporarily disable distributed coordinator training when using TPUStrategy PiperOrigin-RevId: 211712907 --- tensorflow/python/estimator/run_config.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index b1ca207b62..3773810a04 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -521,7 +521,12 @@ class RunConfig(object): eval_distribute=eval_distribute, experimental_distribute=experimental_distribute) - if train_distribute or eval_distribute or experimental_distribute: + # TODO(frankchn,priyag): Eventually use distributed coordinator for TPUs. + if ((train_distribute and + train_distribute.__class__.__name__ != 'TPUStrategy') or + (eval_distribute and + eval_distribute.__class__.__name__ != 'TPUStrategy') or + experimental_distribute): logging.info('Initializing RunConfig with distribution strategies.') distribute_coordinator_training.init_run_config(self, tf_config) else: -- GitLab From b98d33daa08781d5b55a3c583f62e5753dc1da51 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 15:54:50 -0700 Subject: [PATCH 153/540] Mark tf.GraphKeys.VARIABLES as deprecated PiperOrigin-RevId: 211714574 --- tensorflow/python/framework/ops.py | 6 ++---- tensorflow/tools/compatibility/renames_v2.py | 1 + 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 4cfd639bf9..9401309c19 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -55,6 +55,7 @@ from tensorflow.python.platform import app from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util import decorator_utils +from tensorflow.python.util import deprecation from tensorflow.python.util import function_utils from tensorflow.python.util import lock_util from tensorflow.python.util import tf_contextlib @@ -5807,11 +5808,8 @@ class GraphKeys(object): _STREAMING_MODEL_PORTS = "streaming_model_ports" @decorator_utils.classproperty + @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.") def VARIABLES(cls): # pylint: disable=no-self-argument - logging.log_first_n(logging.WARN, - "VARIABLES collection name is deprecated, please use " - "GLOBAL_VARIABLES instead; VARIABLES will be removed " - "after 2017-03-02.", 1) return cls.GLOBAL_VARIABLES diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py index 216aa41b60..29c62763b0 100644 --- a/tensorflow/tools/compatibility/renames_v2.py +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -67,6 +67,7 @@ renames = { 'tf.gather_nd': 'tf.manip.gather_nd', 'tf.greater': 'tf.math.greater', 'tf.greater_equal': 'tf.math.greater_equal', + 'tf.GraphKeys.VARIABLES': 'tf.GraphKeys.GLOBAL_VARIABLES', 'tf.ifft': 'tf.spectral.ifft', 'tf.igamma': 'tf.math.igamma', 'tf.igammac': 'tf.math.igammac', -- GitLab From b744cc00e1522d50463e2b681beae39cbb6f4d16 Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Wed, 5 Sep 2018 16:00:28 -0700 Subject: [PATCH 154/540] Fix several build warnings in TFLite PiperOrigin-RevId: 211715608 --- tensorflow/contrib/lite/builtin_op_data.h | 10 +++++ tensorflow/contrib/lite/context.h | 40 ++++++++++--------- .../contrib/lite/kernels/eigen_support.h | 2 +- .../contrib/lite/nnapi_delegate_disabled.cc | 8 +++- 4 files changed, 40 insertions(+), 20 deletions(-) diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index e81f9e4f51..aecd71910c 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -25,6 +25,11 @@ extern "C" { // TODO(aselle): Consider using "if this then that" for testing. +// Useful placeholder to put in otherwise empty structs to avoid size warnings. +typedef struct { + char dummy_; +} EmptyStructPlaceholder; + // Possible padding types (for convolutions) typedef enum { kTfLitePaddingUnknown = 0, @@ -129,9 +134,11 @@ typedef struct { } TfLiteAddParams; typedef struct { + EmptyStructPlaceholder placeholder_; } TfLiteSpaceToBatchNDParams; typedef struct { + EmptyStructPlaceholder placeholder_; } TfLiteBatchToSpaceNDParams; typedef struct { @@ -178,9 +185,11 @@ typedef struct { } TfLiteResizeBilinearParams; typedef struct { + EmptyStructPlaceholder placeholder_; } TfLitePadParams; typedef struct { + EmptyStructPlaceholder placeholder_; } TfLitePadV2Params; typedef struct { @@ -220,6 +229,7 @@ typedef struct { } TfLiteGatherParams; typedef struct { + EmptyStructPlaceholder placeholder_; } TfLiteTransposeParams; typedef struct { diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index c7f4df3cdc..b23183b743 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -39,6 +39,12 @@ extern "C" { typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; +// Forward declarations for use with dependent types. +struct TfLiteContext; +struct TfLiteNode; +struct _TfLiteRegistration; +struct _TfLiteDelegate; + // The list of external context types known to TF Lite. This list exists solely // to avoid conflicts and to ensure ops can share the external contexts they // need. Access to the external contexts is controled by one of the @@ -60,10 +66,6 @@ typedef struct { TfLiteStatus (*Refresh)(struct TfLiteContext* context); } TfLiteExternalContext; -// Forward declare so GetNode can use this is in Context. -typedef struct _TfLiteRegistration TfLiteRegistration; -typedef struct _TfLiteDelegate TfLiteDelegate; - #define kOptionalTensor (-1) // Fixed size list of integers. Used for dimensions and inputs/outputs tensor @@ -240,7 +242,7 @@ typedef struct { // The delegate which knows how to handle `buffer_handle`. // WARNING: This is an experimental interface that is subject to change. - TfLiteDelegate* delegate; + struct _TfLiteDelegate* delegate; // An integer buffer handle that can be handled by `delegate`. // The value is valid only when delegate is not null. @@ -278,7 +280,7 @@ void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); // A structure representing an instance of a node. // This structure only exhibits the inputs, outputs and user defined data, not // other features like the type. -typedef struct { +typedef struct TfLiteNode { // Inputs to this node expressed as indices into the simulator's tensors. TfLiteIntArray* inputs; @@ -305,7 +307,7 @@ typedef struct { // The pointer to the delegate. This is non-null only when the node is // created by calling `interpreter.ModifyGraphWithDelegate`. // WARNING: This is an experimental interface that is subject to change. - TfLiteDelegate* delegate; + struct _TfLiteDelegate* delegate; } TfLiteNode; typedef struct TfLiteContext { @@ -351,15 +353,15 @@ typedef struct TfLiteContext { // Get a Tensor node by node_index. // WARNING: This is an experimental interface that is subject to change. - TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index, - TfLiteNode** node, - TfLiteRegistration** registration); + TfLiteStatus (*GetNodeAndRegistration)( + struct TfLiteContext*, int node_index, struct TfLiteNode** node, + struct _TfLiteRegistration** registration); // Replace ops with one or more stub delegate operations. This function // does not take ownership of `nodes_to_replace`. TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)( - struct TfLiteContext*, TfLiteRegistration registration, - const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate); + struct TfLiteContext*, struct _TfLiteRegistration registration, + const TfLiteIntArray* nodes_to_replace, struct _TfLiteDelegate* delegate); // Number of threads that are recommended to subsystems like gemmlowp and // eigen. @@ -447,19 +449,20 @@ typedef struct _TfLiteDelegate { // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels() // to ask the TensorFlow lite runtime to create macro-nodes to represent // delegated subgraphs of the original graph. - TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate); + TfLiteStatus (*Prepare)(struct TfLiteContext* context, + struct _TfLiteDelegate* delegate); // Copy the data from delegate buffer handle to raw memory. // This can be null if the delegate doesn't use its own buffer. - TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context, - TfLiteDelegate* delegate, + TfLiteStatus (*CopyFromBufferHandle)(struct TfLiteContext* context, + struct _TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, void* data, size_t size); // Copy the data from raw memory to delegate buffer handle. // This can be null if the delegate doesn't use its own buffer. - TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context, - TfLiteDelegate* delegate, + TfLiteStatus (*CopyToBufferHandle)(struct TfLiteContext* context, + struct _TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle, void* data, size_t size); @@ -467,7 +470,8 @@ typedef struct _TfLiteDelegate { // this doesn't release the underlying resource (e.g. textures). The // resources are either owned by application layer or the delegate. // This can be null if the delegate doesn't use its own buffer. - void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate, + void (*FreeBufferHandle)(struct TfLiteContext* context, + struct _TfLiteDelegate* delegate, TfLiteBufferHandle* handle); } TfLiteDelegate; diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h index ec77856b10..b235829642 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.h +++ b/tensorflow/contrib/lite/kernels/eigen_support.h @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" namespace EigenForTFLite { -class ThreadPoolDevice; +struct ThreadPoolDevice; } namespace tflite { diff --git a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc index efde72b1a7..e3536d3db6 100644 --- a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc +++ b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc @@ -27,7 +27,13 @@ NNAPIAllocation::NNAPIAllocation(const char* filename, NNAPIAllocation::~NNAPIAllocation() {} -NNAPIDelegate::~NNAPIDelegate() {} +NNAPIDelegate::~NNAPIDelegate() { +#define UNUSED_MEMBER(x) (void)(x) + UNUSED_MEMBER(nn_model_); + UNUSED_MEMBER(nn_compiled_model_); + UNUSED_MEMBER(model_status_); +#undef UNUSED_MEMBER +} TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { return kTfLiteError; -- GitLab From 6ce8af21574ce71f94a8a06bde876d2f7bf690e5 Mon Sep 17 00:00:00 2001 From: Rachel Lim Date: Wed, 5 Sep 2018 16:12:12 -0700 Subject: [PATCH 155/540] [tf.data] Surface errors correctly in MapDefunOp by using different CancellationManagers for each run of the function. PiperOrigin-RevId: 211717580 --- .../python/kernel_tests/map_defun_op_test.py | 16 ++++++++++++++++ tensorflow/core/kernels/data/map_defun_op.cc | 6 ++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py index 73cde40305..091eb5ce37 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -130,6 +130,22 @@ class MapDefunTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): self.evaluate(result) + def testMapDefunCancelledCorrectly(self): + + @function.Defun(dtypes.int64) + def defun(x): + # x has leading dimension 5, this will raise an error + return array_ops.gather(x, 10) + + c = array_ops.tile( + array_ops.expand_dims( + constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0), + [100, 1]) + map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0] + with self.assertRaisesRegexp(errors.InvalidArgumentError, + r"indices = 10 is not in \[0, 5\)"): + self.evaluate(map_defun_op) + if __name__ == "__main__": test.main() diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index 607d0ca028..cc4d7976f8 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -29,7 +29,6 @@ void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, bool always_collect_stats) { opts->step_id = ctx->step_id(); opts->rendezvous = ctx->rendezvous(); - opts->cancellation_manager = ctx->cancellation_manager(); if (always_collect_stats) { opts->stats_collector = ctx->stats_collector(); } @@ -117,10 +116,13 @@ class MapDefunOp : public AsyncOpKernel { for (size_t i = 0; i < static_cast(batch_size); ++i) { auto* call_frame = new MapFunctionCallFrame(*args, *arg_shapes, output, this, i); + CancellationManager* c_mgr = new CancellationManager; + opts_.cancellation_manager = c_mgr; ctx->function_library()->Run( opts_, func_handle_, call_frame, - [call_frame, refcounted](const Status& func_status) { + [call_frame, refcounted, c_mgr](const Status& func_status) { delete call_frame; + delete c_mgr; refcounted->UpdateStatus(func_status); refcounted->Unref(); }); -- GitLab From df7930083b73b91959420dc2f92463befbac5af4 Mon Sep 17 00:00:00 2001 From: Youlong Cheng Date: Wed, 5 Sep 2018 16:16:46 -0700 Subject: [PATCH 156/540] Implements TPU alltoall op. RELNOTES: n/a PiperOrigin-RevId: 211718248 --- .../contrib/tpu/ops/cross_replica_ops.cc | 89 ++++++++++++++++++- tensorflow/contrib/tpu/python/ops/tpu_ops.py | 64 +++++++++++-- 2 files changed, 142 insertions(+), 11 deletions(-) diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc index 9ee5ecb123..ea8e0e00ed 100644 --- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc +++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc @@ -18,6 +18,89 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("AllToAll") + .Input("input: T") + .Input("group_assignment: int32") + .Output("output: T") + .Attr("T: {bfloat16, float}") + .Attr("concat_dimension: int") + .Attr("split_dimension: int") + .Attr("split_count: int") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle input = c->input(0); + int64 rank; + if (c->RankKnown(input)) { + rank = c->Rank(input); + } else { + return errors::InvalidArgument("input's rank is unknown."); + } + int concat_dimension; + int split_dimension; + + TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension)); + + if (concat_dimension < 0 || concat_dimension >= rank) { + return errors::InvalidArgument("concat_dimension ", concat_dimension, + " is out of range of input rank ", rank); + } + + TF_RETURN_IF_ERROR(c->GetAttr("split_dimension", &split_dimension)); + if (split_dimension < 0 || split_dimension >= rank) { + return errors::InvalidArgument("split_dimension ", split_dimension, + " is out of range of input rank ", rank); + } + + std::vector dims; + dims.resize(rank); + + for (int32 i = 0; i < rank; ++i) { + int64 in_idx = i; + if (i == concat_dimension) { + in_idx = split_dimension; + } else if (i == split_dimension) { + in_idx = concat_dimension; + } + + dims[i] = c->Dim(input, in_idx); + } + + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + }) + .Doc(R"doc( +An Op to exchange data across TPU replicas. On each replica, the input is +split into `split_count` blocks along `split_dimension` and send to the other +replicas given group_assignment. After receiving `split_count` - 1 blocks from +other replicas, we concatenate the blocks along `concat_dimension` as the +output. + +For example, suppose there are 2 TPU replicas: +replica 0 receives input: `[[A, B]]` +replica 1 receives input: `[[C, D]]` + +group_assignment=`[[0, 1]]` +concat_dimension=0 +split_dimension=1 +split_count=2 + +replica 0's output: `[[A], [C]]` +replica 1's output: `[[B], [D]]` + +input: The local input to the sum. +group_assignment: An int32 tensor with shape + [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the + replica ids in the ith subgroup. +concat_dimension: The dimension number to concatenate. +split_dimension: The dimension number to split. +split_count: The number of splits, this number must equal to the sub-group + size(group_assignment.get_shape()[1]) +output: The exchanged result. +T: The type of elements to be exchanged. +)doc"); REGISTER_OP("CrossReplicaSum") .Input("input: T") @@ -26,10 +109,8 @@ REGISTER_OP("CrossReplicaSum") .Attr("T: {bfloat16, float}") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( -An Op to sum inputs across replicated TPU instances. Each -instance supplies its own input. If group_assignment is empty, the output of -each is the sum of all the inputs, otherwise the output of each is the sum of -the inputs belonging to the same group. +An Op to sum inputs across replicated TPU instances. Each instance supplies its +own input. For example, suppose there are 8 TPU instances: `[A, B, C, D, E, F, G, H]`. Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0, diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 3ed571aff9..d92a0652bb 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -38,6 +38,62 @@ if platform.system() != "Windows": _tpu_ops = loader.load_op_library( resource_loader.get_path_to_datafile("_tpu_ops.so")) + def _create_default_group_assignment(): + num_shards = tpu_function.get_tpu_context().number_of_shards + if num_shards is None: + logging.warning( + "cross_replica_sum should be used within a tpu_shard_context, but " + "got unset number_of_shards. Assuming 1.") + num_shards = 1 + group_assignment = [list(range(num_shards))] + return group_assignment + + def all_to_all(x, + concat_dimension, + split_dimension, + split_count, + group_assignment=None, + name=None): + """Exchange data across TPU replicas. + + Args: + x: The local tensor. + concat_dimension: The dimension number to concatenate. + split_dimension: The dimension number to split. + split_count: The number of splits, this number must equal to the sub-group + size(group_assignment.get_shape()[1]) + group_assignment: Optional 2d int32 lists with shape [num_groups, + num_replicas_per_group]. `group_assignment[i]` represents the replica + ids in the ith subgroup. + name: Optional op name. + + Returns: + A `Tensor` which is concatenated by data from different replicas. + """ + if group_assignment is None: + group_assignment = _create_default_group_assignment() + return gen_tpu_ops.all_to_all( + x, + group_assignment, + concat_dimension=concat_dimension, + split_dimension=split_dimension, + split_count=split_count, + name=name) + + @ops.RegisterGradient("AllToAll") + def _all_to_all_grad(op, grad): + # The gradient of a all-to-all is also a all-to-all but the + # split_dimension and concat_dimension is swapped. + # The graident with respect to group_assignment is None. + return [ + gen_tpu_ops.all_to_all( + grad, + op.inputs[1], + concat_dimension=op.get_attr("split_dimension"), + split_dimension=op.get_attr("concat_dimension"), + split_count=op.get_attr("split_count")), None + ] + def cross_replica_sum(x, group_assignment=None, name=None): """Sum the input tensor accorss replicas according to group_assignment. @@ -52,13 +108,7 @@ if platform.system() != "Windows": A `Tensor` which is summed across replicas. """ if group_assignment is None: - num_shards = tpu_function.get_tpu_context().number_of_shards - if num_shards is None: - logging.warning( - "cross_replica_sum should be used within a tpu_shard_context, but " - "got unset number_of_shards. Assuming 1.") - num_shards = 1 - group_assignment = [list(range(num_shards))] + group_assignment = _create_default_group_assignment() return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) -- GitLab From 0a3036e9865672229619d1e673a8bf64a2c723d1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 16:22:12 -0700 Subject: [PATCH 157/540] Re-added proto field for dynamic learning rate support (not usable yet). PiperOrigin-RevId: 211719009 --- .../contrib/tpu/proto/optimization_parameters.proto | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto index cbf6809257..fc1320501b 100644 --- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -9,8 +9,8 @@ message ClippingLimits { google.protobuf.FloatValue upper = 2; // +inf if not set } -// Get the learning rate from a source that can change -// dynamically. +// Get the learning rate from the parameters of the SendTPUEmbeddingGradients +// op. message DynamicLearningRate { } @@ -18,10 +18,8 @@ message DynamicLearningRate { message LearningRate { oneof learning_rate { float constant = 1; - // DynamicLearningRate dynamic = 2; -- disabled while code is being - // rewritten. + DynamicLearningRate dynamic = 2; } - reserved 2; } message AdagradParameters { -- GitLab From bded7fb63e932c7a7139a32d0e958479d90dbc1d Mon Sep 17 00:00:00 2001 From: Olivia Nordquist Date: Wed, 5 Sep 2018 16:24:10 -0700 Subject: [PATCH 158/540] disable msan in failing test PiperOrigin-RevId: 211719342 --- tensorflow/python/estimator/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index f6ef6d8dcb..cf8e18b216 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -687,6 +687,7 @@ py_test( "manual", # b/112769036, b/113907597 "no_oss", # b/112769036, b/113907597 "no_windows", + "nomsan", "notsan", # b/67510291 ], deps = [ -- GitLab From 2c8bc1587e9480a44c10146d0e9472c1d6f9c7d7 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Wed, 5 Sep 2018 16:24:29 -0700 Subject: [PATCH 159/540] Fix lite_test.py. PiperOrigin-RevId: 211719399 --- tensorflow/contrib/lite/python/BUILD | 2 +- tensorflow/contrib/lite/python/lite.py | 15 +++++++++++---- tensorflow/contrib/lite/python/lite_test.py | 19 +++++++++++++++---- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 6e30251eff..57e1290e07 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -70,7 +70,7 @@ py_library( py_test( name = "lite_test", srcs = ["lite_test.py"], - data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pbtxt"], + data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"], srcs_version = "PY2AND3", tags = [ "no_oss", diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 2de97fec86..44dfb97b84 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -58,6 +58,7 @@ from tensorflow.python.framework import graph_util as _tf_graph_util from tensorflow.python.framework import ops as _ops from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError from tensorflow.python.framework.importer import import_graph_def as _import_graph_def +from tensorflow.python.lib.io import file_io as _file_io from tensorflow.python.saved_model import signature_constants as _signature_constants from tensorflow.python.saved_model import tag_constants as _tag_constants @@ -225,8 +226,10 @@ class TocoConverter(object): TocoConverter class. Raises: - ValueError: + IOError: + File not found. Unable to parse input file. + ValueError: The graph is not frozen. input_arrays or output_arrays contains an invalid tensor name. input_shapes is not correctly defined when required @@ -234,10 +237,13 @@ class TocoConverter(object): with _ops.Graph().as_default(): with _session.Session() as sess: # Read GraphDef from file. - graph_def = _graph_pb2.GraphDef() - with open(graph_def_file, "rb") as f: + if not _file_io.file_exists(graph_def_file): + raise IOError("File '{0}' does not exist.".format(graph_def_file)) + with _file_io.FileIO(graph_def_file, "rb") as f: file_content = f.read() + try: + graph_def = _graph_pb2.GraphDef() graph_def.ParseFromString(file_content) except (_text_format.ParseError, DecodeError): try: @@ -248,9 +254,10 @@ class TocoConverter(object): file_content = file_content.decode("utf-8") else: file_content = file_content.encode("utf-8") + graph_def = _graph_pb2.GraphDef() _text_format.Merge(file_content, graph_def) except (_text_format.ParseError, DecodeError): - raise ValueError( + raise IOError( "Unable to parse input file '{}'.".format(graph_def_file)) # Handles models with custom TFLite ops that cannot be resolved in diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 1c94ba605a..3f8ea433ff 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -521,14 +521,21 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) self.assertEqual((0., 0.), output_details[0]['quantization']) - def testInvalidFile(self): + def testInvalidFileNotFound(self): + with self.assertRaises(IOError) as error: + lite.TocoConverter.from_frozen_graph('invalid_file', ['Placeholder'], + ['add']) + self.assertEqual('File \'invalid_file\' does not exist.', + str(error.exception)) + + def testInvalidFileBadData(self): graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file') with gfile.Open(graph_def_file, 'wb') as temp_file: temp_file.write('bad data') temp_file.flush() # Attempts to convert the invalid model. - with self.assertRaises(ValueError) as error: + with self.assertRaises(IOError) as error: lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], ['add']) self.assertEqual( @@ -539,7 +546,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): def _initObjectDetectionArgs(self): # Initializes the arguments required for the object detection model. self._graph_def_file = resource_loader.get_path_to_datafile( - 'testdata/tflite_graph.pbtxt') + 'testdata/tflite_graph.pb') self._input_arrays = ['normalized_input_image_tensor'] self._output_arrays = [ 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1', @@ -586,7 +593,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): output_details[3]['name']) self.assertTrue(([1] == output_details[3]['shape']).all()) - def testTFLiteGraphDefInvalid(self): + def testTFLiteGraphDefMissingShape(self): # Tests invalid cases for the model that cannot be loaded in TensorFlow. self._initObjectDetectionArgs() @@ -597,6 +604,10 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase): self.assertEqual('input_shapes must be defined for this model.', str(error.exception)) + def testTFLiteGraphDefInvalidShape(self): + # Tests invalid cases for the model that cannot be loaded in TensorFlow. + self._initObjectDetectionArgs() + # `input_shapes` does not contain the names in `input_arrays`. with self.assertRaises(ValueError) as error: lite.TocoConverter.from_frozen_graph( -- GitLab From 352d2a0a2a099ae830855c94a30f9ea657556aef Mon Sep 17 00:00:00 2001 From: Niranjan Hasabnis Date: Wed, 5 Sep 2018 16:35:38 -0700 Subject: [PATCH 160/540] Addressing review comments --- tensorflow/core/common_runtime/mkl_cpu_allocator.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 553f07020e..200ca57a9a 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -133,7 +133,8 @@ class MklSmallSizeAllocator : public VisitableAllocator { private: // Increment statistics for the allocator handling small allocations. - inline void IncrementStats(size_t alloc_size) GUARDED_BY(mutex_) { + inline void + IncrementStats(size_t alloc_size) EXCLUSIVE_LOCKS_REQUIRED(mutex_) { ++stats_.num_allocs; stats_.bytes_in_use += alloc_size; stats_.max_bytes_in_use = std::max(stats_.max_bytes_in_use, @@ -143,7 +144,8 @@ class MklSmallSizeAllocator : public VisitableAllocator { } // Decrement statistics for the allocator handling small allocations. - inline void DecrementStats(size_t dealloc_size) GUARDED_BY(mutex_) { + inline void + DecrementStats(size_t dealloc_size) EXCLUSIVE_LOCKS_REQUIRED(mutex_) { stats_.bytes_in_use -= dealloc_size; } -- GitLab From 7dfc0756439aede05ec471193780a4de9f61874e Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Wed, 5 Sep 2018 16:38:33 -0700 Subject: [PATCH 161/540] Propagate eager output tensor types in TFLite PiperOrigin-RevId: 211721354 --- .../lite/delegates/eager/delegate_test.cc | 20 +++++++++ .../contrib/lite/delegates/eager/kernel.cc | 2 +- .../contrib/lite/delegates/eager/test_util.cc | 43 ++++++++++--------- .../contrib/lite/delegates/eager/test_util.h | 28 ++++++++++-- .../contrib/lite/delegates/eager/util.cc | 36 +++++++++++++++- .../contrib/lite/delegates/eager/util.h | 13 ++++-- .../contrib/lite/delegates/eager/util_test.cc | 38 +++++++++++++--- 7 files changed, 145 insertions(+), 35 deletions(-) diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc index eb47f46c0b..984f8bbc98 100644 --- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc @@ -72,6 +72,26 @@ TEST_F(DelegateTest, FullGraph) { ASSERT_THAT(GetShape(8), ElementsAre(2, 1)); ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f)); + ASSERT_EQ(GetType(8), kTfLiteFloat32); +} + +TEST_F(DelegateTest, NonFloatTypeInference) { + AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2}); + + AddTfOp(testing::kAdd, {0, 1}, {2}); + + ConfigureDelegate(); + + SetShape(0, {2, 2}); + SetTypedValues(0, {1, 2, 3, 4}); + SetShape(1, {2, 2}); + SetTypedValues(1, {4, 3, 2, 1}); + + ASSERT_TRUE(Invoke()); + + ASSERT_THAT(GetShape(2), ElementsAre(2, 2)); + ASSERT_THAT(GetTypedValues(2), ElementsAre(5, 5, 5, 5)); + ASSERT_EQ(GetType(2), kTfLiteInt32); } TEST_F(DelegateTest, MixedGraph) { diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc index f8467c7cb2..0ee4db1ffb 100644 --- a/tensorflow/contrib/lite/delegates/eager/kernel.cc +++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc @@ -278,7 +278,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* tensor = &context->tensors[tensor_index]; TF_LITE_ENSURE_OK( context, - CopyShape(context, buffer_map->GetTensor(tensor_index), tensor)); + CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor)); tensor->buffer_handle = tensor_index; tensor->data_is_stale = true; } diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc index b8c9e2652a..8584999ace 100644 --- a/tensorflow/contrib/lite/delegates/eager/test_util.cc +++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc @@ -25,19 +25,6 @@ namespace testing { bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; } -void EagerModelTest::SetValues(int tensor_index, - const std::vector& values) { - float* v = interpreter_->typed_tensor(tensor_index); - for (float f : values) { - *v++ = f; - } -} - -std::vector EagerModelTest::GetValues(int tensor_index) { - TfLiteTensor* o = interpreter_->tensor(tensor_index); - return std::vector(o->data.f, o->data.f + o->bytes / sizeof(float)); -} - void EagerModelTest::SetShape(int tensor_index, const std::vector& values) { ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk); @@ -54,13 +41,21 @@ std::vector EagerModelTest::GetShape(int tensor_index) { return result; } +TfLiteType EagerModelTest::GetType(int tensor_index) { + return interpreter_->tensor(tensor_index)->type; +} + void EagerModelTest::AddTensors(int num_tensors, const std::vector& inputs, const std::vector& outputs, - const TfLiteType& type, - const std::vector& dims) { + TfLiteType type, const std::vector& dims) { interpreter_->AddTensors(num_tensors); for (int i = 0; i < num_tensors; ++i) { TfLiteQuantizationParams quant; + // Suppress explicit output type specification to ensure type inference + // works properly. + if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) { + type = kTfLiteFloat32; + } CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type, /*name=*/"", /*dims=*/dims, quant), @@ -101,18 +96,26 @@ void EagerModelTest::AddTfOp(TfOpType op, const std::vector& inputs, return " attr{ key: '" + key + "' value {" + value + "}}"; }; + // Crude type attribution, will need fleshing out as more tests are added. + // TODO(b/113613439): Use nodedef string utilities to properly handle + // all types. + string type_attribute = attr("T", "type: DT_FLOAT"); + if (interpreter_->tensor(inputs[0])->type == kTfLiteInt32) { + type_attribute = attr("T", "type: DT_INT32"); + } + if (op == kUnpack) { - string attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") + - attr("axis", "i: 0"); + string attributes = + type_attribute + attr("num", "i: 2") + attr("axis", "i: 0"); AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs); } else if (op == kIdentity) { - string attributes = attr("T", "type: DT_FLOAT"); + string attributes = type_attribute; AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs); } else if (op == kAdd) { - string attributes = attr("T", "type: DT_FLOAT"); + string attributes = type_attribute; AddTfOp("EagerAdd", "Add", attributes, inputs, outputs); } else if (op == kMul) { - string attributes = attr("T", "type: DT_FLOAT"); + string attributes = type_attribute; AddTfOp("EagerMul", "Mul", attributes, inputs, outputs); } else if (op == kNonExistent) { AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs); diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/eager/test_util.h index 0eab9e1135..816db41931 100644 --- a/tensorflow/contrib/lite/delegates/eager/test_util.h +++ b/tensorflow/contrib/lite/delegates/eager/test_util.h @@ -44,11 +44,30 @@ class EagerModelTest : public ::testing::Test { bool Invoke(); + // Sets the (typed) tensor's values at the given index. + template + void SetTypedValues(int tensor_index, const std::vector& values) { + memcpy(interpreter_->typed_tensor(tensor_index), values.data(), + values.size() * sizeof(T)); + } + + // Returns the (typed) tensor's values at the given index. + template + std::vector GetTypedValues(int tensor_index) { + const TfLiteTensor* t = interpreter_->tensor(tensor_index); + const T* tdata = interpreter_->typed_tensor(tensor_index); + return std::vector(tdata, tdata + t->bytes / sizeof(T)); + } + // Sets the tensor's values at the given index. - void SetValues(int tensor_index, const std::vector& values); + void SetValues(int tensor_index, const std::vector& values) { + SetTypedValues(tensor_index, values); + } // Returns the tensor's values at the given index. - std::vector GetValues(int tensor_index); + std::vector GetValues(int tensor_index) { + return GetTypedValues(tensor_index); + } // Sets the tensor's shape at the given index. void SetShape(int tensor_index, const std::vector& values); @@ -56,13 +75,16 @@ class EagerModelTest : public ::testing::Test { // Returns the tensor's shape at the given index. std::vector GetShape(int tensor_index); + // Returns the tensor's type at the given index. + TfLiteType GetType(int tensor_index); + const TestErrorReporter& error_reporter() const { return error_reporter_; } // Adds `num_tensor` tensors to the model. `inputs` contains the indices of // the input tensors and `outputs` contains the indices of the output // tensors. All tensors are set to have `type` and `dims`. void AddTensors(int num_tensors, const std::vector& inputs, - const std::vector& outputs, const TfLiteType& type, + const std::vector& outputs, TfLiteType type, const std::vector& dims); // Adds a TFLite Mul op. `inputs` contains the indices of the input tensors diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc index 4426c653e6..051246bf86 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.cc +++ b/tensorflow/contrib/lite/delegates/eager/util.cc @@ -26,8 +26,17 @@ TfLiteStatus ConvertStatus(TfLiteContext* context, return kTfLiteOk; } -TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src, - TfLiteTensor* tensor) { +TfLiteStatus CopyShapeAndType(TfLiteContext* context, + const tensorflow::Tensor& src, + TfLiteTensor* tensor) { + tensor->type = GetTensorFlowLiteType(static_cast(src.dtype())); + if (tensor->type == kTfLiteNoType) { + context->ReportError(context, + "TF Lite does not support TensorFlow data type: %s", + DataTypeString(src.dtype()).c_str()); + return kTfLiteError; + } + int num_dims = src.dims(); TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims); for (int j = 0; j < num_dims; ++j) { @@ -68,5 +77,28 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) { } } +TfLiteType GetTensorFlowLiteType(TF_DataType type) { + switch (type) { + case TF_FLOAT: + return kTfLiteFloat32; + case TF_INT16: + return kTfLiteInt16; + case TF_INT32: + return kTfLiteInt32; + case TF_UINT8: + return kTfLiteUInt8; + case TF_INT64: + return kTfLiteInt64; + case TF_COMPLEX64: + return kTfLiteComplex64; + case TF_STRING: + return kTfLiteString; + case TF_BOOL: + return kTfLiteBool; + default: + return kTfLiteNoType; + } +} + } // namespace eager } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h index a9407be071..ff500d18f3 100644 --- a/tensorflow/contrib/lite/delegates/eager/util.h +++ b/tensorflow/contrib/lite/delegates/eager/util.h @@ -28,14 +28,19 @@ namespace eager { TfLiteStatus ConvertStatus(TfLiteContext* context, const tensorflow::Status& status); -// Copies the given shape of the given 'src' into a TF Lite 'tensor'. Logs an -// error and returns kTfLiteError if the shape can't be converted. -TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src, - TfLiteTensor* tensor); +// Copies the given shape and type of the TensorFlow 'src' tensor into a TF Lite +// 'tensor'. Logs an error and returns kTfLiteError if the shape or type can't +// be converted. +TfLiteStatus CopyShapeAndType(TfLiteContext* context, + const tensorflow::Tensor& src, + TfLiteTensor* tensor); // Returns the TF C API Data type that corresponds to the given TfLiteType. TF_DataType GetTensorFlowDataType(TfLiteType type); +// Returns the TfLiteType that corresponds to the given TF C API Data type. +TfLiteType GetTensorFlowLiteType(TF_DataType); + } // namespace eager } // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc index 53378a1eaf..aebc91149c 100644 --- a/tensorflow/contrib/lite/delegates/eager/util_test.cc +++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc @@ -26,6 +26,7 @@ namespace eager { namespace { using tensorflow::DT_FLOAT; +using tensorflow::DT_INT32; using tensorflow::Tensor; using ::testing::ElementsAre; @@ -71,27 +72,41 @@ TEST(UtilTest, ConvertStatus) { EXPECT_TRUE(context.error.empty()); } -TEST(UtilTest, CopyShape) { +TEST(UtilTest, CopyShapeAndType) { TestContext context; context.ReportError = ReportError; context.ResizeTensor = ResizeTensor; TfLiteTensor dst; - EXPECT_EQ(CopyShape(&context, Tensor(), &dst), kTfLiteOk); + EXPECT_EQ(CopyShapeAndType(&context, Tensor(), &dst), kTfLiteOk); EXPECT_THAT(context.new_size, ElementsAre(0)); + EXPECT_EQ(dst.type, kTfLiteFloat32); - EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1, 2}), &dst), kTfLiteOk); + EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1, 2}), &dst), + kTfLiteOk); EXPECT_THAT(context.new_size, ElementsAre(1, 2)); + EXPECT_EQ(dst.type, kTfLiteFloat32); - EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst), + EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_INT32, {1, 2}), &dst), + kTfLiteOk); + EXPECT_THAT(context.new_size, ElementsAre(1, 2)); + EXPECT_EQ(dst.type, kTfLiteInt32); + + EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst), kTfLiteError); EXPECT_EQ(context.error, "Dimension value in TensorFlow shape is larger than supported by " "TF Lite"); + + EXPECT_EQ( + CopyShapeAndType(&context, Tensor(tensorflow::DT_HALF, {1, 2}), &dst), + kTfLiteError); + EXPECT_EQ(context.error, + "TF Lite does not support TensorFlow data type: half"); } -TEST(UtilTest, TypeConversions) { +TEST(UtilTest, TypeConversionsFromTFLite) { EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType)); EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32)); EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16)); @@ -103,6 +118,19 @@ TEST(UtilTest, TypeConversions) { EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool)); } +TEST(UtilTest, TypeConversionsFromTensorFlow) { + EXPECT_EQ(kTfLiteFloat32, GetTensorFlowLiteType(TF_FLOAT)); + EXPECT_EQ(kTfLiteInt16, GetTensorFlowLiteType(TF_INT16)); + EXPECT_EQ(kTfLiteInt32, GetTensorFlowLiteType(TF_INT32)); + EXPECT_EQ(kTfLiteUInt8, GetTensorFlowLiteType(TF_UINT8)); + EXPECT_EQ(kTfLiteInt64, GetTensorFlowLiteType(TF_INT64)); + EXPECT_EQ(kTfLiteComplex64, GetTensorFlowLiteType(TF_COMPLEX64)); + EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING)); + EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL)); + EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE)); + EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_VARIANT)); +} + } // namespace } // namespace eager } // namespace tflite -- GitLab From 0eaf0f8ac6791ef2b841fa08aff41d85be189e9f Mon Sep 17 00:00:00 2001 From: Raghuraman Krishnamoorthi Date: Wed, 5 Sep 2018 16:39:38 -0700 Subject: [PATCH 162/540] Upload floating point mobilenet-v2 and resnet-v2-101 models. Also upload fully quantized mobilenet-v2 and inception-v3 models. PiperOrigin-RevId: 211721504 --- tensorflow/contrib/lite/g3doc/models.md | 68 +++++++++++++------------ 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md index 0f9d016e6d..88f6cda420 100644 --- a/tensorflow/contrib/lite/g3doc/models.md +++ b/tensorflow/contrib/lite/g3doc/models.md @@ -3,33 +3,34 @@ ## Image classification (Float Models) -Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance -------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------: -DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms -SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms -NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 74.2% | 91.7% | 261 ms | 389 ms -NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.8% | 96.2% | 6697 ms | 7940 ms -ResNet_V2_50 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_50_2018_04_27.tgz) | 102.3 Mb | 68.1% | 88.4% | 942 ms | 1008 ms -ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_101_2018_04_27.tgz) | 178.3 Mb | 70.4% | 89.6% | 1880 ms | 1970 ms -Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 78.2% | 94.0% | 1433 ms | 1522 ms -Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.4% | 95.2% | 2986 ms | 3139 ms -Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.8% | 94.1% | 2731 ms | 2926 ms -Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.6% | 66.6% | 6.2 ms | 13.0 ms -Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.7% | 70.6% | 8.6 ms | 19.5 ms -Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.5% | 72.4% | 12.1 ms | 27.8 ms -Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 50.0% | 74.4% | 16.2 ms | 37.3 ms -Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.5% | 79.5% | 18.1 ms | 29.9 ms -Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.3% | 82.1% | 26.8 ms | 45.9 ms -Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 62.0% | 83.7% | 35.6 ms | 65.3 ms -Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.5% | 85.0% | 47.6 ms | 164.2 ms -Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.3% | 84.1% | 34.6 ms | 48.7 ms -Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.5% | 86.1% | 51.3 ms | 75.2 ms -Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.4% | 87.4% | 71.7 ms | 107.0 ms -Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.6% | 88.3% | 95.7 ms | 143.4 ms -Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.5% | 85.9% | 57.4 ms | 76.8 ms -Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.3% | 87.8% | 86.0 ms | 117.7 ms -Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 70.2% | 89.3% | 118.6 ms | 167.3 ms -Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.3% | 90.1% | 160.1 ms | 224.3 ms +Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance +--------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------: +DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms +SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms +NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 74.2% | 91.7% | 261 ms | 389 ms +NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.8% | 96.2% | 6697 ms | 7940 ms +ResNet_V2_50 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_50_2018_04_27.tgz) | 102.3 Mb | 68.1% | 88.4% | 942 ms | 1008 ms +ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz) | 178.3 Mb | 70.4% | 89.6% | 1880 ms | 1970 ms +Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 78.2% | 94.0% | 1433 ms | 1522 ms +Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.4% | 95.2% | 2986 ms | 3139 ms +Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.8% | 94.1% | 2731 ms | 2926 ms +Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.6% | 66.6% | 6.2 ms | 13.0 ms +Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.7% | 70.6% | 8.6 ms | 19.5 ms +Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.5% | 72.4% | 12.1 ms | 27.8 ms +Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 50.0% | 74.4% | 16.2 ms | 37.3 ms +Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.5% | 79.5% | 18.1 ms | 29.9 ms +Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.3% | 82.1% | 26.8 ms | 45.9 ms +Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 62.0% | 83.7% | 35.6 ms | 65.3 ms +Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.5% | 85.0% | 47.6 ms | 164.2 ms +Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.3% | 84.1% | 34.6 ms | 48.7 ms +Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.5% | 86.1% | 51.3 ms | 75.2 ms +Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.4% | 87.4% | 71.7 ms | 107.0 ms +Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.6% | 88.3% | 95.7 ms | 143.4 ms +Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.5% | 85.9% | 57.4 ms | 76.8 ms +Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.3% | 87.8% | 86.0 ms | 117.7 ms +Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 70.2% | 89.3% | 118.6 ms | 167.3 ms +Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.3% | 90.1% | 160.1 ms | 224.3 ms +Mobilenet_V2_1.0_224 | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz) | 14.0 Mb | 71.9% | 90.1% | 117 ms | ^ The model files include both TF Lite FlatBuffer and Tensorflow frozen Graph. @@ -41,8 +42,8 @@ after excluding blacklisted images. ## Image classification (Quantized Models) -Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance ------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: +Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance +--------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------: Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.8% | 64.8% | 3.7 ms Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.0% | 68.4% | 5.5 ms Mobilenet_V1_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 46.0% | 71.2% | 7.9 ms @@ -59,9 +60,12 @@ Mobilenet_V1_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tf Mobilenet_V1_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.2% | 86.9% | 37.4 ms Mobilenet_V1_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.4% | 88.3% | 51.9 ms Mobilenet_V1_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.2% | 89.1% | 70.2 ms +Mobilenet_v2_1.0_224_quant | [paper](https://arxiv.org/abs/1806.08342), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz) | 3.4 Mb | 71.1% | 90.1% | 80.3 ms +Inception_v3_quant | [paper](https://arxiv.org/abs/1806.08342),[tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz) | 23 Mb | 77.5% | 93.6% | 637 ms ## Other models -Model | TF Lite FlatBuffer ------------------------ | :----------------: -Smart Reply 1.0 Android | [reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html), [tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip) +Lite FlatBuffer ----------------------- | :----------------: Smart Reply 1.0 +Android | +[reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html), +[tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip) -- GitLab From ad6248bf67eb91efe43da714ded953d698580732 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 16:43:42 -0700 Subject: [PATCH 163/540] Convert more kernel signatures to use runtime shapes. PiperOrigin-RevId: 211722113 --- .../internal/reference/reference_ops.h | 102 ++++++++++++++---- .../contrib/lite/kernels/internal/types.h | 5 + 2 files changed, 85 insertions(+), 22 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 00f9616cc2..a027a47726 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -3398,10 +3398,12 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, } } -inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, - int32 zero_point, double scale, float* output_data, - const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Dequantize(const tflite::DequantizationParams& op_params, + const RuntimeShape& input_shape, const uint8* input_data, + const RuntimeShape& output_shape, float* output_data) { + int32 zero_point = op_params.zero_point; + double scale = op_params.scale; + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { int32 val = input_data[i]; @@ -3410,9 +3412,25 @@ inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, } } -inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, - float rmin, float rmax, int num_bits, float* output_data, - const Dims<4>& output_dims) { +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy Dims<4>. +inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, + int32 zero_point, double scale, float* output_data, + const Dims<4>& output_dims) { + tflite::DequantizationParams op_params; + op_params.zero_point = zero_point; + op_params.scale = scale; + + Dequantize(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +inline void FakeQuant(const tflite::FakeQuantParams& op_params, + const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + float rmin = op_params.minmax.min; + float rmax = op_params.minmax.max; + int num_bits = op_params.num_bits; // 0 should always be a representable value. Let's assume that the initial // min,max range contains 0. TFLITE_DCHECK_LE(rmin, 0.0f); @@ -3425,11 +3443,25 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, float nudged_min, nudged_max, nudged_scale; NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min, &nudged_max, &nudged_scale); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data, output_data, flat_size); } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy Dims<4>. +inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, + float rmin, float rmax, int num_bits, float* output_data, + const Dims<4>& output_dims) { + tflite::FakeQuantParams op_params; + op_params.num_bits = num_bits; + op_params.minmax.min = rmin; + op_params.minmax.max = rmax; + + FakeQuant(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + template inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data, const RuntimeShape& output_shape, DstT* output_data) { @@ -4050,22 +4082,32 @@ inline bool Mean(const T* input_data, const int* input_dims, } template -inline void Mean(const T* input_data, const Dims<4>& input_dims, - const std::vector& reduction_indices, T* output_data, - const Dims<4>& output_dims) { - const int output_batch = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int output_depth = ArraySize(output_dims, 0); +inline void Mean(const tflite::MeanParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, T* output_data) { + gemmlowp::ScopedProfilingLabel label("Mean"); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int output_batch = output_shape.Dims(0); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int output_depth = output_shape.Dims(3); + + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); // The current implementation only supports simultaneous reduction over // width and height. - TFLITE_DCHECK_EQ(reduction_indices.size(), 2); - TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || - (reduction_indices[0] == 2 && reduction_indices[1] == 1)); + TFLITE_DCHECK_EQ(op_params.axis_count, 2); + TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) || + (op_params.axis[0] == 2 && op_params.axis[1] == 1)); TFLITE_DCHECK_EQ(output_height, 1); TFLITE_DCHECK_EQ(output_width, 1); @@ -4074,15 +4116,31 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims, float value = 0; for (int in_h = 0; in_h < input_height; ++in_h) { for (int in_w = 0; in_w < input_width; ++in_w) { - value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; + value += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)]; } } - output_data[Offset(output_dims, out_d, 0, 0, out_b)] = + output_data[Offset(output_shape, out_b, 0, 0, out_d)] = value / (input_width * input_height); } } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy Dims<4>. +template +inline void Mean(const T* input_data, const Dims<4>& input_dims, + const std::vector& reduction_indices, T* output_data, + const Dims<4>& output_dims) { + tflite::MeanParams op_params; + op_params.axis_count = reduction_indices.size(); + for (int i = 0; i < op_params.axis_count; ++i) { + op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i]; + } + + Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); +} + // Computes the mean of elements across dimensions given in axis. // It does so in two stages, first calculates the sum of elements along the axis // then divides it by the number of element in axis for quantized values. diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 9f6e74a267..c4c7cf3842 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -769,6 +769,11 @@ struct DepthwiseParams { int32 output_activation_max; }; +struct DequantizationParams { + double scale; + int32 zero_point; +}; + struct FakeQuantParams { MinMax minmax; int32 num_bits; -- GitLab From e7b37766f53d5d9d976f2ba3046d3df3333c8ebb Mon Sep 17 00:00:00 2001 From: Brennan Saeta Date: Wed, 5 Sep 2018 17:03:46 -0700 Subject: [PATCH 164/540] [Keras / Cloud TPU]: Correct indexing for software pipelining. PiperOrigin-RevId: 211724843 --- tensorflow/contrib/tpu/python/tpu/keras_support.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index dd7f8b678f..08e0465b71 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -1657,7 +1657,7 @@ class KerasTPUModel(models.Model): 'make sure your paths are correct and you have ' 'permissions to read the files. Skipping validation') - for step_index in range(steps_per_epoch - 1): + for step_index in range(steps_per_epoch): batch_logs = {'batch': step_index, 'size': 1} callbacks.on_batch_begin(step_index, batch_logs) try: -- GitLab From 017599d0a1fa7a7227a43649db67e96311033a4e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 17:13:24 -0700 Subject: [PATCH 165/540] This CL changes the graph-mode API of the learning_rate_decay functions in TF 2.0 to return a no-arg callable to output a learning rate, instead of directly outputting a learning rate tensor. This brings the graph mode API in line with the eager execution API, where this change was made to allow changing the learning rate value across different invocations of optimizer functions. PiperOrigin-RevId: 211726295 --- tensorflow/python/BUILD | 1 + .../python/training/learning_rate_decay.py | 432 +++------ .../python/training/learning_rate_decay_v2.py | 898 ++++++++++++++++++ .../training/learning_rate_decay_v2_test.py | 497 ++++++++++ .../tools/compatibility/tf_upgrade_v2.py | 24 + .../tools/compatibility/tf_upgrade_v2_test.py | 13 + 6 files changed, 1547 insertions(+), 318 deletions(-) create mode 100644 tensorflow/python/training/learning_rate_decay_v2.py create mode 100644 tensorflow/python/training/learning_rate_decay_v2_test.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e6169e9e80..ba9c6a2320 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4393,6 +4393,7 @@ cuda_py_tests( "training/ftrl_test.py", "training/gradient_descent_test.py", "training/learning_rate_decay_test.py", + "training/learning_rate_decay_v2_test.py", "training/momentum_test.py", "training/optimizer_test.py", "training/proximal_adagrad_test.py", diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py index fd195a7965..29b5465321 100644 --- a/tensorflow/python/training/learning_rate_decay.py +++ b/tensorflow/python/training/learning_rate_decay.py @@ -17,19 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops +from tensorflow.python.training import learning_rate_decay_v2 from tensorflow.python.util.tf_export import tf_export -@tf_export("train.exponential_decay") +@tf_export(v1=["train.exponential_decay"]) def exponential_decay(learning_rate, global_step, decay_steps, @@ -95,32 +88,19 @@ def exponential_decay(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - if global_step is None: - raise ValueError("global_step is required for exponential_decay.") - with ops.name_scope( - name, "ExponentialDecay", - [learning_rate, global_step, decay_steps, decay_rate]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - decay_steps = math_ops.cast(decay_steps, dtype) - decay_rate = math_ops.cast(decay_rate, dtype) - - def decayed_lr(): - """Helper to recompute learning rate; most helpful in eager-mode.""" - global_step_recomp = math_ops.cast(global_step, dtype) - p = global_step_recomp / decay_steps - if staircase: - p = math_ops.floor(p) - return math_ops.multiply( - learning_rate, math_ops.pow(decay_rate, p), name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr() - - return decayed_lr - - -@tf_export("train.piecewise_constant") + decayed_lr = learning_rate_decay_v2.exponential_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=staircase, + name=name) + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr + + +@tf_export(v1=["train.piecewise_constant"]) def piecewise_constant(x, boundaries, values, name=None): """Piecewise constant from boundaries and interval values. @@ -163,58 +143,15 @@ def piecewise_constant(x, boundaries, values, name=None): the learning rate value across different invocations of optimizer functions. @end_compatibility """ - if len(boundaries) != len(values) - 1: - raise ValueError( - "The length of boundaries should be 1 less than the length of values") - with ops.name_scope(name, "PiecewiseConstant", - [x, boundaries, values, name]) as name: - boundaries = ops.convert_n_to_tensor(boundaries) - values = ops.convert_n_to_tensor(values) - - def decayed_lr(): - """Helper to recompute learning rate; most helpful in eager-mode.""" - x_recomp = ops.convert_to_tensor(x) - # Avoid explicit conversion to x's dtype. This could result in faulty - # comparisons, for example if floats are converted to integers. - for i, b in enumerate(boundaries): - if b.dtype.base_dtype != x_recomp.dtype.base_dtype: - # We can promote int32 boundaries to int64 without loss of precision. - # This covers the most common case where the user passes in boundaries - # as an array of Python integers. - if (b.dtype.base_dtype == dtypes.int32 and - x_recomp.dtype.base_dtype == dtypes.int64): - b = math_ops.cast(b, x_recomp.dtype.base_dtype) - boundaries[i] = b - else: - raise ValueError( - "Boundaries (%s) must have the same dtype as x (%s)." % - (b.dtype.base_dtype, x_recomp.dtype.base_dtype)) - # TODO(rdipietro): Ensure that boundaries' elements strictly increases. - for v in values[1:]: - if v.dtype.base_dtype != values[0].dtype.base_dtype: - raise ValueError( - "Values must have elements all with the same dtype (%s vs %s)." % - (values[0].dtype.base_dtype, v.dtype.base_dtype)) - pred_fn_pairs = [] - pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0])) - pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1])) - for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): - # Need to bind v here; can do this with lambda v=v: ... - pred = (x_recomp > low) & (x_recomp <= high) - pred_fn_pairs.append((pred, lambda v=v: v)) - - # The default isn't needed here because our conditions are mutually - # exclusive and exhaustive, but tf.case requires it. - default = lambda: values[0] - return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr() - - return decayed_lr - - -@tf_export("train.polynomial_decay") + decayed_lr = learning_rate_decay_v2.piecewise_constant(x, boundaries, values, + name=name) + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr + + +@tf_export(v1=["train.polynomial_decay"]) def polynomial_decay(learning_rate, global_step, decay_steps, @@ -299,46 +236,22 @@ def polynomial_decay(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - if global_step is None: - raise ValueError("global_step is required for polynomial_decay.") - with ops.name_scope( - name, "PolynomialDecay", - [learning_rate, global_step, decay_steps, end_learning_rate, power - ]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - end_learning_rate = math_ops.cast(end_learning_rate, dtype) - power = math_ops.cast(power, dtype) - - def decayed_lr(): - """Helper to recompute learning rate; most helpful in eager-mode.""" - global_step_recomp = math_ops.cast(global_step, dtype) - decay_steps_recomp = math_ops.cast(decay_steps, dtype) - if cycle: - # Find the first multiple of decay_steps that is bigger than - # global_step. If global_step is zero set the multiplier to 1 - multiplier = control_flow_ops.cond( - math_ops.equal(global_step_recomp, 0), lambda: 1.0, - lambda: math_ops.ceil(global_step_recomp / decay_steps)) - decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier) - else: - # Make sure that the global_step used is not bigger than decay_steps. - global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) - - p = math_ops.div(global_step_recomp, decay_steps_recomp) - return math_ops.add( - math_ops.multiply(learning_rate - end_learning_rate, - math_ops.pow(1 - p, power)), - end_learning_rate, - name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr() - - return decayed_lr - - -@tf_export("train.natural_exp_decay") + decayed_lr = learning_rate_decay_v2.polynomial_decay( + learning_rate, + global_step, + decay_steps, + end_learning_rate=end_learning_rate, + power=power, + cycle=cycle, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr + + +@tf_export(v1=["train.natural_exp_decay"]) def natural_exp_decay(learning_rate, global_step, decay_steps, @@ -410,32 +323,17 @@ def natural_exp_decay(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - if global_step is None: - raise ValueError("global_step is required for natural_exp_decay.") - with ops.name_scope(name, "NaturalExpDecay", - [learning_rate, global_step, decay_rate]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - decay_steps = math_ops.cast(decay_steps, dtype) - decay_rate = math_ops.cast(decay_rate, dtype) - - def decayed_lr(): - """Helper to recompute learning rate; most helpful in eager-mode.""" - global_step_recomp = math_ops.cast(global_step, dtype) - p = global_step_recomp / decay_steps - if staircase: - p = math_ops.floor(p) - exponent = math_ops.exp( - math_ops.multiply(math_ops.negative(decay_rate), p)) - return math_ops.multiply(learning_rate, exponent, name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr() - - return decayed_lr - - -@tf_export("train.inverse_time_decay") + decayed_lr = learning_rate_decay_v2.natural_exp_decay( + learning_rate, global_step, decay_steps, decay_rate, staircase=staircase, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr + + +@tf_export(v1=["train.inverse_time_decay"]) def inverse_time_decay(learning_rate, global_step, decay_steps, @@ -507,32 +405,21 @@ def inverse_time_decay(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - if global_step is None: - raise ValueError("global_step is required for inverse_time_decay.") - with ops.name_scope(name, "InverseTimeDecay", - [learning_rate, global_step, decay_rate]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - decay_steps = math_ops.cast(decay_steps, dtype) - decay_rate = math_ops.cast(decay_rate, dtype) - - def decayed_lr(): - """Helper to recompute learning rate; most helpful in eager-mode.""" - global_step_recomp = math_ops.cast(global_step, dtype) - p = global_step_recomp / decay_steps - if staircase: - p = math_ops.floor(p) - const = math_ops.cast(constant_op.constant(1), dtype) - denom = math_ops.add(const, math_ops.multiply(decay_rate, p)) - return math_ops.div(learning_rate, denom, name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr() - - return decayed_lr - - -@tf_export("train.cosine_decay") + decayed_lr = learning_rate_decay_v2.inverse_time_decay( + learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=staircase, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr + + +@tf_export(v1=["train.cosine_decay"]) def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): """Applies cosine decay to the learning rate. @@ -581,32 +468,16 @@ def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, name=None): the learning rate value across different invocations of optimizer functions. @end_compatibility """ - if global_step is None: - raise ValueError("cosine decay requires global_step") - with ops.name_scope(name, "CosineDecay", - [learning_rate, global_step]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - decay_steps = math_ops.cast(decay_steps, dtype) - - def decayed_lr(): - """Helper to recompute learning rate; most helpful in eager-mode.""" - global_step_recomp = math_ops.cast(global_step, dtype) - global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) - completed_fraction = global_step_recomp / decay_steps - cosine_decayed = 0.5 * (1.0 + math_ops.cos( - constant_op.constant(math.pi) * completed_fraction)) - - decayed = (1 - alpha) * cosine_decayed + alpha - return math_ops.multiply(learning_rate, decayed) + decayed_lr = learning_rate_decay_v2.cosine_decay( + learning_rate, global_step, decay_steps, alpha=alpha, name=name) - if not context.executing_eagerly(): - decayed_lr = decayed_lr() + if not context.executing_eagerly(): + decayed_lr = decayed_lr() - return decayed_lr + return decayed_lr -@tf_export("train.cosine_decay_restarts") +@tf_export(v1=["train.cosine_decay_restarts"]) def cosine_decay_restarts(learning_rate, global_step, first_decay_steps, @@ -664,57 +535,22 @@ def cosine_decay_restarts(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - if global_step is None: - raise ValueError("cosine decay restarts requires global_step") - with ops.name_scope(name, "SGDRDecay", [learning_rate, global_step]) as name: - learning_rate = ops.convert_to_tensor( - learning_rate, name="initial_learning_rate") - dtype = learning_rate.dtype - first_decay_steps = math_ops.cast(first_decay_steps, dtype) - alpha = math_ops.cast(alpha, dtype) - t_mul = math_ops.cast(t_mul, dtype) - m_mul = math_ops.cast(m_mul, dtype) - - def decayed_lr(): - """Helper to recompute learning rate; most helpful in eager-mode.""" - global_step_recomp = math_ops.cast(global_step, dtype) - completed_fraction = global_step_recomp / first_decay_steps - - def compute_step(completed_fraction, geometric=False): - """Helper for `cond` operation.""" - if geometric: - i_restart = math_ops.floor( - math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) / - math_ops.log(t_mul)) - - sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) - completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart - - else: - i_restart = math_ops.floor(completed_fraction) - completed_fraction -= i_restart + decayed_lr = learning_rate_decay_v2.cosine_decay_restarts( + learning_rate, + global_step, + first_decay_steps, + t_mul=t_mul, + m_mul=m_mul, + alpha=alpha, + name=name) - return i_restart, completed_fraction + if not context.executing_eagerly(): + decayed_lr = decayed_lr() - i_restart, completed_fraction = control_flow_ops.cond( - math_ops.equal(t_mul, 1.0), - lambda: compute_step(completed_fraction, geometric=False), - lambda: compute_step(completed_fraction, geometric=True)) + return decayed_lr - m_fac = m_mul**i_restart - cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos( - constant_op.constant(math.pi) * completed_fraction)) - decayed = (1 - alpha) * cosine_decayed + alpha - return math_ops.multiply(learning_rate, decayed, name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr() - - return decayed_lr - - -@tf_export("train.linear_cosine_decay") +@tf_export(v1=["train.linear_cosine_decay"]) def linear_cosine_decay(learning_rate, global_step, decay_steps, @@ -781,37 +617,22 @@ def linear_cosine_decay(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - if global_step is None: - raise ValueError("linear cosine decay requires global_step") - with ops.name_scope(name, "LinearCosineDecay", - [learning_rate, global_step]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - decay_steps = math_ops.cast(decay_steps, dtype) - num_periods = math_ops.cast(num_periods, dtype) - alpha = math_ops.cast(alpha, dtype) - beta = math_ops.cast(beta, dtype) - - def decayed_lr(): - """Helper to recompute learning rate; most helpful in eager-mode.""" - global_step_recomp = math_ops.cast(global_step, dtype) - global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) - linear_decayed = (decay_steps - global_step_recomp) / decay_steps - completed_fraction = global_step_recomp / decay_steps - fraction = 2.0 * num_periods * completed_fraction - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) - - linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta - return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr() - - return decayed_lr - - -@tf_export("train.noisy_linear_cosine_decay") + decayed_lr = learning_rate_decay_v2.linear_cosine_decay( + learning_rate, + global_step, + decay_steps, + num_periods=num_periods, + alpha=alpha, + beta=beta, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr + + +@tf_export(v1=["train.noisy_linear_cosine_decay"]) def noisy_linear_cosine_decay(learning_rate, global_step, decay_steps, @@ -886,42 +707,17 @@ def noisy_linear_cosine_decay(learning_rate, the learning rate value across different invocations of optimizer functions. @end_compatibility """ - if global_step is None: - raise ValueError("noisy linear cosine decay requires global_step") - with ops.name_scope(name, "NoisyLinearCosineDecay", - [learning_rate, global_step]) as name: - learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") - dtype = learning_rate.dtype - decay_steps = math_ops.cast(decay_steps, dtype) - initial_variance = math_ops.cast(initial_variance, dtype) - variance_decay = math_ops.cast(variance_decay, dtype) - num_periods = math_ops.cast(num_periods, dtype) - alpha = math_ops.cast(alpha, dtype) - beta = math_ops.cast(beta, dtype) - - def decayed_lr(): - """Helper to recompute learning rate; most helpful in eager-mode.""" - global_step_recomp = math_ops.cast(global_step, dtype) - global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) - linear_decayed = (decay_steps - global_step_recomp) / decay_steps - variance = initial_variance / ( - math_ops.pow(1.0 + global_step_recomp, variance_decay)) - std = math_ops.sqrt(variance) - noisy_linear_decayed = ( - linear_decayed + random_ops.random_normal( - linear_decayed.shape, stddev=std)) - - completed_fraction = global_step_recomp / decay_steps - fraction = 2.0 * num_periods * completed_fraction - cosine_decayed = 0.5 * ( - 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) - noisy_linear_cosine_decayed = ( - (alpha + noisy_linear_decayed) * cosine_decayed + beta) - - return math_ops.multiply( - learning_rate, noisy_linear_cosine_decayed, name=name) - - if not context.executing_eagerly(): - decayed_lr = decayed_lr() - - return decayed_lr + decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay( + learning_rate, global_step, + decay_steps, + initial_variance=initial_variance, + variance_decay=variance_decay, + num_periods=num_periods, + alpha=alpha, + beta=beta, + name=name) + + if not context.executing_eagerly(): + decayed_lr = decayed_lr() + + return decayed_lr diff --git a/tensorflow/python/training/learning_rate_decay_v2.py b/tensorflow/python/training/learning_rate_decay_v2.py new file mode 100644 index 0000000000..9c5e144be6 --- /dev/null +++ b/tensorflow/python/training/learning_rate_decay_v2.py @@ -0,0 +1,898 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Various learning rate decay functions.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import math + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.util.tf_export import tf_export + + +@tf_export("train.exponential_decay", v1=[]) +def exponential_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False, + name=None): + """Applies exponential decay to the learning rate. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies an exponential decay function + to a provided initial learning rate. It requires a `global_step` value to + compute the decayed learning rate. You can just pass a TensorFlow variable + that you increment at each training step. + + The function returns a no-arg function that produces the decayed learning + rate. This can be useful for changing the learning rate value across + different invocations of optimizer functions. + It is computed as: + + ```python + decayed_learning_rate = learning_rate * + decay_rate ^ (global_step / decay_steps) + ``` + + If the argument `staircase` is `True`, then `global_step / decay_steps` is an + integer division and the decayed learning rate follows a staircase function. + + Example: decay every 100000 steps with a base of 0.96: + + ```python + ... + global_step = tf.Variable(0, trainable=False) + starter_learning_rate = 0.1 + learning_rate_fn = tf.train.exponential_decay(starter_learning_rate, + global_step, 100000, 0.96, + staircase=True) + # Passing global_step to minimize() will increment it at each step. + learning_step = ( + tf.train.GradientDescentOptimizer(learning_rate_fn) + .minimize(...my loss..., global_step=global_step) + ) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. + Global step to use for the decay computation. Must not be negative. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. + Must be positive. See the decay computation above. + decay_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The decay rate. + staircase: Boolean. If `True` decay the learning rate at discrete intervals + name: String. Optional name of the operation. Defaults to + 'ExponentialDecay'. + + Returns: + A no-arg function that outputs the decayed learning rate, a scalar `Tensor` + of the same type as `learning_rate`. + + Raises: + ValueError: if `global_step` is not supplied. + """ + if global_step is None: + raise ValueError("global_step is required for exponential_decay.") + def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, + staircase, name): + """Helper to recompute learning rate; most helpful in eager-mode.""" + with ops.name_scope( + name, "ExponentialDecay", + [learning_rate, global_step, decay_steps, decay_rate]) as name: + learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") + dtype = learning_rate.dtype + decay_steps = math_ops.cast(decay_steps, dtype) + decay_rate = math_ops.cast(decay_rate, dtype) + + global_step_recomp = math_ops.cast(global_step, dtype) + p = global_step_recomp / decay_steps + if staircase: + p = math_ops.floor(p) + return math_ops.multiply( + learning_rate, math_ops.pow(decay_rate, p), name=name) + + return functools.partial(decayed_lr, learning_rate, global_step, decay_steps, + decay_rate, staircase, name) + + +@tf_export("train.piecewise_constant", v1=[]) +def piecewise_constant(x, boundaries, values, name=None): + """Piecewise constant from boundaries and interval values. + + This function returns a no-arg callable to compute the piecewise constant. + This can be useful for changing the learning rate value across + different invocations of optimizer functions. + + Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5 + for the next 10000 steps, and 0.1 for any additional steps. + + ```python + global_step = tf.Variable(0, trainable=False) + boundaries = [100000, 110000] + values = [1.0, 0.5, 0.1] + learning_rate_fn = tf.train.piecewise_constant(global_step, boundaries, + values) + learning_rate = learning_rate_fn() + + # Later, whenever we perform an optimization step, we increment global_step. + ``` + + Args: + x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`, + `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`. + boundaries: A list of `Tensor`s or `int`s or `float`s with strictly + increasing entries, and with all elements having the same type as `x`. + values: A list of `Tensor`s or `float`s or `int`s that specifies the values + for the intervals defined by `boundaries`. It should have one more element + than `boundaries`, and all elements should have the same type. + name: A string. Optional name of the operation. Defaults to + 'PiecewiseConstant'. + + Returns: + A no-arg function that outputs a 0-D Tensor. The output of the no-arg + function is `values[0]` when `x <= boundaries[0]`, + `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ..., + and values[-1] when `x > boundaries[-1]`. + + Raises: + ValueError: if types of `x` and `boundaries` do not match, or types of all + `values` do not match or + the number of elements in the lists does not match. + """ + if len(boundaries) != len(values) - 1: + raise ValueError( + "The length of boundaries should be 1 less than the length of values") + def decayed_lr(x, boundaries, values, name): + """Helper to recompute learning rate; most helpful in eager-mode.""" + with ops.name_scope(name, "PiecewiseConstant", + [x, boundaries, values, name]) as name: + boundaries = ops.convert_n_to_tensor(boundaries) + values = ops.convert_n_to_tensor(values) + x_recomp = ops.convert_to_tensor(x) + # Avoid explicit conversion to x's dtype. This could result in faulty + # comparisons, for example if floats are converted to integers. + for i, b in enumerate(boundaries): + if b.dtype.base_dtype != x_recomp.dtype.base_dtype: + # We can promote int32 boundaries to int64 without loss of precision. + # This covers the most common case where the user passes in boundaries + # as an array of Python integers. + if (b.dtype.base_dtype == dtypes.int32 and + x_recomp.dtype.base_dtype == dtypes.int64): + b = math_ops.cast(b, x_recomp.dtype.base_dtype) + boundaries[i] = b + else: + raise ValueError( + "Boundaries (%s) must have the same dtype as x (%s)." % + (b.dtype.base_dtype, x_recomp.dtype.base_dtype)) + # TODO(rdipietro): Ensure that boundaries' elements strictly increases. + for v in values[1:]: + if v.dtype.base_dtype != values[0].dtype.base_dtype: + raise ValueError( + "Values must have elements all with the same dtype (%s vs %s)." % + (values[0].dtype.base_dtype, v.dtype.base_dtype)) + pred_fn_pairs = [] + pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0])) + pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1])) + for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]): + # Need to bind v here; can do this with lambda v=v: ... + pred = (x_recomp > low) & (x_recomp <= high) + pred_fn_pairs.append((pred, lambda v=v: v)) + + # The default isn't needed here because our conditions are mutually + # exclusive and exhaustive, but tf.case requires it. + default = lambda: values[0] + return control_flow_ops.case(pred_fn_pairs, default, exclusive=True) + + return functools.partial(decayed_lr, x, boundaries, values, name) + + +@tf_export("train.polynomial_decay", v1=[]) +def polynomial_decay(learning_rate, + global_step, + decay_steps, + end_learning_rate=0.0001, + power=1.0, + cycle=False, + name=None): + """Applies a polynomial decay to the learning rate. + + It is commonly observed that a monotonically decreasing learning rate, whose + degree of change is carefully chosen, results in a better performing model. + This function applies a polynomial decay function to a provided initial + `learning_rate` to reach an `end_learning_rate` in the given `decay_steps`. + + It requires a `global_step` value to compute the decayed learning rate. You + can just pass a TensorFlow variable that you increment at each training step. + + The function returns a no-arg callable that outputs the decayed learning + rate. This can be useful for changing the learning rate value across + different invocations of optimizer functions. It is computed as: + + ```python + global_step = min(global_step, decay_steps) + decayed_learning_rate = (learning_rate - end_learning_rate) * + (1 - global_step / decay_steps) ^ (power) + + end_learning_rate + + ``` + + If `cycle` is True then a multiple of `decay_steps` is used, the first one + that is bigger than `global_steps`. + + ```python + decay_steps = decay_steps * ceil(global_step / decay_steps) + decayed_learning_rate_fn = (learning_rate - end_learning_rate) * + (1 - global_step / decay_steps) ^ (power) + + end_learning_rate + decayed_learning_rate = decayed_learning_rate_fn() + + ``` + + Example: decay from 0.1 to 0.01 in 10000 steps using sqrt (i.e. power=0.5): + + ```python + ... + global_step = tf.Variable(0, trainable=False) + starter_learning_rate = 0.1 + end_learning_rate = 0.01 + decay_steps = 10000 + learning_rate_fn = tf.train.polynomial_decay(starter_learning_rate, + global_step, decay_steps, + end_learning_rate, + power=0.5) + # Passing global_step to minimize() will increment it at each step. + learning_step = ( + tf.train.GradientDescentOptimizer(learning_rate_fn) + .minimize(...my loss..., global_step=global_step) + ) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. + Global step to use for the decay computation. Must not be negative. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. + Must be positive. See the decay computation above. + end_learning_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The minimal end learning rate. + power: A scalar `float32` or `float64` `Tensor` or a + Python number. The power of the polynomial. Defaults to linear, 1.0. + cycle: A boolean, whether or not it should cycle beyond decay_steps. + name: String. Optional name of the operation. Defaults to + 'PolynomialDecay'. + + Returns: + A no-arg function that outputs the decayed learning rate, a scalar `Tensor` + of the same type as `learning_rate`. + + Raises: + ValueError: if `global_step` is not supplied. + """ + if global_step is None: + raise ValueError("global_step is required for polynomial_decay.") + def decayed_lr(learning_rate, global_step, decay_steps, end_learning_rate, + power, cycle, name): + """Helper to recompute learning rate; most helpful in eager-mode.""" + with ops.name_scope( + name, "PolynomialDecay", + [learning_rate, global_step, decay_steps, end_learning_rate, power] + ) as name: + learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") + dtype = learning_rate.dtype + end_learning_rate = math_ops.cast(end_learning_rate, dtype) + power = math_ops.cast(power, dtype) + + global_step_recomp = math_ops.cast(global_step, dtype) + decay_steps_recomp = math_ops.cast(decay_steps, dtype) + if cycle: + # Find the first multiple of decay_steps that is bigger than + # global_step. If global_step is zero set the multiplier to 1 + multiplier = control_flow_ops.cond( + math_ops.equal(global_step_recomp, 0), lambda: 1.0, + lambda: math_ops.ceil(global_step_recomp / decay_steps)) + decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier) + else: + # Make sure that the global_step used is not bigger than decay_steps. + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + + p = math_ops.div(global_step_recomp, decay_steps_recomp) + return math_ops.add( + math_ops.multiply(learning_rate - end_learning_rate, + math_ops.pow(1 - p, power)), + end_learning_rate, + name=name) + + return functools.partial( + decayed_lr, learning_rate, global_step, decay_steps, end_learning_rate, + power, cycle, name) + + +@tf_export("train.natural_exp_decay", v1=[]) +def natural_exp_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False, + name=None): + """Applies natural exponential decay to the initial learning rate. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies an exponential decay function + to a provided initial learning rate. It requires an `global_step` value to + compute the decayed learning rate. You can just pass a TensorFlow variable + that you increment at each training step. + + The function returns a no-arg callable that produces the decayed learning + rate. This can be useful for changing the learning rate value across + different invocations of optimizer functions. It is computed as: + + ```python + decayed_learning_rate = learning_rate * exp(-decay_rate * global_step / + decay_step) + ``` + + or, if `staircase` is `True`, as: + + ```python + decayed_learning_rate = learning_rate * exp(-decay_rate * floor(global_step / + decay_step)) + ``` + + Example: decay exponentially with a base of 0.96: + + ```python + ... + global_step = tf.Variable(0, trainable=False) + learning_rate = 0.1 + decay_steps = 5 + k = 0.5 + learning_rate_fn = tf.train.natural_exp_decay(learning_rate, global_step, + decay_steps, k) + + # Passing global_step to minimize() will increment it at each step. + learning_step = ( + tf.train.GradientDescentOptimizer(learning_rate_fn) + .minimize(...my loss..., global_step=global_step) + ) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The initial learning rate. + global_step: A Python number. + Global step to use for the decay computation. Must not be negative. + decay_steps: How often to apply decay. + decay_rate: A Python number. The decay rate. + staircase: Whether to apply decay in a discrete staircase, as opposed to + continuous, fashion. + name: String. Optional name of the operation. Defaults to + 'ExponentialTimeDecay'. + + Returns: + A no-arg function that outputs the decayed learning rate, a scalar `Tensor` + of the same type as `learning_rate`. + + Raises: + ValueError: if `global_step` is not supplied. + """ + if global_step is None: + raise ValueError("global_step is required for natural_exp_decay.") + def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, staircase, + name): + """Helper to recompute learning rate; most helpful in eager-mode.""" + with ops.name_scope(name, "NaturalExpDecay", + [learning_rate, global_step, decay_rate]) as name: + learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") + dtype = learning_rate.dtype + decay_steps = math_ops.cast(decay_steps, dtype) + decay_rate = math_ops.cast(decay_rate, dtype) + + global_step_recomp = math_ops.cast(global_step, dtype) + p = global_step_recomp / decay_steps + if staircase: + p = math_ops.floor(p) + exponent = math_ops.exp( + math_ops.multiply(math_ops.negative(decay_rate), p)) + return math_ops.multiply(learning_rate, exponent, name=name) + + return functools.partial(decayed_lr, learning_rate, global_step, decay_steps, + decay_rate, staircase, name) + + +@tf_export("train.inverse_time_decay", v1=[]) +def inverse_time_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False, + name=None): + """Applies inverse time decay to the initial learning rate. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies an inverse decay function + to a provided initial learning rate. It requires an `global_step` value to + compute the decayed learning rate. You can just pass a TensorFlow variable + that you increment at each training step. + + The function returns a no-arg callable that produces the decayed learning + rate. This can be useful for changing the learning rate value across + different invocations of optimizer functions. It is computed as: + + ```python + decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / + decay_step) + ``` + + or, if `staircase` is `True`, as: + + ```python + decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / + decay_step)) + ``` + + Example: decay 1/t with a rate of 0.5: + + ```python + ... + global_step = tf.Variable(0, trainable=False) + learning_rate = 0.1 + decay_steps = 1.0 + decay_rate = 0.5 + learning_rate_fn = tf.train.inverse_time_decay(learning_rate, global_step, + decay_steps, decay_rate) + + # Passing global_step to minimize() will increment it at each step. + learning_step = ( + tf.train.GradientDescentOptimizer(learning_rate_fn) + .minimize(...my loss..., global_step=global_step) + ) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` `Tensor` or a + Python number. The initial learning rate. + global_step: A Python number. + Global step to use for the decay computation. Must not be negative. + decay_steps: How often to apply decay. + decay_rate: A Python number. The decay rate. + staircase: Whether to apply decay in a discrete staircase, as opposed to + continuous, fashion. + name: String. Optional name of the operation. Defaults to + 'InverseTimeDecay'. + + Returns: + A no-arg function that outputs the decayed learning rate, a scalar `Tensor` + of the same type as `learning_rate`. + + Raises: + ValueError: if `global_step` is not supplied. + """ + if global_step is None: + raise ValueError("global_step is required for inverse_time_decay.") + def decayed_lr(learning_rate, global_step, decay_steps, decay_rate, staircase, + name): + """Helper to recompute learning rate; most helpful in eager-mode.""" + with ops.name_scope(name, "InverseTimeDecay", + [learning_rate, global_step, decay_rate]) as name: + learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") + dtype = learning_rate.dtype + decay_steps = math_ops.cast(decay_steps, dtype) + decay_rate = math_ops.cast(decay_rate, dtype) + + global_step_recomp = math_ops.cast(global_step, dtype) + p = global_step_recomp / decay_steps + if staircase: + p = math_ops.floor(p) + const = math_ops.cast(constant_op.constant(1), dtype) + denom = math_ops.add(const, math_ops.multiply(decay_rate, p)) + return math_ops.div(learning_rate, denom, name=name) + + return functools.partial(decayed_lr, learning_rate, global_step, decay_steps, + decay_rate, staircase, name) + + +@tf_export("train.cosine_decay", v1=[]) +def cosine_decay(learning_rate, global_step, decay_steps, alpha=0.0, + name=None): + """Applies cosine decay to the learning rate. + + See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent + with Warm Restarts. https://arxiv.org/abs/1608.03983 + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies a cosine decay function + to a provided initial learning rate. It requires a `global_step` value to + compute the decayed learning rate. You can just pass a TensorFlow variable + that you increment at each training step. + + The function returns a no-arg callable that produces the decayed learning + rate. This can be useful for changing the learning rate value across + different invocations of optimizer functions. It is computed as: + + ```python + global_step = min(global_step, decay_steps) + cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps)) + decayed = (1 - alpha) * cosine_decay + alpha + decayed_learning_rate = learning_rate * decayed + ``` + + Example usage: + ```python + decay_steps = 1000 + lr_decayed_fn = tf.train.cosine_decay(learning_rate, global_step, decay_steps) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` Tensor or a Python number. + The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. + Global step to use for the decay computation. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. + Number of steps to decay over. + alpha: A scalar `float32` or `float64` Tensor or a Python number. + Minimum learning rate value as a fraction of learning_rate. + name: String. Optional name of the operation. Defaults to 'CosineDecay'. + Returns: + A no-arg function that outputs the decayed learning rate, a scalar `Tensor` + of the same type as `learning_rate`. + Raises: + ValueError: if `global_step` is not supplied. + """ + if global_step is None: + raise ValueError("cosine decay requires global_step") + def decayed_lr(learning_rate, global_step, decay_steps, alpha, name): + """Helper to recompute learning rate; most helpful in eager-mode.""" + with ops.name_scope(name, "CosineDecay", + [learning_rate, global_step]) as name: + learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") + dtype = learning_rate.dtype + decay_steps = math_ops.cast(decay_steps, dtype) + + global_step_recomp = math_ops.cast(global_step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + completed_fraction = global_step_recomp / decay_steps + cosine_decayed = 0.5 * (1.0 + math_ops.cos( + constant_op.constant(math.pi) * completed_fraction)) + + decayed = (1 - alpha) * cosine_decayed + alpha + return math_ops.multiply(learning_rate, decayed) + + return functools.partial(decayed_lr, learning_rate, global_step, decay_steps, + alpha, name) + + +@tf_export("train.cosine_decay_restarts", v1=[]) +def cosine_decay_restarts(learning_rate, + global_step, + first_decay_steps, + t_mul=2.0, + m_mul=1.0, + alpha=0.0, + name=None): + """Applies cosine decay with restarts to the learning rate. + + See [Loshchilov & Hutter, ICLR2016], SGDR: Stochastic Gradient Descent + with Warm Restarts. https://arxiv.org/abs/1608.03983 + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies a cosine decay function with + restarts to a provided initial learning rate. It requires a `global_step` + value to compute the decayed learning rate. You can just pass a TensorFlow + variable that you increment at each training step. + + The function returns a no-arg callable that produces the decayed learning + rate while taking into account possible warm restarts. This can be useful for + changing the learning rate value across different invocations of optimizer + functions. + + The learning rate multiplier first decays + from 1 to `alpha` for `first_decay_steps` steps. Then, a warm + restart is performed. Each new warm restart runs for `t_mul` times more steps + and with `m_mul` times smaller initial learning rate. + + Example usage: + ```python + first_decay_steps = 1000 + lr_decayed_fn = tf.train.cosine_decay_restarts(learning_rate, global_step, + first_decay_steps) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` Tensor or a Python number. + The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. + Global step to use for the decay computation. + first_decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. + Number of steps to decay over. + t_mul: A scalar `float32` or `float64` `Tensor` or a Python number. + Used to derive the number of iterations in the i-th period + m_mul: A scalar `float32` or `float64` `Tensor` or a Python number. + Used to derive the initial learning rate of the i-th period: + alpha: A scalar `float32` or `float64` Tensor or a Python number. + Minimum learning rate value as a fraction of the learning_rate. + name: String. Optional name of the operation. Defaults to 'SGDRDecay'. + Returns: + A no-arg function that outputs the decayed learning rate, a scalar `Tensor` + of the same type as `learning_rate`. + + Raises: + ValueError: if `global_step` is not supplied. + """ + if global_step is None: + raise ValueError("cosine decay restarts requires global_step") + def decayed_lr(learning_rate, global_step, first_decay_steps, t_mul, m_mul, + alpha, name): + """Helper to recompute learning rate; most helpful in eager-mode.""" + with ops.name_scope(name, "SGDRDecay", [learning_rate, global_step] + ) as name: + learning_rate = ops.convert_to_tensor( + learning_rate, name="initial_learning_rate") + dtype = learning_rate.dtype + first_decay_steps = math_ops.cast(first_decay_steps, dtype) + alpha = math_ops.cast(alpha, dtype) + t_mul = math_ops.cast(t_mul, dtype) + m_mul = math_ops.cast(m_mul, dtype) + + global_step_recomp = math_ops.cast(global_step, dtype) + completed_fraction = global_step_recomp / first_decay_steps + + def compute_step(completed_fraction, geometric=False): + """Helper for `cond` operation.""" + if geometric: + i_restart = math_ops.floor( + math_ops.log(1.0 - completed_fraction * (1.0 - t_mul)) / + math_ops.log(t_mul)) + + sum_r = (1.0 - t_mul**i_restart) / (1.0 - t_mul) + completed_fraction = (completed_fraction - sum_r) / t_mul**i_restart + + else: + i_restart = math_ops.floor(completed_fraction) + completed_fraction -= i_restart + + return i_restart, completed_fraction + + i_restart, completed_fraction = control_flow_ops.cond( + math_ops.equal(t_mul, 1.0), + lambda: compute_step(completed_fraction, geometric=False), + lambda: compute_step(completed_fraction, geometric=True)) + + m_fac = m_mul**i_restart + cosine_decayed = 0.5 * m_fac * (1.0 + math_ops.cos( + constant_op.constant(math.pi) * completed_fraction)) + decayed = (1 - alpha) * cosine_decayed + alpha + + return math_ops.multiply(learning_rate, decayed, name=name) + + return functools.partial(decayed_lr, learning_rate, global_step, + first_decay_steps, t_mul, m_mul, alpha, name) + + +@tf_export("train.linear_cosine_decay", v1=[]) +def linear_cosine_decay(learning_rate, + global_step, + decay_steps, + num_periods=0.5, + alpha=0.0, + beta=0.001, + name=None): + """Applies linear cosine decay to the learning rate. + + See [Bello et al., ICML2017] Neural Optimizer Search with RL. + https://arxiv.org/abs/1709.07417 + + For the idea of warm starts here controlled by `num_periods`, + see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent + with Warm Restarts. https://arxiv.org/abs/1608.03983 + + Note that linear cosine decay is more aggressive than cosine decay and + larger initial learning rates can typically be used. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies a linear cosine decay function + to a provided initial learning rate. It requires a `global_step` value to + compute the decayed learning rate. You can just pass a TensorFlow variable + that you increment at each training step. + + The function returns a no-arg callable that produces the decayed learning + rate. This can be useful for changing the learning rate value across + different invocations of optimizer functions. It is computed as: + + ```python + global_step = min(global_step, decay_steps) + linear_decay = (decay_steps - global_step) / decay_steps) + cosine_decay = 0.5 * ( + 1 + cos(pi * 2 * num_periods * global_step / decay_steps)) + decayed = (alpha + linear_decay) * cosine_decay + beta + decayed_learning_rate = learning_rate * decayed + ``` + + Example usage: + ```python + decay_steps = 1000 + lr_decayed_fn = tf.train.linear_cosine_decay(learning_rate, global_step, + decay_steps) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` Tensor or a Python number. + The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. + Global step to use for the decay computation. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. + Number of steps to decay over. + num_periods: Number of periods in the cosine part of the decay. + See computation above. + alpha: See computation above. + beta: See computation above. + name: String. Optional name of the operation. Defaults to + 'LinearCosineDecay'. + Returns: + A no-arg function that outputs the decayed learning rate, a scalar `Tensor` + of the same type as `learning_rate`. + Raises: + ValueError: if `global_step` is not supplied. + """ + if global_step is None: + raise ValueError("linear cosine decay requires global_step") + def decayed_lr(learning_rate, global_step, decay_steps, num_periods, alpha, + beta, name): + """Helper to recompute learning rate; most helpful in eager-mode.""" + with ops.name_scope(name, "LinearCosineDecay", + [learning_rate, global_step]) as name: + learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") + dtype = learning_rate.dtype + decay_steps = math_ops.cast(decay_steps, dtype) + num_periods = math_ops.cast(num_periods, dtype) + alpha = math_ops.cast(alpha, dtype) + beta = math_ops.cast(beta, dtype) + + global_step_recomp = math_ops.cast(global_step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + linear_decayed = (decay_steps - global_step_recomp) / decay_steps + completed_fraction = global_step_recomp / decay_steps + fraction = 2.0 * num_periods * completed_fraction + cosine_decayed = 0.5 * ( + 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + + linear_cosine_decayed = (alpha + linear_decayed) * cosine_decayed + beta + return math_ops.multiply(learning_rate, linear_cosine_decayed, name=name) + + return functools.partial(decayed_lr, learning_rate, global_step, decay_steps, + num_periods, alpha, beta, name) + + +@tf_export("train.noisy_linear_cosine_decay", v1=[]) +def noisy_linear_cosine_decay(learning_rate, + global_step, + decay_steps, + initial_variance=1.0, + variance_decay=0.55, + num_periods=0.5, + alpha=0.0, + beta=0.001, + name=None): + """Applies noisy linear cosine decay to the learning rate. + + See [Bello et al., ICML2017] Neural Optimizer Search with RL. + https://arxiv.org/abs/1709.07417 + + For the idea of warm starts here controlled by `num_periods`, + see [Loshchilov & Hutter, ICLR2016] SGDR: Stochastic Gradient Descent + with Warm Restarts. https://arxiv.org/abs/1608.03983 + + Note that linear cosine decay is more aggressive than cosine decay and + larger initial learning rates can typically be used. + + When training a model, it is often recommended to lower the learning rate as + the training progresses. This function applies a noisy linear + cosine decay function to a provided initial learning rate. + It requires a `global_step` value to compute the decayed learning rate. + You can just pass a TensorFlow variable that you increment at each + training step. + + The function returns a no-arg callable that produces the decayed learning + rate. This can be useful for changing the learning rate value across + different invocations of optimizer functions. It is computed as: + + ```python + global_step = min(global_step, decay_steps) + linear_decay = (decay_steps - global_step) / decay_steps) + cosine_decay = 0.5 * ( + 1 + cos(pi * 2 * num_periods * global_step / decay_steps)) + decayed = (alpha + linear_decay + eps_t) * cosine_decay + beta + decayed_learning_rate = learning_rate * decayed + ``` + where eps_t is 0-centered gaussian noise with variance + initial_variance / (1 + global_step) ** variance_decay + + Example usage: + ```python + decay_steps = 1000 + lr_decayed_fn = tf.train.noisy_linear_cosine_decay(learning_rate, global_step, + decay_steps) + ``` + + Args: + learning_rate: A scalar `float32` or `float64` Tensor or a Python number. + The initial learning rate. + global_step: A scalar `int32` or `int64` `Tensor` or a Python number. + Global step to use for the decay computation. + decay_steps: A scalar `int32` or `int64` `Tensor` or a Python number. + Number of steps to decay over. + initial_variance: initial variance for the noise. See computation above. + variance_decay: decay for the noise's variance. See computation above. + num_periods: Number of periods in the cosine part of the decay. + See computation above. + alpha: See computation above. + beta: See computation above. + name: String. Optional name of the operation. Defaults to + 'NoisyLinearCosineDecay'. + Returns: + A no-arg function that outputs the decayed learning rate, a scalar `Tensor` + of the same type as `learning_rate`. + Raises: + ValueError: if `global_step` is not supplied. + """ + if global_step is None: + raise ValueError("noisy linear cosine decay requires global_step") + def decayed_lr(learning_rate, global_step, decay_steps, initial_variance, + variance_decay, num_periods, alpha, beta, name): + """Helper to recompute learning rate; most helpful in eager-mode.""" + with ops.name_scope(name, "NoisyLinearCosineDecay", + [learning_rate, global_step]) as name: + learning_rate = ops.convert_to_tensor(learning_rate, name="learning_rate") + dtype = learning_rate.dtype + decay_steps = math_ops.cast(decay_steps, dtype) + initial_variance = math_ops.cast(initial_variance, dtype) + variance_decay = math_ops.cast(variance_decay, dtype) + num_periods = math_ops.cast(num_periods, dtype) + alpha = math_ops.cast(alpha, dtype) + beta = math_ops.cast(beta, dtype) + + global_step_recomp = math_ops.cast(global_step, dtype) + global_step_recomp = math_ops.minimum(global_step_recomp, decay_steps) + linear_decayed = (decay_steps - global_step_recomp) / decay_steps + variance = initial_variance / ( + math_ops.pow(1.0 + global_step_recomp, variance_decay)) + std = math_ops.sqrt(variance) + noisy_linear_decayed = ( + linear_decayed + random_ops.random_normal( + linear_decayed.shape, stddev=std)) + + completed_fraction = global_step_recomp / decay_steps + fraction = 2.0 * num_periods * completed_fraction + cosine_decayed = 0.5 * ( + 1.0 + math_ops.cos(constant_op.constant(math.pi) * fraction)) + noisy_linear_cosine_decayed = ( + (alpha + noisy_linear_decayed) * cosine_decayed + beta) + + return math_ops.multiply( + learning_rate, noisy_linear_cosine_decayed, name=name) + + return functools.partial(decayed_lr, learning_rate, global_step, decay_steps, + initial_variance, variance_decay, num_periods, alpha, + beta, name) diff --git a/tensorflow/python/training/learning_rate_decay_v2_test.py b/tensorflow/python/training/learning_rate_decay_v2_test.py new file mode 100644 index 0000000000..0f2d60dafc --- /dev/null +++ b/tensorflow/python/training/learning_rate_decay_v2_test.py @@ -0,0 +1,497 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functional test for learning rate decay.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from tensorflow.python.eager import context +from tensorflow.python.framework import test_util +# Import resource_variable_ops for the variables-to-tensor implicit conversion. +from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest +from tensorflow.python.training import learning_rate_decay_v2 + + +class LRDecayTestV2(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def testContinuous(self): + self.evaluate(variables.global_variables_initializer()) + step = 5 + decayed_lr = learning_rate_decay_v2.exponential_decay(0.05, step, 10, 0.96) + expected = .05 * 0.96**(5.0 / 10.0) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testStaircase(self): + if context.executing_eagerly(): + step = resource_variable_ops.ResourceVariable(0) + self.evaluate(variables.global_variables_initializer()) + decayed_lr = learning_rate_decay_v2.exponential_decay( + .1, step, 3, 0.96, staircase=True) + + # No change to learning rate due to staircase + expected = .1 + self.evaluate(step.assign(1)) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + expected = .1 + self.evaluate(step.assign(2)) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + # Decayed learning rate + expected = .1 * 0.96 ** (100 // 3) + self.evaluate(step.assign(100)) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + def testVariables(self): + with self.test_session(): + step = variables.Variable(1) + assign_1 = step.assign(1) + assign_2 = step.assign(2) + assign_100 = step.assign(100) + decayed_lr = learning_rate_decay_v2.exponential_decay(.1, step, 3, 0.96, + staircase=True) + variables.global_variables_initializer().run() + # No change to learning rate + assign_1.op.run() + self.assertAllClose(decayed_lr().eval(), .1, 1e-6) + assign_2.op.run() + self.assertAllClose(decayed_lr().eval(), .1, 1e-6) + # Decayed learning rate + assign_100.op.run() + expected = .1 * 0.96 ** (100 // 3) + self.assertAllClose(decayed_lr().eval(), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testPiecewiseConstant(self): + x = resource_variable_ops.ResourceVariable(-999) + decayed_lr = learning_rate_decay_v2.piecewise_constant( + x, [100, 110, 120], [1.0, 0.1, 0.01, 0.001]) + + self.evaluate(variables.global_variables_initializer()) + + self.assertAllClose(self.evaluate(decayed_lr()), 1.0, 1e-6) + self.evaluate(x.assign(100)) + self.assertAllClose(self.evaluate(decayed_lr()), 1.0, 1e-6) + self.evaluate(x.assign(105)) + self.assertAllClose(self.evaluate(decayed_lr()), 0.1, 1e-6) + self.evaluate(x.assign(110)) + self.assertAllClose(self.evaluate(decayed_lr()), 0.1, 1e-6) + self.evaluate(x.assign(120)) + self.assertAllClose(self.evaluate(decayed_lr()), 0.01, 1e-6) + self.evaluate(x.assign(999)) + self.assertAllClose(self.evaluate(decayed_lr()), 0.001, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testPiecewiseConstantEdgeCases(self): + x_int = resource_variable_ops.ResourceVariable( + 0, dtype=variables.dtypes.int32) + boundaries, values = [-1.0, 1.0], [1, 2, 3] + with self.assertRaises(ValueError): + decayed_lr = learning_rate_decay_v2.piecewise_constant( + x_int, boundaries, values) + decayed_lr() + + x = resource_variable_ops.ResourceVariable(0.0) + boundaries, values = [-1.0, 1.0], [1.0, 2, 3] + with self.assertRaises(ValueError): + decayed_lr = learning_rate_decay_v2.piecewise_constant( + x, boundaries, values)() + decayed_lr() + + # Test that ref types are valid. + if not context.executing_eagerly(): + x = variables.Variable(0.0) + x_ref = x.op.outputs[0] # float32_ref tensor should be accepted + boundaries, values = [1.0, 2.0], [1, 2, 3] + learning_rate_decay_v2.piecewise_constant(x_ref, boundaries, values) + + # Test casting boundaries from int32 to int64. + x_int64 = resource_variable_ops.ResourceVariable( + 0, dtype=variables.dtypes.int64) + boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7] + decayed_lr = learning_rate_decay_v2.piecewise_constant( + x_int64, boundaries, values) + + self.evaluate(variables.global_variables_initializer()) + self.assertAllClose(self.evaluate(decayed_lr()), 0.4, 1e-6) + self.evaluate(x_int64.assign(1)) + self.assertAllClose(self.evaluate(decayed_lr()), 0.4, 1e-6) + self.evaluate(x_int64.assign(2)) + self.assertAllClose(self.evaluate(decayed_lr()), 0.5, 1e-6) + self.evaluate(x_int64.assign(3)) + self.assertAllClose(self.evaluate(decayed_lr()), 0.6, 1e-6) + self.evaluate(x_int64.assign(4)) + self.assertAllClose(self.evaluate(decayed_lr()), 0.7, 1e-6) + + +class LinearDecayTestV2(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def testHalfWay(self): + step = 5 + lr = 0.05 + end_lr = 0.0 + decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr) + expected = lr * 0.5 + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testEnd(self): + step = 10 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testHalfWayWithEnd(self): + step = 5 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr) + expected = (lr + end_lr) * 0.5 + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testBeyondEnd(self): + step = 15 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay_v2.polynomial_decay(lr, step, 10, end_lr) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testBeyondEndWithCycle(self): + step = 15 + lr = 0.05 + end_lr = 0.001 + decayed_lr = learning_rate_decay_v2.polynomial_decay( + lr, step, 10, end_lr, cycle=True) + expected = (lr - end_lr) * 0.25 + end_lr + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + +class SqrtDecayTestV2(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def testHalfWay(self): + step = 5 + lr = 0.05 + end_lr = 0.0 + power = 0.5 + decayed_lr = learning_rate_decay_v2.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = lr * 0.5**power + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testEnd(self): + step = 10 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay_v2.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testHalfWayWithEnd(self): + step = 5 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay_v2.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = (lr - end_lr) * 0.5**power + end_lr + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testBeyondEnd(self): + step = 15 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay_v2.polynomial_decay( + lr, step, 10, end_lr, power=power) + expected = end_lr + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testBeyondEndWithCycle(self): + step = 15 + lr = 0.05 + end_lr = 0.001 + power = 0.5 + decayed_lr = learning_rate_decay_v2.polynomial_decay( + lr, step, 10, end_lr, power=power, cycle=True) + expected = (lr - end_lr) * 0.25**power + end_lr + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + +class PolynomialDecayTestV2(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def testBeginWithCycle(self): + lr = 0.001 + decay_steps = 10 + step = 0 + decayed_lr = learning_rate_decay_v2.polynomial_decay( + lr, step, decay_steps, cycle=True) + expected = lr + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + +class ExponentialDecayTestV2(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def testDecay(self): + initial_lr = 0.1 + k = 10 + decay_rate = 0.96 + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay_v2.natural_exp_decay(initial_lr, step, k, + decay_rate) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr * math.exp(-i / k * decay_rate) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + self.evaluate(step.assign_add(1)) + + @test_util.run_in_graph_and_eager_modes + def testStaircase(self): + initial_lr = 0.1 + k = 10 + decay_rate = 0.96 + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay_v2.natural_exp_decay( + initial_lr, step, k, decay_rate, staircase=True) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr * math.exp(-decay_rate * (i // k)) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + self.evaluate(step.assign_add(1)) + + +class InverseDecayTestV2(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def testDecay(self): + initial_lr = 0.1 + k = 10 + decay_rate = 0.96 + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay_v2.inverse_time_decay(initial_lr, step, k, + decay_rate) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr / (1 + i / k * decay_rate) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + self.evaluate(step.assign_add(1)) + + @test_util.run_in_graph_and_eager_modes + def testStaircase(self): + initial_lr = 0.1 + k = 10 + decay_rate = 0.96 + step = resource_variable_ops.ResourceVariable(0) + decayed_lr = learning_rate_decay_v2.inverse_time_decay( + initial_lr, step, k, decay_rate, staircase=True) + + self.evaluate(variables.global_variables_initializer()) + for i in range(k + 1): + expected = initial_lr / (1 + decay_rate * (i // k)) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + self.evaluate(step.assign_add(1)) + + +class CosineDecayTestV2(test_util.TensorFlowTestCase): + + def np_cosine_decay(self, step, decay_steps, alpha=0.0): + step = min(step, decay_steps) + completed_fraction = step / decay_steps + decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) + return (1.0 - alpha) * decay + alpha + + @test_util.run_in_graph_and_eager_modes + def testDecay(self): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_decay_v2.cosine_decay(initial_lr, step, + num_training_steps) + expected = self.np_cosine_decay(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testAlpha(self): + num_training_steps = 1000 + initial_lr = 1.0 + alpha = 0.1 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_decay_v2.cosine_decay(initial_lr, step, + num_training_steps, + alpha) + expected = self.np_cosine_decay(step, num_training_steps, alpha) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + +class CosineDecayRestartsTestV2(test_util.TensorFlowTestCase): + + def np_cosine_decay_restarts(self, step, decay_steps, t_mul=2.0, m_mul=1.0, + alpha=0.0): + fac = 1.0 + while step >= decay_steps: + step -= decay_steps + decay_steps *= t_mul + fac *= m_mul + + completed_fraction = step / decay_steps + decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction)) + return (1.0 - alpha) * decay + alpha + + @test_util.run_in_graph_and_eager_modes + def testDecay(self): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_decay_v2.cosine_decay_restarts( + initial_lr, step, num_training_steps) + expected = self.np_cosine_decay_restarts(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testAlpha(self): + num_training_steps = 1000 + initial_lr = 1.0 + alpha = 0.1 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_decay_v2.cosine_decay_restarts( + initial_lr, step, num_training_steps, alpha=alpha) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, alpha=alpha) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testMMul(self): + num_training_steps = 1000 + initial_lr = 1.0 + m_mul = 0.9 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_decay_v2.cosine_decay_restarts( + initial_lr, step, num_training_steps, m_mul=m_mul) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, m_mul=m_mul) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testTMul(self): + num_training_steps = 1000 + initial_lr = 1.0 + t_mul = 1.0 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_decay_v2.cosine_decay_restarts( + initial_lr, step, num_training_steps, t_mul=t_mul) + expected = self.np_cosine_decay_restarts( + step, num_training_steps, t_mul=t_mul) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + +class LinearCosineDecayTestV2(test_util.TensorFlowTestCase): + + def np_linear_cosine_decay(self, + step, + decay_steps, + alpha=0.0, + beta=0.001, + num_periods=0.5): + step = min(step, decay_steps) + linear_decayed = float(decay_steps - step) / decay_steps + fraction = 2.0 * num_periods * step / float(decay_steps) + cosine_decayed = 0.5 * (1.0 + math.cos(math.pi * fraction)) + return (alpha + linear_decayed) * cosine_decayed + beta + + @test_util.run_in_graph_and_eager_modes + def testDefaultDecay(self): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_decay_v2.linear_cosine_decay( + initial_lr, step, num_training_steps) + expected = self.np_linear_cosine_decay(step, num_training_steps) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + @test_util.run_in_graph_and_eager_modes + def testNonDefaultDecay(self): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + decayed_lr = learning_rate_decay_v2.linear_cosine_decay( + initial_lr, + step, + num_training_steps, + alpha=0.1, + beta=1e-4, + num_periods=5) + expected = self.np_linear_cosine_decay( + step, num_training_steps, alpha=0.1, beta=1e-4, num_periods=5) + self.assertAllClose(self.evaluate(decayed_lr()), expected, 1e-6) + + +class NoisyLinearCosineDecayTestV2(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def testDefaultNoisyLinearCosine(self): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + # No numerical check because of noise + decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay( + initial_lr, step, num_training_steps) + # Cannot be deterministically tested + self.evaluate(decayed_lr()) + + @test_util.run_in_graph_and_eager_modes + def testNonDefaultNoisyLinearCosine(self): + num_training_steps = 1000 + initial_lr = 1.0 + for step in range(0, 1500, 250): + # No numerical check because of noise + decayed_lr = learning_rate_decay_v2.noisy_linear_cosine_decay( + initial_lr, + step, + num_training_steps, + initial_variance=0.5, + variance_decay=0.1, + alpha=0.1, + beta=1e-4, + num_periods=5) + # Cannot be deterministically tested + self.evaluate(decayed_lr()) + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py index 9702430a12..38216ce9b1 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import argparse +import functools from tensorflow.tools.compatibility import ast_edits from tensorflow.tools.compatibility import renames_v2 @@ -45,6 +46,29 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec): # Specially handled functions. self.function_handle = {} + for decay in ["tf.train.exponential_decay", "tf.train.piecewise_constant", + "tf.train.polynomial_decay", "tf.train.natural_exp_decay", + "tf.train.inverse_time_decay", "tf.train.cosine_decay", + "tf.train.cosine_decay_restarts", + "tf.train.linear_cosine_decay", + "tf.train.noisy_linear_cosine_decay"]: + self.function_handle[decay] = functools.partial( + self._learning_rate_decay_handler, decay_name=decay) + + @staticmethod + def _learning_rate_decay_handler(file_edit_recorder, node, decay_name): + comment = ("ERROR: %s has been changed to return a callable instead of a " + "tensor when graph building, but its functionality remains " + "unchanged during eager execution (returns a callable like " + "before). The converter cannot detect and fix this reliably, so " + "you need to inspect this usage manually.\n") % decay_name + file_edit_recorder.add( + comment, + node.lineno, + node.col_offset, + decay_name, + decay_name, + error="%s requires manual check." % decay_name) if __name__ == "__main__": diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py index 57ac04de06..3886c1e8b9 100644 --- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py +++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py @@ -63,6 +63,19 @@ class TestUpgrade(test_util.TensorFlowTestCase): _, unused_report, unused_errors, new_text = self._upgrade(text) self.assertEqual(new_text, "tf.math.rsqrt(tf.math.log(3.8))\n") + def testLearningRateDecay(self): + for decay in ["tf.train.exponential_decay", "tf.train.piecewise_constant", + "tf.train.polynomial_decay", "tf.train.natural_exp_decay", + "tf.train.inverse_time_decay", "tf.train.cosine_decay", + "tf.train.cosine_decay_restarts", + "tf.train.linear_cosine_decay", + "tf.train.noisy_linear_cosine_decay"]: + + text = "%s(a, b)\n" % decay + _, unused_report, errors, new_text = self._upgrade(text) + self.assertEqual(text, new_text) + self.assertEqual(errors, ["test.py:1: %s requires manual check." % decay]) + class TestUpgradeFiles(test_util.TensorFlowTestCase): -- GitLab From 6bd9f8fa0c17c55fc0c11ba0d9281cab1688b115 Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Wed, 5 Sep 2018 17:17:23 -0700 Subject: [PATCH 166/540] Rollforward of cl/211656888 after fixing failing unit test. *** Original change description *** Add HloSchedule class representing a sequential order of an HloModule. Currently we represent a sequential schedule of a module using a SequentialHloOrdering::HloModuleSequence which is a type alias of a bare map from HloComputation* to std::vector. This CL replaces this with a proper class which results in better encap... *** PiperOrigin-RevId: 211726890 --- tensorflow/compiler/xla/service/BUILD | 48 +++ .../compiler/xla/service/buffer_assignment.cc | 28 +- .../xla/service/buffer_assignment_test.cc | 98 ++--- .../xla/service/buffer_liveness_test.cc | 42 +-- .../compiler/xla/service/cpu/cpu_compiler.cc | 56 ++- .../compiler/xla/service/cpu/ir_emitter.cc | 2 +- .../compiler/xla/service/cpu/ir_emitter.h | 2 +- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/gpu_hlo_schedule.cc | 6 +- .../xla/service/gpu/gpu_hlo_schedule.h | 4 +- .../compiler/xla/service/heap_simulator.cc | 43 +-- .../compiler/xla/service/heap_simulator.h | 48 ++- .../xla/service/heap_simulator_test.cc | 36 +- .../xla/service/hlo_alias_analysis_test.cc | 16 +- .../xla/service/hlo_dataflow_analysis_test.cc | 29 +- .../compiler/xla/service/hlo_ordering.cc | 86 ++--- .../compiler/xla/service/hlo_ordering.h | 22 +- .../compiler/xla/service/hlo_ordering_test.cc | 101 ++++++ .../xla/service/hlo_rematerialization.cc | 87 ++--- .../xla/service/hlo_rematerialization.h | 19 +- .../xla/service/hlo_rematerialization_test.cc | 46 +-- .../compiler/xla/service/hlo_schedule.cc | 291 +++++++++++++++ .../compiler/xla/service/hlo_schedule.h | 151 ++++++++ .../compiler/xla/service/hlo_schedule_test.cc | 341 +++++++++++++++++ .../compiler/xla/service/hlo_scheduling.cc | 230 ++---------- .../compiler/xla/service/hlo_scheduling.h | 54 +-- .../xla/service/hlo_scheduling_test.cc | 343 +++--------------- 27 files changed, 1325 insertions(+), 905 deletions(-) create mode 100644 tensorflow/compiler/xla/service/hlo_schedule.cc create mode 100644 tensorflow/compiler/xla/service/hlo_schedule.h create mode 100644 tensorflow/compiler/xla/service/hlo_schedule_test.cc diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 64141ed191..ab86dce510 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -989,6 +989,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1036,6 +1037,7 @@ tf_cc_test( ":flatten_call_graph", ":hlo", ":hlo_ordering", + ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1049,6 +1051,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1062,6 +1065,7 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", + ":hlo_schedule", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1082,6 +1086,7 @@ tf_cc_test( ":hlo", ":hlo_dataflow_analysis", ":hlo_ordering", + ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1089,6 +1094,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1102,6 +1108,7 @@ cc_library( ":hlo", ":hlo_ordering", ":hlo_proto", + ":hlo_schedule", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1125,6 +1132,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", ], ) @@ -1169,6 +1177,43 @@ cc_library( ], ) +cc_library( + name = "hlo_schedule", + srcs = ["hlo_schedule.cc"], + hdrs = ["hlo_schedule.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "hlo_schedule_test", + srcs = ["hlo_schedule_test.cc"], + deps = [ + ":heap_simulator", + ":hlo", + ":hlo_dce", + ":hlo_ordering", + ":hlo_parser", + ":hlo_schedule", + ":hlo_scheduling", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + ], +) + cc_library( name = "hlo_scheduling", srcs = ["hlo_scheduling.cc"], @@ -1177,6 +1222,7 @@ cc_library( ":heap_simulator", ":hlo", ":hlo_ordering", + ":hlo_schedule", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1205,6 +1251,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", ], ) @@ -2366,6 +2413,7 @@ cc_library( ":hlo", ":hlo_dce", ":hlo_ordering", + ":hlo_schedule", ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 8b8c6bfd26..0f0af57626 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -617,18 +617,24 @@ Status BufferAssignment::ComputeSummaryStats() { } // Only compute total fragmentation if all computations have schedules. - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(module_); + bool schedule_complete = true; for (const auto& computation : module_->computations()) { - const std::vector* sequence = - liveness_->hlo_ordering().SequentialOrder(*computation); - if (sequence != nullptr) { - module_sequence.emplace(computation, *sequence); + if (!computation->IsFusionComputation()) { + const std::vector* sequence = + liveness_->hlo_ordering().SequentialOrder(*computation); + if (sequence == nullptr) { + schedule_complete = false; + } else { + schedule.set_sequence(computation, *sequence); + } } } - if (module_sequence.size() == module_->computation_count()) { + if (schedule_complete) { + TF_RETURN_IF_ERROR(schedule.Verify()); TF_ASSIGN_OR_RETURN( const int64 min_size, - HeapSimulator::MinimumMemoryForModule(module_sequence, buffer_size_)); + HeapSimulator::MinimumMemoryForModule(schedule, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } @@ -1064,7 +1070,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( // since buffers for kCall, kWhile, and kConditional sub-computations are // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(&assignment->module()); FlatSet all_buffers_to_assign; for (const auto& pair : buffers_to_assign_sequentially) { const HloComputation* computation = pair.first; @@ -1072,7 +1078,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const std::vector* instruction_sequence = hlo_ordering.SequentialOrder(*computation); CHECK(instruction_sequence != nullptr) << computation->name(); - module_sequence[computation] = *instruction_sequence; + schedule.set_sequence(computation, *instruction_sequence); all_buffers_to_assign.insert(buffers_to_assign.begin(), buffers_to_assign.end()); } @@ -1090,7 +1096,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HeapSimulator::Result result, HeapSimulator::Run(absl::make_unique( absl::make_unique(alignment)), - assignment->module(), module_sequence, + assignment->module(), schedule, assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, @@ -1121,7 +1127,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( HeapSimulator::Run( absl::make_unique( absl::make_unique(alignment)), - *computation, *instruction_sequence, + *computation, HloInstructionSequence(*instruction_sequence), assignment->points_to_analysis(), assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 56bd67fb55..5a231c173d 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -40,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/macros.h" namespace xla { @@ -120,14 +122,10 @@ class BufferAssignmentTest : public HloVerifiedTestBase { HloModule* module, absl::Span instruction_sequence, int64 alignment = 1) { - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[module->entry_computation()] = - std::vector(instruction_sequence.begin(), - instruction_sequence.end()); + HloSchedule schedule(module); + schedule.set_sequence(module->entry_computation(), instruction_sequence); return BufferAssigner::Run( - module, - absl::make_unique(module, - module_sequence), + module, absl::make_unique(schedule), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -1785,11 +1783,10 @@ class WhileBufferAssignmentTest : public HloVerifiedTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { - auto sequence = - ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); + HloSchedule schedule = + ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, - absl::make_unique(module, sequence), + module, absl::make_unique(schedule), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }, /*allow_input_output_aliasing=*/false, @@ -2096,17 +2093,25 @@ TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { // Create a sequential order among all the instructions in the entry // computation, since the issue this test stresses depends on the order the // nodes are traversed during BufferAssignment. - SequentialHloOrdering::HloModuleSequence sequence; - sequence[module->entry_computation()] = { - token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}; + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + schedule.set_sequence( + module->entry_computation(), + {token, infeed, infeed_data, while0, while1, zero, add, while2, tuple}); + TF_ASSERT_OK(schedule.Verify()); + TF_ASSERT_OK_AND_ASSIGN( auto assignment, - BufferAssigner::Run( - module, absl::make_unique(module, sequence), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module, + absl::make_unique(schedule), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // The result tuple elements must be assigned with different buffers. TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); @@ -2263,29 +2268,6 @@ ENTRY Main { GetAllocation(*buffers, param0, {1, 1})); } -static bool IsPostOrderTraversal( - const std::vector& sequence) { - tensorflow::gtl::FlatSet seen_so_far; - auto has_not_been_seen_yet = [&](const HloInstruction* instruction) { - return seen_so_far.count(instruction) == 0; - }; - - for (auto instruction : sequence) { - if (std::any_of(instruction->operands().begin(), - instruction->operands().end(), has_not_been_seen_yet) || - std::any_of(instruction->control_predecessors().begin(), - instruction->control_predecessors().end(), - has_not_been_seen_yet)) { - return false; // Not a post order. - } - if (!seen_so_far.insert(instruction).second) { - return false; // Not a "traversal". - } - } - - return true; -} - TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); @@ -2340,27 +2322,27 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { RunCopyInsertion(module); - auto sequence = - ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); + HloSchedule schedule = + ScheduleModule(*module, ByteSizeOf).ConsumeValueOrDie(); - // To trigger b/38494731, we want a specific Hlo sequence for the + // To trigger b/38494731, we want a specific Hlo schedule for the // root computation, so we overwrite that entry with a manually // crafted sequence. - sequence[module->entry_computation()] = { - input1, weights1, one, output1, while1->operand(0), while1, - input0, weights0, zero, output0, while0->operand(0), while0, - gte0, gte1, root_add}; + schedule.set_sequence(module->entry_computation(), + {input1, weights1, one, output1, while1->operand(0), + while1, input0, weights0, zero, output0, + while0->operand(0), while0, gte0, gte1, root_add}); - // If this ASSERT_TRUE fails, we constructed a bogus sequence above - // and this test itself is buggy. - ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); + // If this ASSERT fails, we constructed a bogus sequence above and this test + // itself is buggy. + TF_ASSERT_OK(schedule.Verify()); auto assignment = - BufferAssigner::Run( - module, absl::make_unique(module, sequence), - ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true) + BufferAssigner::Run(module, + absl::make_unique(schedule), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 26e26e316d..414bfe7999 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -166,12 +167,12 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { auto module = CreateNewModule(); HloComputation* entry = module->AddEntryComputation(builder.Build()); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique( - module.get(), sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param0, negate, param1, exp, add}); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -291,13 +292,12 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - SequentialHloOrdering::HloModuleSequence module_sequence; - std::vector order = {param, negate, exp, add}; - module_sequence.emplace(computation, order); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, {param, negate, exp, add}); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) + .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -339,14 +339,14 @@ TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build(add)); - SequentialHloOrdering::HloModuleSequence module_sequence; - std::vector order = {param, add, recv, - recv_done, send, send_done}; - module_sequence.emplace(computation, order); - auto liveness = BufferLiveness::Run(module.get(), - absl::make_unique( - module.get(), module_sequence)) - .ConsumeValueOrDie(); + HloSchedule schedule(module.get()); + schedule.set_sequence(computation, + {param, add, token, recv, recv_done, send, send_done}); + TF_ASSERT_OK(schedule.Verify()); + auto liveness = + BufferLiveness::Run(module.get(), + absl::make_unique(schedule)) + .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); // Check the root instruction (add) buffer interferes with the recv buffer. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 796f36510e..e7b6075994 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -584,16 +584,14 @@ StatusOr> CpuCompiler::RunBackend( // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), - DFSMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, BufferSizeBytesFunction(), DFSMemoryScheduler)); // Run buffer allocation on the HLO graph. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run(module.get(), - absl::make_unique( - module.get(), module_sequence), + absl::make_unique(schedule), BufferSizeBytesFunction(), memory_alignment, /*allow_input_output_aliasing=*/false, /*allocate_buffers_for_constants=*/true)); @@ -627,9 +625,10 @@ StatusOr> CpuCompiler::RunBackend( } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) + .EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &schedule.sequence(embedded_computation).instructions()) .status()); } string function_name_prefix = entry_computation->name().empty() @@ -637,9 +636,10 @@ StatusOr> CpuCompiler::RunBackend( : entry_computation->name(); TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, - ir_emitter.EmitComputation(entry_computation, function_name_prefix, - /*is_top_level_computation=*/true, - &module_sequence.at(entry_computation))); + ir_emitter.EmitComputation( + entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + &schedule.sequence(entry_computation).instructions())); string function_name = [&]() { llvm::SmallVector function_name_vector; @@ -771,20 +771,18 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, VLOG(2) << "After optimization:"; XLA_VLOG_LINES(2, module->ToString()); - TF_ASSIGN_OR_RETURN( - SequentialHloOrdering::HloModuleSequence module_sequence, - ScheduleComputationsInModule(*module, BufferSizeBytesFunction())); + TF_ASSIGN_OR_RETURN(HloSchedule schedule, + ScheduleModule(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module, - absl::make_unique(module, module_sequence), - BufferSizeBytesFunction(), memory_alignment, - /*allow_input_output_aliasing=*/false, - /*allocate_buffers_for_constants=*/true)); + BufferAssigner::Run(module, + absl::make_unique(schedule), + BufferSizeBytesFunction(), memory_alignment, + /*allow_input_output_aliasing=*/false, + /*allocate_buffers_for_constants=*/true)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -824,18 +822,18 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, } TF_RETURN_IF_ERROR( ir_emitter - .EmitComputation(embedded_computation, - embedded_computation->name(), - /*is_top_level_computation=*/false, - &module_sequence.at(embedded_computation)) + .EmitComputation( + embedded_computation, embedded_computation->name(), + /*is_top_level_computation=*/false, + &schedule.sequence(embedded_computation).instructions()) .status()); } const string& entry_point_name = options.entry_point_name(); - TF_ASSIGN_OR_RETURN( - llvm::Function * entry_function, - ir_emitter.EmitComputation(computation, entry_point_name, - /*is_top_level_computation=*/true, - &module_sequence.at(computation))); + TF_ASSIGN_OR_RETURN(llvm::Function * entry_function, + ir_emitter.EmitComputation( + computation, entry_point_name, + /*is_top_level_computation=*/true, + &schedule.sequence(computation).instructions())); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index e5cf15c686..df8c2a636b 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -110,7 +110,7 @@ IrEmitter::IrEmitter( StatusOr IrEmitter::EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - std::vector* instruction_order) { + const std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]; ordered? " << (instruction_order != nullptr); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 58a333b8fb..3df99464ba 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -98,7 +98,7 @@ class IrEmitter : public DfsHloVisitorWithDefault, StatusOr EmitComputation( HloComputation* computation, const string& function_name_prefix, bool is_top_level_computation, - std::vector* instruction_order); + const std::vector* instruction_order); llvm::IRBuilder<>* b() { return &b_; } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a68b7a1bef..13ccff35f8 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -813,6 +813,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", + "//tensorflow/compiler/xla/service:hlo_schedule", "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc index 743035a84e..ea9376e101 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" @@ -198,11 +199,12 @@ StatusOr> GpuHloSchedule::Build( // All kernels are launched on a single stream, so there's no loss of // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( - schedule->thunk_launch_order_, - ScheduleOneComputation( + HloInstructionSequence sequence, + ScheduleComputation( *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); + schedule->thunk_launch_order_ = sequence.instructions(); } else { // BFS tends to increase concurrency, but also increases memory usage. BFSLaunchOrder(entry_computation, &schedule->thunk_launch_order_); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h index 30a0e7cecd..07a7fc67aa 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_schedule.h @@ -33,7 +33,9 @@ namespace gpu { // launches, because thunks may be scheduled onto concurrent streams. This // schedule is used by BufferAssigner to determine buffer liveness (i.e. to // minimize allocations), and also by ThunkSchedule to determine the thunk -// launch order. +// launch order. This class differs from xla::HloSchedule in that HloSchedule +// represents a total order of all instructions in the module for backends which +// execute HLO instructions strictly sequentially. class GpuHloSchedule { public: // Constructs an GpuHloSchedule for the given module, based on the given diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 38c3982ebf..e0f3a7e0e2 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -29,13 +29,13 @@ using tensorflow::gtl::FlatSet; /*static*/ StatusOr HeapSimulator::MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { + if (schedule.empty()) { return 0; } - const HloModule* module = module_sequence.begin()->first->parent(); + const HloModule* module = schedule.module(); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(module)); @@ -47,14 +47,13 @@ StatusOr HeapSimulator::MinimumMemoryForModule( TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(absl::make_unique(), *module, - module_sequence, *points_to_analysis, size_function)); + schedule, *points_to_analysis, size_function)); return result.heap_size; } /*static*/ StatusOr HeapSimulator::MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap* @@ -71,13 +70,13 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options) { - HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); + HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule); const HloComputation* entry_computation = module.entry_computation(); - const std::vector& instruction_sequence = - FindOrDie(module_sequence, entry_computation); + const HloInstructionSequence& instruction_sequence = + schedule.sequence(entry_computation); TF_RETURN_IF_ERROR(heap.RunComputation( *entry_computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -86,13 +85,13 @@ StatusOr HeapSimulator::Run( /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, const tensorflow::gtl::FlatMap* memory_by_computation) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*module_sequence=*/nullptr, memory_by_computation); + /*schedule=*/nullptr, memory_by_computation); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); return heap.Finish(); @@ -102,7 +101,7 @@ StatusOr HeapSimulator::Run( // 'instruction_sequence'. Status HeapSimulator::RunComputation( const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis) { VLOG(3) << "Computation:\n" << computation.ToString(); // The goal here is to minimize memory usage, assuming the given sequential @@ -133,7 +132,8 @@ Status HeapSimulator::RunComputation( // set of instructions that need to be visited contains all users of all // aliases, that is, all users of all instructions that have the buffer // contained in their points-to set. - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const PointsToSet& points_to = points_to_analysis.GetPointsToSet(instruction); const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); @@ -166,7 +166,8 @@ Status HeapSimulator::RunComputation( std::vector dead_buffers_to_free; std::vector operand_buffers_to_free; - for (const HloInstruction* instruction : instruction_sequence) { + for (const HloInstruction* instruction : + instruction_sequence.instructions()) { const TuplePointsToAnalysis::BufferDefinitionVector& buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); @@ -285,14 +286,14 @@ Status HeapSimulator::RunComputation( // The order that the sub-computations are simulated does not affect // correctness; since the whole module has been scheduled, we know that the // sub-computations will never be run concurrently. - if (module_sequence_ != nullptr) { + if (schedule_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || instruction->opcode() == HloOpcode::kConditional || instruction->opcode() == HloOpcode::kWhile) { for (const HloComputation* called_computation : instruction->called_computations()) { - const std::vector& called_sequence = - FindOrDie(*module_sequence_, called_computation); + const HloInstructionSequence& called_sequence = + schedule_->sequence(called_computation); TF_RETURN_IF_ERROR(RunComputation( *called_computation, called_sequence, points_to_analysis)); } @@ -343,16 +344,16 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence, + const HloSchedule* schedule, const tensorflow::gtl::FlatMap* memory_by_computation) : no_fragmentation_stats_(absl::make_unique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - module_sequence_(module_sequence), + schedule_(schedule), memory_by_computation_(memory_by_computation) { - debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); + debug_trace_.set_whole_module_simulation(schedule_ != nullptr); } HeapSimulator::~HeapSimulator() {} diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index af05bedee7..ffbf947d5a 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -88,23 +89,22 @@ class HeapSimulator { // Returns the minimum memory required to compute an HLO module where all // computations have been scheduled (represented by the given - // module_sequence), assuming no fragmentation. + // schedule), assuming no fragmentation. static StatusOr MinimumMemoryForModule( - const SequentialHloOrdering::HloModuleSequence& module_sequence, + const HloSchedule& schedule, const LogicalBuffer::SizeFunction& size_function); // Returns the minimum memory required to compute the given computation, // assuming no fragmentation. static StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, + const HloComputation& computation, const HloInstructionSequence& sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap* memory_by_computation = nullptr); // Run the heap simulation with the given algorithm, assuming the given - // module_sequence, which must contain a topologically-consistent total + // schedule, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid // if instructions are not run in exactly this sequence. // @@ -112,12 +112,12 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - static StatusOr Run( - std::unique_ptr algorithm, const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + static StatusOr Run(std::unique_ptr algorithm, + const HloModule& module, + const HloSchedule& schedule, + const TuplePointsToAnalysis& points_to_analysis, + const BufferValue::SizeFunction& size_fn, + const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' // must contain a topologically-consistent total ordering of all instructions @@ -126,7 +126,7 @@ class HeapSimulator { static StatusOr Run( std::unique_ptr algorithm, const HloComputation& computation, - const std::vector& instruction_sequence, + const HloInstructionSequence& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const BufferValue::SizeFunction& size_fn, const Options& options = Options(), @@ -134,21 +134,19 @@ class HeapSimulator { memory_by_computation = nullptr); private: - // If 'module_sequence' is non-null, it is used to find kCall and kWhile + // If 'schedule' is non-null, it is used to find kCall and kWhile // sub-computations, and the heap simulation for those sub-computations will // be run recursively. I.e. the simulation is run over the whole module. - HeapSimulator( - std::unique_ptr algorithm, - const BufferValue::SizeFunction& size_fn, const Options& options, - const SequentialHloOrdering::HloModuleSequence* module_sequence = nullptr, - const tensorflow::gtl::FlatMap* - memory_by_computation = nullptr); + HeapSimulator(std::unique_ptr algorithm, + const BufferValue::SizeFunction& size_fn, + const Options& options, const HloSchedule* schedule = nullptr, + const tensorflow::gtl::FlatMap* + memory_by_computation = nullptr); ~HeapSimulator(); - Status RunComputation( - const HloComputation& computation, - const std::vector& instruction_sequence, - const TuplePointsToAnalysis& points_to_analysis); + Status RunComputation(const HloComputation& computation, + const HloInstructionSequence& instruction_sequence, + const TuplePointsToAnalysis& points_to_analysis); bool IgnoreBuffer(const BufferValue* buffer) const; void Alloc(const BufferValue* buffer, const HloInstruction* instruction); @@ -169,11 +167,11 @@ class HeapSimulator { const std::unique_ptr algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; - // module_sequence_ is set by buffer assignment, and memory_by_computation_ is + // schedule_ is set by buffer assignment, and memory_by_computation_ is // set by hlo scheduling. Then, in RunComputation, we check both in order to // handle subcomputations. It would be good to unify the handling of // subcomputations, but it's not clear how. - const SequentialHloOrdering::HloModuleSequence* module_sequence_; + const HloSchedule* schedule_; const tensorflow::gtl::FlatMap* memory_by_computation_; diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 7ad8a107e1..00a25db467 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { @@ -85,13 +86,16 @@ TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, HeapSimulator::MinimumMemoryForModule(module_sequence, size_fn) - .ValueOrDie()); + HloSchedule schedule(module.get()); + schedule.set_sequence(cond_computation, + {cond_param, cond_iter, cond_data, cond_lt}); + schedule.set_sequence(body_computation, {body_param}); + schedule.set_sequence(entry_computation, {iter, data, tuple, while_op}); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ( + 56, + HeapSimulator::MinimumMemoryForModule(schedule, size_fn).ValueOrDie()); } const char kAlloc[] = "Alloc"; @@ -149,10 +153,11 @@ class HeapSimulatorTracker { auto zero_size = [](const BufferValue& buffer) { return 0; }; auto algorithm = absl::make_unique( absl::make_unique(&actual_calls_)); - result_ = HeapSimulator::Run( - std::move(algorithm), *module_->entry_computation(), - instruction_sequence, *points_to_analysis_, zero_size) - .ConsumeValueOrDie(); + result_ = + HeapSimulator::Run(std::move(algorithm), *module_->entry_computation(), + HloInstructionSequence(instruction_sequence), + *points_to_analysis_, zero_size) + .ConsumeValueOrDie(); } explicit HeapSimulatorTracker(const string& name) { @@ -168,11 +173,12 @@ class HeapSimulatorTracker { TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); // Construct the module sequence grouped by computation. - SequentialHloOrdering::HloModuleSequence module_sequence; + HloSchedule schedule(module_.get()); tensorflow::gtl::FlatMap reverse_position; for (int i = 0; i < full_module_sequence.size(); ++i) { const HloInstruction* instruction = full_module_sequence[i]; - module_sequence[instruction->parent()].push_back(instruction); + schedule.GetOrCreateSequence(instruction->parent()) + .push_back(instruction); reverse_position[instruction] = full_module_sequence.size() - i; } @@ -185,8 +191,8 @@ class HeapSimulatorTracker { }; auto algorithm = absl::make_unique( absl::make_unique(&actual_calls_)); - result_ = HeapSimulator::Run(std::move(algorithm), *module_, - module_sequence, *points_to_analysis_, size_fn) + result_ = HeapSimulator::Run(std::move(algorithm), *module_, schedule, + *points_to_analysis_, size_fn) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 54abe3345d..0cd0ab36fc 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -885,18 +885,20 @@ TEST_F(HloAliasAnalysisTest, WhileInterference) { // For a sequential order, if there is interference iff the negate is after // the while. - SequentialHloOrdering::HloModuleSequence sequence; - sequence[body] = {body_param, body_root}; - sequence[condition] = {cond_param, cond_root}; + HloSchedule schedule(module_); + schedule.set_sequence(body, {body_param, body_root}); + schedule.set_sequence(condition, {cond_param, cond_root}); { - sequence[entry] = {init, xla_while, negate, entry_root}; - SequentialHloOrdering ordering(module_, sequence); + schedule.set_sequence(entry, {init, xla_while, negate, entry_root}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); } { - sequence[entry] = {init, negate, xla_while, entry_root}; - SequentialHloOrdering ordering(module_, sequence); + schedule.set_sequence(entry, {init, negate, xla_while, entry_root}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); } } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 72b236801a..510d6360a1 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -1261,9 +1262,10 @@ TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param0, negate, param1, exp, add}}); - SequentialHloOrdering ordering(module_.get(), sequence); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param0, negate, param1, exp, add}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -1339,14 +1341,16 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { bool ssa_form = GetParam(); RunAnalysis(ssa_form); - SequentialHloOrdering::HloModuleSequence sequence; - sequence.insert({entry, {param, xla_while}}); - sequence.insert({condition, {cond_param, cond_constant}}); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param, xla_while}); + schedule.set_sequence(condition, {cond_param, cond_constant}); // Construct the order such that 'constant' and its use 'exp' are before // body_param. - sequence.insert({body, {constant, exp, body_param, add}}); + schedule.set_sequence( + body, {constant, exp, body_param, add, dead_constant, dead_negate}); + TF_ASSERT_OK(schedule.Verify()); - SequentialHloOrdering ordering(module_.get(), sequence); + SequentialHloOrdering ordering(schedule); // 'add' is live out of the body and will interfere with an later instructions // such as 'dead_constant' and 'dead_negate'. @@ -1476,11 +1480,10 @@ TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { auto entry = module_->AddEntryComputation(builder.Build()); RunAnalysis(GetParam()); - SequentialHloOrdering::HloModuleSequence sequence; - std::vector order = {param, negate, exp, add}; - sequence.emplace(entry, order); - - SequentialHloOrdering ordering(module_.get(), sequence); + HloSchedule schedule(module_.get()); + schedule.set_sequence(entry, {param, negate, exp, add}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 0581d5c404..2105f7a349 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -252,6 +253,12 @@ bool HloOrdering::LiveRangeStrictlyBefore( VLOG(4) << a << " not defined before " << b; return false; } + + if (a.live_out_of_module()) { + VLOG(4) << a << " is live out of module and defined before " << b; + return false; + } + // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), @@ -264,6 +271,18 @@ bool HloOrdering::LiveRangeStrictlyBefore( return false; } } + + if (a.instruction()->parent() == b.instruction()->parent()) { + for (const HloPosition& position : a.positions()) { + if (position.instruction == + a.instruction()->parent()->root_instruction()) { + VLOG(4) << a << " is live out of computation and defined before " << b + << " which is in same computation"; + return false; + } + } + } + return true; } @@ -336,15 +355,24 @@ string DependencyHloOrdering::ToString() const { return ToStringHelper("DependencyHloOrdering"); } -SequentialHloOrdering::SequentialHloOrdering( - const HloModule* module, const HloModuleSequence& module_sequence) - : HloOrdering(module), module_sequence_(module_sequence) { +SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule) + : HloOrdering(schedule.module()), schedule_(schedule) { + Initialize(); +} + +SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule) + : HloOrdering(schedule.module()), schedule_(std::move(schedule)) { + Initialize(); +} + +void SequentialHloOrdering::Initialize() { // Create a map from instruction to its order position. - for (auto computation_order : module_sequence_) { - const std::vector& order = computation_order.second; + TF_DCHECK_OK(schedule_.Verify()); + for (const auto& computation_sequence : schedule_.sequences()) { + const std::vector& order = + computation_sequence.second.instructions(); for (int i = 0; i < order.size(); ++i) { - DCHECK_EQ(0, order_position_.count(order[i])); - order_position_.emplace(order[i], i); + InsertOrDie(&order_position_, order[i], i); } } } @@ -362,49 +390,13 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( const std::vector* SequentialHloOrdering::SequentialOrder( const HloComputation& computation) const { - auto find_it = module_sequence_.find(&computation); - return find_it == module_sequence_.end() ? nullptr : &find_it->second; + return schedule_.is_computation_scheduled(&computation) + ? &schedule_.sequence(&computation).instructions() + : nullptr; } string SequentialHloOrdering::ToString() const { - std::vector pieces; - pieces.push_back("SequentialHloOrdering"); - for (auto* computation : module_->computations()) { - pieces.push_back( - absl::StrFormat("computation %s order:", computation->name())); - // Gather all instructions in the module sequence for this computation and - // sort them by their position. - std::vector instructions; - for (auto& instruction_position : order_position_) { - const HloInstruction* instruction = instruction_position.first; - if (instruction->parent() == computation) { - instructions.push_back(instruction); - } - } - std::sort(instructions.begin(), instructions.end(), - [this](const HloInstruction* a, const HloInstruction* b) { - return order_position_.at(a) < order_position_.at(b); - }); - for (auto instruction : instructions) { - pieces.push_back(absl::StrFormat(" %s", instruction->name())); - } - } - return absl::StrJoin(pieces, "\n"); -} - -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence) { - for (auto computation_pair : module_sequence) { - const HloComputation* computation = computation_pair.first; - const std::vector& computation_sequence = - computation_pair.second; - out << "Computation " << computation->name() << ":\n"; - for (auto* instruction : computation_sequence) { - out << " " << instruction->name() << "\n"; - } - } - return out; + return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index 985f3fa64d..b21071c4b2 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -183,17 +184,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: - // TODO(dimvar): HloModuleSequence is not a good name because it sounds like - // a sequence of modules, instead of a map of schedules for all computations - // in a module. We should change it at some point. - // - // A sequence of instructions for each computation in the module. - using HloModuleSequence = - tensorflow::gtl::FlatMap>; - - SequentialHloOrdering(const HloModule* module, - const HloModuleSequence& module_sequence); + SequentialHloOrdering(const HloSchedule& schedule); + SequentialHloOrdering(HloSchedule&& schedule); ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. @@ -203,10 +195,12 @@ class SequentialHloOrdering : public HloOrdering { string ToString() const override; protected: + void Initialize(); + bool ExecutesBeforeInSameComputation(const HloInstruction* a, const HloInstruction* b) const override; - const HloModuleSequence module_sequence_; + const HloSchedule schedule_; // The position of every instruction in the HLO module in its respective // computation sequence (a value of zero indicates the instruction is first in @@ -217,10 +211,6 @@ class SequentialHloOrdering : public HloOrdering { tensorflow::gtl::FlatMap order_position_; }; -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 126d3a2d9c..6b6005e7a5 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -23,11 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -376,5 +378,104 @@ ENTRY root { dataflow->GetValueDefinedAt(add_3))); } +TEST_F(HloOrderingTest, + ValuesLiveOutOfModuleInterfereWithInstructionsAfterRoot) { + // Tests that values live out of the module should interfere with values + // defined after the root instruction. That is: + // + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloComputation* entry = + module->AddEntryComputation(builder.Build(/*root_instruction=*/root)); + + HloSchedule schedule(module.get()); + schedule.set_sequence(entry, {param, root, dead}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + +TEST_F(HloOrderingTest, + ValuesLiveOutOfComputationInterfereWithInstructionsAfterRoot) { + // Tests that values live out of a computation should interfere with values + // defined after the root instruction of the computation. That is: + // + // subcomputation: + // %param = param(0) + // ROOT %root = negate(%param) + // %dead = Constant(123.0) + // + // entry computation: + // %c = constant(42.0) + // ROOT %call = call({%c}), subcomputation + // + // %root should interfere with %dead. + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto subbuilder = HloComputation::Builder(TestName() + ".sub"); + HloInstruction* param = subbuilder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* root = subbuilder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + HloInstruction* dead = subbuilder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloComputation* subcomputation = module->AddEmbeddedComputation( + subbuilder.Build(/*root_instruction=*/root)); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction* call = builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {c}, subcomputation)); + HloComputation* entry = module->AddEntryComputation(builder.Build()); + + HloSchedule schedule(module.get()); + schedule.set_sequence(subcomputation, {param, root, dead}); + schedule.set_sequence(entry, {c, call}); + TF_ASSERT_OK(schedule.Verify()); + SequentialHloOrdering ordering(schedule); + + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); + + EXPECT_TRUE(ordering.ExecutesBefore(root, dead)); + EXPECT_FALSE(ordering.ExecutesBefore(dead, root)); + + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(root), dataflow->GetValueDefinedAt(dead), + *dataflow)); + + EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(root), + dataflow->GetValueDefinedAt(dead), + *dataflow)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index c9629926ea..0a0a6a323e 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -962,8 +962,7 @@ StatusOr HloRematerialization::CalledComputationsMemoryUsage( } StatusOr HloRematerialization::RematerializeComputation( - HloComputation* computation, - SequentialHloOrdering::HloModuleSequence* sequence, + HloComputation* computation, HloSchedule* schedule, int64 memory_limit_bytes) { VLOG(1) << "Rematerializing computation " << computation->name() << " with limit " << HumanReadableNumBytes(memory_limit_bytes); @@ -971,7 +970,8 @@ StatusOr HloRematerialization::RematerializeComputation( << HumanReadableNumBytes(computation_peak_memory_.at(computation)); CHECK(!ContainsKey(rematerialized_computations_, computation)); - InstructionList instruction_list(sequence->at(computation)); + InstructionList instruction_list( + schedule->sequence(computation).instructions()); MemoryUsageTracker memory_tracker(computation, size_function_, *points_to_analysis_, instruction_list); bool changed = false; @@ -1145,7 +1145,7 @@ StatusOr HloRematerialization::RematerializeComputation( 0, memory_limit_bytes - memory_tracker.memory_usage()); TF_ASSIGN_OR_RETURN( bool subcomputation_changed, - RematerializeComputation(called_computation, sequence, + RematerializeComputation(called_computation, schedule, subcomputation_memory_limit_bytes)); changed |= subcomputation_changed; } @@ -1179,12 +1179,12 @@ StatusOr HloRematerialization::RematerializeComputation( computation_peak_memory_.at(computation) = peak_memory; // Update order to include rematerialized instructions. - auto& dst = sequence->at(computation); - dst.clear(); + HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation); + sequence.clear(); for (auto* item = instruction_list.first(); item != nullptr; item = instruction_list.next(item)) { const HloInstruction* instruction = item->instruction; - dst.push_back(instruction); + sequence.push_back(instruction); } rematerialized_computations_.insert(computation); @@ -1194,20 +1194,21 @@ StatusOr HloRematerialization::RematerializeComputation( return changed; } -StatusOr HloRematerialization::Run( - HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence, - int64 memory_limit_bytes, RematerializationSizes* sizes, - CopyInsertion* copy_insertion) { - // The sequence is constructed entirely by this method. - TF_RET_CHECK(sequence->empty()); +StatusOr HloRematerialization::Run(HloModule* module, + HloSchedule* schedule, + int64 memory_limit_bytes, + RematerializationSizes* sizes, + CopyInsertion* copy_insertion) { + // The schedule is constructed entirely by this method. + TF_RET_CHECK(schedule->empty()); VLOG(1) << "HloRematerialization() with memory limit of " << HumanReadableNumBytes(memory_limit_bytes); XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); - // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( - *module, + // Create initial schedule of HLO instructions. + TF_ASSIGN_OR_RETURN(*schedule, + ScheduleModule(*module, [this](const BufferValue& buffer) { return size_function_(buffer.shape()); }, @@ -1217,16 +1218,7 @@ StatusOr HloRematerialization::Run( // ordering from the HLO schedule allows for more copies to be eliminated. // TODO(b/80249101): Instead of a separate copy elision pass, use the // ordering from the HLO schedule directly for copy insertion. - - // First create a copy of the schedule which contains HloInstruction unique - // ids instead of HloInstruction*. This is necessary for updating the - // schedule below. - // TODO(b/113175018): Remove this when the HLO schedule is self-contained - // and can update itself. - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(*sequence); - - SequentialHloOrdering ordering(module, *sequence); + SequentialHloOrdering ordering(*schedule); TF_RETURN_IF_ERROR( copy_insertion->RemoveUnnecessaryCopies(ordering, module)); @@ -1241,10 +1233,10 @@ StatusOr HloRematerialization::Run( // The passes above can add and remove copies, update the schedule to // account for these transformations. Newly added instructions will be // placed ASAP in the schedule. - TF_RETURN_IF_ERROR(UpdateSchedule(*module, id_sequence, sequence)); + TF_RETURN_IF_ERROR(schedule->Update()); TF_DCHECK_OK(copy_insertion->VerifyNoLiveRangeInterference( - SequentialHloOrdering(module, *sequence), module)); + SequentialHloOrdering(*schedule), module)); } TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module)); @@ -1271,12 +1263,13 @@ StatusOr HloRematerialization::Run( // sequential context. call_graph_ = CallGraph::Build(module); TF_RETURN_IF_ERROR(call_graph_->VisitNodes( - [this, sequence](const CallGraphNode& node) -> Status { + [this, schedule](const CallGraphNode& node) -> Status { if (node.context() == CallContext::kSequential) { TF_ASSIGN_OR_RETURN( computation_peak_memory_[node.computation()], - ComputePeakMemory(node.computation(), - sequence->at(node.computation()))); + ComputePeakMemory( + node.computation(), + schedule->sequence(node.computation()).instructions())); } return Status::OK(); }, @@ -1295,7 +1288,7 @@ StatusOr HloRematerialization::Run( // Subcomputations called by the entry computation will also be // rematerialized. TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( - module->entry_computation(), sequence, + module->entry_computation(), schedule, adjusted_memory_limit_bytes)); // Rematerialization can introduce dead code. This occurs if all uses of an @@ -1305,30 +1298,7 @@ StatusOr HloRematerialization::Run( // After DCE, the module sequence may include instructions which no longer // exist. - for (const auto* computation : module->MakeNonfusionComputations()) { - if (sequence->at(computation).size() != computation->instruction_count()) { - // A size mismatch between the computation instruction count and the size - // of the ordering of instructions can only be caused by DCE. Rebuild the - // order by removing the deleted instructions from the order. - tensorflow::gtl::FlatSet instruction_set; - for (const auto& instruction : computation->instructions()) { - instruction_set.insert(instruction); - } - // Move the old order into a temporary vector, then build new order - // inplace. - std::vector& order = sequence->at(computation); - std::vector old_order; - using std::swap; - swap(order, old_order); - std::copy_if(old_order.begin(), old_order.end(), - std::back_inserter(order), - [&instruction_set](const HloInstruction* instruction) { - return ContainsKey(instruction_set, instruction); - }); - TF_RET_CHECK(sequence->at(computation).size() == - computation->instruction_count()); - } - } + TF_RETURN_IF_ERROR(schedule->Update()); VLOG(1) << "Rematerialized " << instructions_rematerialized_ << " instructions in module " << module->name() << "; " << net_instructions_added_ << " net instructions added"; @@ -1366,11 +1336,10 @@ StatusOr HloRematerialization::Run( /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, - MemorySchedulerAlgorithm scheduler_algorithm, - SequentialHloOrdering::HloModuleSequence* sequence, + MemorySchedulerAlgorithm scheduler_algorithm, HloSchedule* schedule, RematerializationSizes* sizes, CopyInsertion* copy_insertion) { HloRematerialization remat(scheduler_algorithm, size_function); - return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes, + return remat.Run(hlo_module, schedule, memory_limit_bytes, sizes, copy_insertion); } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 2ec004350a..fa0414b472 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -21,6 +21,7 @@ #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -50,7 +51,7 @@ class HloRematerialization { // // hlo_module: HLO module to rematerialize instructions in. // - // sequence: Should point to an empty HloModuleSequence. Upon return + // schedule: Should point to an empty HloSchedule. Upon return // contains the HLO instruction order which was used for // rematerialization. This is the order in which HLO instructions should // be emitted to minimize memory use. @@ -75,8 +76,8 @@ class HloRematerialization { static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm, - SequentialHloOrdering::HloModuleSequence* sequence, - RematerializationSizes* sizes, CopyInsertion* copy_insertion = nullptr); + HloSchedule* schedule, RematerializationSizes* sizes, + CopyInsertion* copy_insertion = nullptr); protected: HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm, @@ -87,10 +88,9 @@ class HloRematerialization { // Runs rematerialization on the given module. Returns whether the module was // changed. memory_limit is the target maximum peak memory usage by the - // module. sequence should be an empty HloModuleSequence. Upon return sequence + // module. schedule should be an empty HloSchedule. Upon return sequence // contains the memory-minimizing order in which to emit the HLO instructions. - StatusOr Run(HloModule* module, - SequentialHloOrdering::HloModuleSequence* sequence, + StatusOr Run(HloModule* module, HloSchedule* schedule, int64 memory_limit, RematerializationSizes* sizes, CopyInsertion* copy_insertion); @@ -98,10 +98,9 @@ class HloRematerialization { // order in which the computation's instructions will be emitted in the // backend. Rematerialized instructions will be added to the HLO computation // and inserted into 'order'. - StatusOr RematerializeComputation( - HloComputation* computation, - SequentialHloOrdering::HloModuleSequence* sequence, - int64 computation_memory_limit); + StatusOr RematerializeComputation(HloComputation* computation, + HloSchedule* schedule, + int64 memory_limit_bytes); // Computes and returns the peak memory used by the given computation. The // peak memory is the maximum total size of all live HLO instruction values at diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index ac8c97d380..83cb113bfb 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -141,13 +141,13 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } - StatusOr RunHloRematerialization( - int64 memory_limit_bytes, HloModule* module, - SequentialHloOrdering::HloModuleSequence* sequence) { + StatusOr RunHloRematerialization(int64 memory_limit_bytes, + HloModule* module, + HloSchedule* schedule) { TF_EXPECT_OK(verifier().Run(module).status()); return HloRematerialization::RematerializeAndSchedule( ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, - sequence, /*sizes=*/nullptr); + schedule, /*sizes=*/nullptr); } // Various shapes used in the canned computations. @@ -170,12 +170,12 @@ TEST_F(HloRematerializationTest, SingleComputation) { const HloInstruction* concat = slice->operand(0); const HloInstruction* bcast = concat->operand(0); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/14 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // Root should not have changed. @@ -187,9 +187,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { // The rematerialized broadcast should be immediate before the concat in the // sequence. - EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 2], + EXPECT_EQ(schedule.sequence(computation) + .instructions()[computation->instruction_count() - 2], concat); - EXPECT_EQ(sequence.at(computation)[computation->instruction_count() - 3], + EXPECT_EQ(schedule.sequence(computation) + .instructions()[computation->instruction_count() - 3], remat_bcast); } @@ -203,10 +205,10 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/20 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -242,10 +244,10 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/17 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -276,10 +278,10 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(entry_computation->instruction_count(), 7); EXPECT_EQ(body_computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/15 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // Both computations should have rematerialized instructions added. @@ -316,10 +318,10 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/13 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // All computations should have rematerialized instructions added. @@ -382,14 +384,14 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { ASSERT_EQ(count_rngs(entry_computation), 1); const int64 original_instruction_count = entry_computation->instruction_count(); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( bool changed, RunHloRematerialization( /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -476,13 +478,13 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { EXPECT_EQ(add_3->operand(0), bcast); EXPECT_EQ(add_4->operand(0), bcast); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/22 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -571,13 +573,13 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { EXPECT_EQ(entry_computation->instruction_count(), 8); - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(module.get()); // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( /*memory_limit_bytes=*/22 * 1024, - module.get(), &sequence)); + module.get(), &schedule)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc new file mode 100644 index 0000000000..a65b33bf40 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -0,0 +1,291 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_schedule.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace xla { + +void HloSchedule::set_sequence( + const HloComputation* computation, + absl::Span sequence) { + set_sequence(computation, HloInstructionSequence(sequence)); +} + +void HloSchedule::set_sequence(const HloComputation* computation, + HloInstructionSequence sequence) { + CHECK(computation->parent() == module_); + sequences_[computation->unique_id()] = std::move(sequence); +} + +HloInstructionSequence& HloSchedule::GetOrCreateSequence( + const HloComputation* computation) { + auto it = sequences_.find(computation->unique_id()); + if (it == sequences_.end()) { + // No sequence found for computation. Create and return an empty one. + CHECK(computation->parent() == module_); + return sequences_[computation->unique_id()]; + } else { + return it->second; + } +} + +const HloInstructionSequence& HloSchedule::sequence( + const HloComputation* computation) const { + return sequences_.at(computation->unique_id()); +} + +Status HloSchedule::UpdateComputationSchedule( + const HloComputation* computation) { + // Map from unique ID to HloInstruction pointer for instructions in the + // computation. + tensorflow::gtl::FlatMap id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction); + } + + // Set of all HloInstructions in the schedule. + tensorflow::gtl::FlatSet ids_in_schedule; + for (int id : sequences_.at(computation->unique_id()).ids()) { + InsertOrDie(&ids_in_schedule, id); + } + + // Map from HloInstruction X to newly added instructions (instruction is in + // computation, but not in schedule) which use X. If an instruction is not in + // the map, then it has no users which are newly added instructions. + tensorflow::gtl::FlatMap> + new_instruction_uses; + + // For each newly added instruction, this is the count of the instruction's + // operands that have not yet been scheduled. When this value reaches zero, + // then the instruction may be placed in the schedule. + tensorflow::gtl::FlatMap + unscheduled_operand_count; + + // Create a worklist of newly added instructions which are ready to be added + // to the schedule. Initialize worklist with those that have zero operands. + std::queue worklist; + + for (const HloInstruction* instruction : computation->instructions()) { + if (ids_in_schedule.count(instruction->unique_id()) == 0) { + // This is a newly added instruction which is not in the schedule. + if (instruction->operands().empty()) { + worklist.push(instruction); + } else { + for (const HloInstruction* operand : instruction->operands()) { + new_instruction_uses[operand].push_back(instruction); + } + unscheduled_operand_count[instruction] = instruction->operand_count(); + } + } + } + + // Update the schedule with the newly added instructions, and remove any + // instructions no longer in the graph. + HloInstructionSequence new_sequence; + + // Lambda which schedules all instructions on the worklist. + auto schedule_worklist = [&]() { + while (!worklist.empty()) { + const HloInstruction* instruction = worklist.front(); + worklist.pop(); + new_sequence.push_back(instruction); + std::vector* new_users = + tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); + if (new_users != nullptr) { + // This just-scheduled instruction has users which are newly added to + // the module. Update the number of unscheduled operands and push the + // newly added instruction to the worklist if it is ready to + // schedule. + for (const HloInstruction* new_user : *new_users) { + unscheduled_operand_count.at(new_user)--; + CHECK_GE(unscheduled_operand_count.at(new_user), 0); + if (unscheduled_operand_count.at(new_user) == 0) { + worklist.push(new_user); + } + } + } + } + }; + + schedule_worklist(); + for (int id : sequences_.at(computation->unique_id()).ids()) { + auto it = id_to_instruction.find(id); + if (it == id_to_instruction.end()) { + // This instruction in the schedule is no longer in the module. Do not add + // it to the new schedule. + continue; + } + worklist.push(it->second); + schedule_worklist(); + } + + set_sequence(computation, std::move(new_sequence)); + return Status::OK(); +} + +Status HloSchedule::Update() { + // The schedule must contain a sequence for every non-fusion computation in + // the module, but can have sequences for computations which no longer exist + // (these are removed). + std::vector nonfusion_computations = + module_->MakeNonfusionComputations(); + for (const HloComputation* computation : nonfusion_computations) { + TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + << "Computation " << computation->name() << " not in HloSchedule."; + } + if (sequences_.size() > nonfusion_computations.size()) { + // Schedule contains some computations which have been removed from the + // HloModule. Remove them from the schedule as well. + tensorflow::gtl::FlatSet nonfusion_computations_ids; + for (const HloComputation* computation : nonfusion_computations) { + nonfusion_computations_ids.insert(computation->unique_id()); + } + for (auto it = sequences_.begin(); it != sequences_.end();) { + if (nonfusion_computations_ids.count(it->first) == 0) { + it = sequences_.erase(it); + } else { + it++; + } + } + } + CHECK_EQ(sequences_.size(), nonfusion_computations.size()); + + for (const HloComputation* computation : nonfusion_computations) { + TF_RETURN_IF_ERROR(UpdateComputationSchedule(computation)); + } + + TF_RETURN_IF_ERROR(Verify()); + return Status::OK(); +} + +Status HloSchedule::Verify() const { + VLOG(2) << "VerifySchedule()"; + XLA_VLOG_LINES(3, module_->ToString()); + XLA_VLOG_LINES(2, ToString()); + + // Verify schedule contains exactly the same set of non-fusion computations as + // module currently does. + std::vector nonfusion_computations = + module_->MakeNonfusionComputations(); + TF_RET_CHECK(nonfusion_computations.size() == sequences_.size()) + << "Schedule has " << sequences_.size() << " sequences, but module has " + << nonfusion_computations.size() << " non-fusion computations"; + for (const HloComputation* computation : nonfusion_computations) { + TF_RET_CHECK(sequences_.count(computation->unique_id()) == 1) + << "Computation " << computation->name() + << " missing from HLO schedule."; + } + + // For each computation verify the set of instructions is the same and that + // each dependency and control edge is honored. + for (const HloComputation* computation : nonfusion_computations) { + tensorflow::gtl::FlatMap instruction_position; + int pos = 0; + for (const HloInstruction* instruction : + sequence(computation).instructions()) { + TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) + << "Instruction " << instruction->name() + << " appears more than once in the schedule"; + pos++; + } + + TF_RET_CHECK(instruction_position.size() == + computation->instruction_count()); + for (const HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(instruction_position.count(instruction) == 1) + << "Instruction " << instruction->name() << " is not in schedule"; + } + + for (const HloInstruction* instruction : computation->instructions()) { + for (const HloInstruction* operand : instruction->operands()) { + TF_RET_CHECK(instruction_position.at(operand) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its operand " << operand->name(); + } + + for (const HloInstruction* pred : instruction->control_predecessors()) { + TF_RET_CHECK(instruction_position.at(pred) < + instruction_position.at(instruction)) + << "Instruction " << instruction->name() + << " is not scheduled after its control predecessor " + << pred->name(); + } + } + } + + return Status::OK(); +} + +namespace { + +// Returns the computation in the given module with the given unique ID. Returns +// nullptr if no such computation exists. +const HloComputation* IdToComputation(const HloModule* module, int64 id) { + for (const HloComputation* computation : module->computations()) { + if (computation->unique_id() == id) { + return computation; + } + } + return nullptr; +} + +} // namespace + +string HloSchedule::ToString() const { + std::vector pieces; + + pieces.push_back("HloSchedule"); + for (const auto& id_sequence : sequences_) { + const HloComputation* computation = + IdToComputation(module_, id_sequence.first); + if (computation == nullptr) { + // The computation is not in the module and may have been deleted so it is + // not safe to dereference any HLO pointers. Just use the HLO unique ids + // stored in this object. + pieces.push_back( + absl::StrFormat("computation with id %d (no longer in HLO module):", + id_sequence.first)); + for (int id : id_sequence.second.ids()) { + pieces.push_back(absl::StrCat(" ", id)); + } + } else { + pieces.push_back(absl::StrFormat("computation %s:", computation->name())); + for (const HloInstruction* instruction : + id_sequence.second.instructions()) { + pieces.push_back(absl::StrCat(" ", instruction->name())); + } + } + } + return absl::StrJoin(pieces, "\n"); +} + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule) { + out << schedule.ToString(); + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h new file mode 100644 index 0000000000..21c6988638 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { + +// Class representing a sequence of HLO instructions such as the sequential +// execution order of an HLO computation. +class HloInstructionSequence { + public: + HloInstructionSequence() = default; + HloInstructionSequence(absl::Span instructions) { + for (const HloInstruction* instruction : instructions) { + push_back(instruction); + } + } + + // Adds the instruction to the end of the sequence. + void push_back(const HloInstruction* instruction) { + instruction_sequence_.push_back(instruction); + id_sequence_.push_back(instruction->unique_id()); + } + + // Clears the sequence of all instructions. + void clear() { + instruction_sequence_.clear(); + id_sequence_.clear(); + } + + int64 size() const { return instruction_sequence_.size(); } + + // Returns the sequence of HLO instructions. + const std::vector& instructions() const { + return instruction_sequence_; + } + + // Returns the unique IDs of the instructions in the sequence (in order). + const std::vector& ids() const { return id_sequence_; } + + private: + // The sequence as HloInstructions. + std::vector instruction_sequence_; + + // The sequence of HLO instructions, represented by their unique IDs. The + // sequence is stored as both HloInstructions and unique IDs because the + // sequence may be referenced after transformations to the HLO graph and HLO + // pointers can be invalidated or recycled in this process (see + // HloSchedule::Update). + std::vector id_sequence_; +}; + +// A class representing a sequential schedule of instructions for an HLO +// module. A complete HLO schedule contains an instruction sequence for every +// non-fusion computation in the HLO module. +class HloSchedule { + public: + HloSchedule(const HloModule* module) : module_(module) {} + + // Returns a reference to the sequence for the given computation. + const HloInstructionSequence& sequence( + const HloComputation* computation) const; + + // Returns the sequence for the given computation. An empty sequence is + // created if none exists for the computation. + HloInstructionSequence& GetOrCreateSequence( + const HloComputation* computation); + + // Sets the sequence for the given computation to the given sequence. + void set_sequence(const HloComputation* computation, + absl::Span sequence); + void set_sequence(const HloComputation* computation, + HloInstructionSequence sequence); + + // Returns a map from HloComputation unique ID to instruction sequence. The + // map contains all sequences in the schedule. + const tensorflow::gtl::FlatMap& sequences() + const { + return sequences_; + } + + // Returns true if the schedule has a sequence for the given computation. + bool is_computation_scheduled(const HloComputation* computation) const { + return sequences_.count(computation->unique_id()) == 1; + } + + // Updates the schedule such that it is (again) a valid schedule for the + // module. This is used to update a schedule after the HLO module has been + // transformed in some way. In general, the only transformations to the module + // for which a schedule can be updated is the addition or removal of + // instructions and removal of computations. Updating the schedule after new + // dependencies between existing instructions in the module is not supported + // and may result in an error status returned. + // + // Instructions in the module which also exist in the given schedule will + // remain in the same order in the updated schedule. Instructions which exist + // in the module but not in the given schedule will be placed as early as + // possible in the updated schedule. + Status Update(); + + // Verifies that the given schedule is valid for the given module. + // Specifically, the schedule contains exactly the instructions in the + // non-fusion computations in the module and every dependency in the module is + // satisfied in the schedule. + Status Verify() const; + + string ToString() const; + + bool empty() const { return sequences_.empty(); } + + const HloModule* module() const { return module_; } + + private: + // Updates the instruction sequence for the given computation. + Status UpdateComputationSchedule(const HloComputation* computation); + + const HloModule* module_; + + // A map from computation unique ID to instruction sequence. Unique IDs are + // used rather than HloComputation pointers because HLO pointers are not + // unique across HLO transformations because pointers may be recycled. + tensorflow::gtl::FlatMap sequences_; +}; + +std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULE_H_ diff --git a/tensorflow/compiler/xla/service/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/hlo_schedule_test.cc new file mode 100644 index 0000000000..eb52582bb5 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_schedule_test.cc @@ -0,0 +1,341 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_schedule.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloScheduleTest : public HloTestBase {}; + +TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) { + // Updating the schedule of an unchanged HLO module should not affect the + // schedule at all. + const string module_str = R"( +HloModule UpdateScheduleUnchanged + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + const std::vector& entry_schedule = + schedule.sequence(module->entry_computation()).instructions(); + + EXPECT_EQ(entry_schedule.size(), 6); + + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(entry_schedule, + schedule.sequence(module->entry_computation()).instructions()); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) { + // Add some additional instructions to a module and verify the schedule can be + // updated. + const string module_str = R"( +HloModule UpdateScheduleWithNewInstructions + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + HloComputation* entry = module->entry_computation(); + const Shape shape = entry->root_instruction()->shape(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, constant, entry->root_instruction())); + entry->set_root_instruction(sub); + + auto in_schedule = [&](const HloInstruction* hlo) { + return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo); + }; + + EXPECT_EQ(schedule.sequence(entry).size(), 6); + EXPECT_FALSE(in_schedule(constant)); + EXPECT_FALSE(in_schedule(sub)); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 8); + EXPECT_TRUE(in_schedule(constant)); + EXPECT_TRUE(in_schedule(sub)); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) { + // Add and delete some instructions from a module and verify that the schedule + // can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithAddedAndDeletedInstruction + +ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + c = f32[] constant(42.0) + sum = f32[] add(a, b) + neg = f32[] negate(c) + ROOT root = f32[] multiply(sum, neg) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + // Set the entry root to some expression containing just a parameter and a + // constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction* new_root = entry->AddInstruction( + HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, + constant, entry->parameter_instruction(0))); + entry->set_root_instruction(new_root); + + // DCE should remove everything but the parameters and the newly added code. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(entry).size(), 6); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 4); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) { + // Completely replace a module with an entirely new set of instructions and + // verify that the schedule can be updated successfully. + const string module_str = R"( +HloModule UpdateScheduleWithCompletelyReplacedModule + +ENTRY main { + a = f32[] constant(42.0) + b = f32[] constant(123.0) + ROOT sum = f32[] add(a, b) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + + // Replace the entry computation with the negation of a constant. + HloComputation* entry = module->entry_computation(); + HloInstruction* constant = entry->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNegate, constant)); + entry->set_root_instruction(new_root); + + // DCE the old instructions. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(entry).size(), 3); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(entry).size(), 2); +} + +TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) { + // Create changes to more than one computation in an HLO module and verify + // that the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + const HloInstruction* xla_while = + module->entry_computation()->root_instruction()->operand(0); + HloComputation* body = xla_while->while_body(); + HloComputation* cond = xla_while->while_condition(); + + // Negate the root of the cond. + cond->set_root_instruction(cond->AddInstruction( + HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kNot, cond->root_instruction()))); + + // Replace the body with a computation which just passes through its + // parameter. + body->set_root_instruction(body->parameter_instruction(0)); + + // DCE the dead code in the body. + HloDCE dce; + TF_ASSERT_OK(dce.Run(module.get()).status()); + + EXPECT_EQ(schedule.sequence(body).size(), 7); + EXPECT_EQ(schedule.sequence(cond).size(), 4); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); + + EXPECT_EQ(schedule.sequence(body).size(), 1); + EXPECT_EQ(schedule.sequence(cond).size(), 5); +} + +TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) { + // Remove computations from a module and verify the schedule can be updated. + const string module_str = R"( +HloModule UpdateScheduleWithMultipleComputations + +%Body (param.1: (s32[], token[])) -> (s32[], token[]) { + %param.1 = (s32[], token[]) parameter(0) + %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 + %constant.1 = s32[] constant(1) + %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) + %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 + %after-all = token[] after-all(token[] %get-tuple-element.2) + ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) +} + +%Cond (param: (s32[], token[])) -> pred[] { + %param = (s32[], token[]) parameter(0) + %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 + %constant = s32[] constant(42) + ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) +} + +ENTRY %WhileLoop () -> s32[] { + %zero = s32[] constant(0) + %init_token = token[] after-all() + %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) + %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body + ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(module_str)); + TF_ASSERT_OK_AND_ASSIGN( + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), + /*pointer_size=*/sizeof(void*)); + })); + + HloInstruction* xla_while = + module->entry_computation()->root_instruction()->mutable_operand(0); + HloInstruction* init = xla_while->mutable_operand(0); + + // Replace the while with its init value. The conditional and body + // computations should then be dead. + TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init)); + + // DCE the dead code in the body. + HloDCE dce; + ASSERT_EQ(module->computation_count(), 3); + TF_ASSERT_OK(dce.Run(module.get()).status()); + ASSERT_EQ(module->computation_count(), 1); + + ASSERT_IS_NOT_OK(schedule.Verify()); + TF_ASSERT_OK(schedule.Update()); + TF_ASSERT_OK(schedule.Verify()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 0fc3b268c0..9bfb0af96c 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -70,7 +70,7 @@ class ListScheduler { public: // Construct and return a memory-minimizing sequence of HLO instructions // containing the given HLO computation. - static StatusOr> Run( + static StatusOr Run( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -229,8 +229,8 @@ class ListScheduler { return {BytesFreedIfScheduled(entry), entry.instruction->user_count()}; } - std::vector CreateSchedule() { - std::vector schedule; + HloInstructionSequence CreateSchedule() { + HloInstructionSequence schedule; // Populate the ready list with instructions which have no operands or // control predecessors. @@ -374,7 +374,7 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> ScheduleComputationHelper( +StatusOr ScheduleComputationHelper( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -392,7 +392,7 @@ StatusOr> ScheduleComputationHelper( } // namespace -StatusOr> DFSMemoryScheduler( +StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -443,7 +443,7 @@ StatusOr> DFSMemoryScheduler( // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a // tiebreaker by name for determinism. - std::vector sequence; + HloInstructionSequence sequence; FunctionVisitor visitor([&sequence](HloInstruction* hlo) { sequence.push_back(hlo); return Status::OK(); @@ -463,7 +463,7 @@ StatusOr> DFSMemoryScheduler( return sequence; } // namespace xla -StatusOr> ListMemoryScheduler( +StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -473,18 +473,16 @@ StatusOr> ListMemoryScheduler( memory_by_computation); } -StatusOr> PostOrderMemoryScheduler( +StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation) { - const auto& post_order = computation.MakeInstructionPostOrder(); - return std::vector{post_order.begin(), - post_order.end()}; + return HloInstructionSequence(computation.MakeInstructionPostOrder()); } -StatusOr> DefaultMemoryScheduler( +StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -499,7 +497,7 @@ StatusOr> DefaultMemoryScheduler( // List wins for most of our benchmarks; postorder-based schedulers win for // some RNNs. TF_ASSIGN_OR_RETURN( - std::vector list_sequence, + HloInstructionSequence list_sequence, ListMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 list_memory, @@ -508,7 +506,7 @@ StatusOr> DefaultMemoryScheduler( size_function, &memory_by_computation)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); - TF_ASSIGN_OR_RETURN(std::vector dfs_sequence, + TF_ASSIGN_OR_RETURN(HloInstructionSequence dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 dfs_memory, @@ -518,7 +516,7 @@ StatusOr> DefaultMemoryScheduler( VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); TF_ASSIGN_OR_RETURN( - std::vector post_order_sequence, + HloInstructionSequence post_order_sequence, PostOrderMemoryScheduler(computation, points_to_analysis, size_function, memory_by_computation)); TF_ASSIGN_OR_RETURN(const int64 post_order_memory, @@ -545,32 +543,35 @@ StatusOr> DefaultMemoryScheduler( } } -StatusOr ScheduleComputationsInModule( +StatusOr ScheduleModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm) { - SequentialHloOrdering::HloModuleSequence sequence; + HloSchedule schedule(&module); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); tensorflow::gtl::FlatMap memory_by_computation; for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { - TF_ASSIGN_OR_RETURN(auto one_computation_sequence, + TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, ScheduleComputationHelper( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = HeapSimulator::MinimumMemoryForComputation( - *computation, one_computation_sequence, *points_to_analysis, + *computation, computation_sequence, *points_to_analysis, size_function, &memory_by_computation) .ValueOrDie(); - sequence[computation] = std::move(one_computation_sequence); + schedule.set_sequence(computation, std::move(computation_sequence)); } } - VLOG(1) << "Module schedule:\n" << sequence; - return sequence; + VLOG(1) << "Module schedule:\n" << schedule; + + TF_RETURN_IF_ERROR(schedule.Verify()); + + return std::move(schedule); } -StatusOr> ScheduleOneComputation( +StatusOr ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); @@ -581,187 +582,4 @@ StatusOr> ScheduleOneComputation( size_function, nullptr, empty_map); } -tensorflow::gtl::FlatMap> -ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence) { - tensorflow::gtl::FlatMap> id_sequence; - for (const auto& computation_sequence : sequence) { - for (const HloInstruction* instruction : computation_sequence.second) { - id_sequence[computation_sequence.first].push_back( - instruction->unique_id()); - } - } - return id_sequence; -} - -Status UpdateSchedule( - const HloModule& module, - const tensorflow::gtl::FlatMap>& - id_sequence, - SequentialHloOrdering::HloModuleSequence* sequence) { - // Map from unique ID to HloInstruction pointer for instructions in the - // module. - tensorflow::gtl::FlatMap id_to_instruction; - // Set of all HloInstructions in the schedule. - tensorflow::gtl::FlatSet ids_in_schedule; - std::vector nonfusion_computations = - module.MakeNonfusionComputations(); - for (const HloComputation* computation : nonfusion_computations) { - for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK( - id_to_instruction.insert({instruction->unique_id(), instruction}) - .second); - } - for (int id : id_sequence.at(computation)) { - ids_in_schedule.insert(id); - } - } - - // Map from HloInstruction X to newly added instructions (instruction is in - // module, but not in schedule) which use X. If an instruction is not in the - // map, then it has no users which are newly added instructions. - tensorflow::gtl::FlatMap> - new_instruction_uses; - - // For each newly added instruction, this is the count of the instruction's - // operands that have not yet been scheduled. When this value reaches zero, - // then the instruction may be placed in the schedule. - tensorflow::gtl::FlatMap - unscheduled_operand_count; - // For each computation, this is the set of newly added instructions which - // have no operands. These must be handled specially and are added to the - // beginning of the schedule. - tensorflow::gtl::FlatMap> - new_zero_operand_instructions; - for (const HloComputation* computation : nonfusion_computations) { - new_zero_operand_instructions[computation] = {}; - for (const HloInstruction* instruction : computation->instructions()) { - if (ids_in_schedule.count(instruction->unique_id()) == 0) { - // This is a newly added instruction which is not in the schedule. - for (const HloInstruction* operand : instruction->operands()) { - new_instruction_uses[operand].push_back(instruction); - } - if (instruction->operands().empty()) { - new_zero_operand_instructions[computation].push_back(instruction); - } - unscheduled_operand_count[instruction] = instruction->operand_count(); - } - } - } - - // Update the schedule with the newly added instructions, and remove any - // instructions no longer in the graph. - for (const HloComputation* computation : nonfusion_computations) { - std::vector old_computation_sequence = - std::move(sequence->at(computation)); - sequence->at(computation).clear(); - - // Create a worklist of newly added instructions which are ready to be added - // to the schedule. Initialize worklist with those that have zero operands. - std::queue worklist; - for (const HloInstruction* instruction : - new_zero_operand_instructions.at(computation)) { - worklist.push(instruction); - } - - // Lambda which schedules all instructions on the worklist. - auto schedule_worklist = [&]() { - while (!worklist.empty()) { - const HloInstruction* instruction = worklist.front(); - worklist.pop(); - sequence->at(computation).push_back(instruction); - std::vector* new_users = - tensorflow::gtl::FindOrNull(new_instruction_uses, instruction); - if (new_users != nullptr) { - // This just-scheduled instruction has users which are newly added to - // the module. Update the number of unscheduled operands and push the - // newly added instruction to the worklist if it is ready to - // schedule. - for (const HloInstruction* new_user : *new_users) { - unscheduled_operand_count.at(new_user)--; - CHECK_GE(unscheduled_operand_count.at(new_user), 0); - if (unscheduled_operand_count.at(new_user) == 0) { - worklist.push(new_user); - } - } - } - } - }; - - schedule_worklist(); - for (int id : id_sequence.at(computation)) { - auto it = id_to_instruction.find(id); - if (it == id_to_instruction.end()) { - // This instruction in the schedule is no longer in the module. - continue; - } - const HloInstruction* instruction = it->second; - worklist.push(instruction); - schedule_worklist(); - } - } - - TF_RETURN_IF_ERROR(VerifySchedule(module, *sequence)); - return Status::OK(); -} - -Status VerifySchedule( - const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& sequence) { - VLOG(2) << "VerifySchedule()"; - XLA_VLOG_LINES(2, module.ToString()); - VLOG(2) << sequence; - - // Verify the set of computations in the sequence is exactly the set of - // computations in the module. - std::vector nonfusion_computations = - module.MakeNonfusionComputations(); - TF_RET_CHECK(nonfusion_computations.size() == sequence.size()); - tensorflow::gtl::FlatSet computations_in_module( - module.computations().begin(), module.computations().end()); - for (const auto& computation_sequence : sequence) { - TF_RET_CHECK(computations_in_module.count(computation_sequence.first) == 1); - } - - // For each computation verify the set of instructions is the same and that - // each dependency and control edge is honored. - for (const HloComputation* computation : nonfusion_computations) { - tensorflow::gtl::FlatMap instruction_position; - int pos = 0; - for (const HloInstruction* instruction : sequence.at(computation)) { - TF_RET_CHECK(instruction_position.insert({instruction, pos}).second) - << "Instruction " << instruction->name() - << " appears more than once in the schedule"; - pos++; - } - - TF_RET_CHECK(instruction_position.size() == - computation->instruction_count()); - for (const HloInstruction* instruction : computation->instructions()) { - TF_RET_CHECK(instruction_position.count(instruction) == 1) - << "Instruction " << instruction->name() << " is not in schedule"; - } - - for (const HloInstruction* instruction : computation->instructions()) { - for (const HloInstruction* operand : instruction->operands()) { - TF_RET_CHECK(instruction_position.at(operand) < - instruction_position.at(instruction)) - << "Instruction " << instruction->name() - << " is not scheduled after its operand " << operand->name(); - } - - for (const HloInstruction* pred : instruction->control_predecessors()) { - TF_RET_CHECK(instruction_position.at(pred) < - instruction_position.at(instruction)) - << "Instruction " << instruction->name() - << " is not scheduled after its control predecessor " - << pred->name(); - } - } - } - - return Status::OK(); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index d06b8d9a5c..54e32340ba 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" @@ -32,14 +33,14 @@ namespace xla { // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function // that maps a tensor's logical size to its padded size. -typedef std::function>( +typedef std::function( const HloComputation&, const TuplePointsToAnalysis&, const LogicalBuffer::SizeFunction&, const tensorflow::gtl::FlatMap&)> MemorySchedulerAlgorithm; // List scheduler -StatusOr> ListMemoryScheduler( +StatusOr ListMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -47,7 +48,7 @@ StatusOr> ListMemoryScheduler( memory_by_computation); // DFS-order scheduler -StatusOr> DFSMemoryScheduler( +StatusOr DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -55,7 +56,7 @@ StatusOr> DFSMemoryScheduler( memory_by_computation); // Naive Post Order scheduler -StatusOr> PostOrderMemoryScheduler( +StatusOr PostOrderMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -65,63 +66,26 @@ StatusOr> PostOrderMemoryScheduler( // The default scheduling algorithm. Runs both the list scheduler // and the DFS scheduler, and chooses whichever returns a lower min-memory, // not accounting for fragmentation. -StatusOr> DefaultMemoryScheduler( +StatusOr DefaultMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, const tensorflow::gtl::FlatMap& memory_by_computation); -// Returns an HloModuleSequence which seeks to minimize the memory required for +// Returns an HloSchedule which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. -StatusOr ScheduleComputationsInModule( +StatusOr ScheduleModule( const HloModule& module, const LogicalBuffer::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm = {}); // Computes the schedule for a single computation. // Currently only used by the GPU backend. -StatusOr> ScheduleOneComputation( +StatusOr ScheduleComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); -// Transforms the given schedule such that it is (again) a valid schedule for -// the module. This is used to update a schedule after the HLO module has been -// transformed in some way. In general, the only transformations to the module -// for which a schedule can be updated is the addition or removal of -// instructions to/from the module. Updating the schedule after new dependencies -// between existing instructions in the module is not supported and may result -// in an error status returned. -// -// Instructions in the module which also exist in the given schedule will remain -// in the same order in the updated schedule. Instructions which exist in the -// module but not in the given schedule will be placed as early as possible in -// the updated schedule. -// -// 'id_sequence' is a mirror of the given schedule 'sequence' but with -// HloInstruction ids rather than HloInstruction pointers. This should be -// constructed using ComputeIdSchedule below after the schedule is constructed -// but before the HLO module is transformed. -Status UpdateSchedule( - const HloModule& module, - const tensorflow::gtl::FlatMap>& - id_sequence, - SequentialHloOrdering::HloModuleSequence* sequence); - -// Constructs a copy of the given schedule but with HloInstruction unique ids -// rather than HloInstruction pointers. This is necessary for updating a -// schedule as HloInstruction points in the schedule may become invalid if -// instructions are removed from the module. Used by UpdateSchedule above.. -// TODO(b/113175018): Remove this function when HLO schedule is its own class. -tensorflow::gtl::FlatMap> -ComputeIdSchedule(const SequentialHloOrdering::HloModuleSequence& sequence); - -// Verifies that the given schedule is valid for the given module. Specifically, -// the schedule contains exactly the instructions in the module and every -// dependency in the module is satisfied in the schedule. -Status VerifySchedule(const HloModule& module, - const SequentialHloOrdering::HloModuleSequence& sequence); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index d49d09d459..6afe51997e 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" @@ -67,19 +68,20 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { + HloSchedule schedule, + ScheduleModule(*module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + const std::vector& sequence = + schedule.sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); // The first instruction should be the parameter and the last the root "sub". - EXPECT_EQ(param, sequence.at(module->entry_computation()).front()); - EXPECT_EQ(sub, sequence.at(module->entry_computation()).back()); + EXPECT_EQ(param, sequence.front()); + EXPECT_EQ(sub, sequence.back()); - SequentialHloOrdering ordering(module.get(), sequence); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); } @@ -108,28 +110,26 @@ ENTRY root { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. - EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); + const std::vector& sequence = + schedule.sequence(module->entry_computation()).instructions(); + EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size()); std::unordered_map instructions_by_name; - for (const HloInstruction* instruction : - sequence.at(module->entry_computation())) { + for (const HloInstruction* instruction : sequence) { instructions_by_name[instruction->name()] = instruction; } // The first instruction should be the parameter and the last the root. - EXPECT_EQ(instructions_by_name.at("param"), - sequence.at(module->entry_computation()).front()); - EXPECT_EQ(instructions_by_name.at("result"), - sequence.at(module->entry_computation()).back()); + EXPECT_EQ(instructions_by_name.at("param"), sequence.front()); + EXPECT_EQ(instructions_by_name.at("result"), sequence.back()); // Instructions "d" and "e" will both be schedulable at the same time, but // instruction "d" allows us to free the buffer of "p1", so the list scheduler // should prefer it. - SequentialHloOrdering ordering(module.get(), sequence); + SequentialHloOrdering ordering(schedule); EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"), instructions_by_name.at("e"))); } @@ -220,13 +220,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); EXPECT_EQ(entry_computation->instruction_count(), - sequence.at(entry_computation).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(entry_computation).size()); + SequentialHloOrdering ordering(schedule); // This schedule is an example of List's greedy heuristics being suboptimal. // The while_loop is more expensive than transpose, so it would have been // better to schedule it first, instead of during the busy time. @@ -243,13 +243,13 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { // HeapSimulator doesn't account for subcomputations EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); // HeapSimulator accounts for subcomputations. The output buffer is aliased, // so we don't double count. EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } @@ -281,19 +281,18 @@ TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf( - buffer.shape(), TUPLE_SIZE); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); // tuple allocates the tuple buffer and doesn't free anything. // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. // abs_abs2 should be scheduled before tuple by List. @@ -332,18 +331,18 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { auto fusion = computation->CreateFusionInstruction( {tuple, mul, add}, HloInstruction::FusionKind::kLoop); - TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule( - *module, - [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), 2); - }, - ListMemoryScheduler)); + TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule, + ScheduleModule(*module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), 2); + }, + ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), - sequence.at(module->entry_computation()).size()); - SequentialHloOrdering ordering(module.get(), sequence); + schedule.sequence(module->entry_computation()).size()); + SequentialHloOrdering ordering(schedule); // fusion allocates memory for the tuple elements and doesn't free anything, // so it's more expensive than exp. EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); @@ -391,12 +390,12 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); + HloSchedule schedule, + ScheduleModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. auto entry_computation = module->entry_computation(); - EXPECT_EQ(entry_computation->instruction_count(), - sequence.at(entry_computation).size()); + EXPECT_EQ(module->entry_computation()->instruction_count(), + schedule.sequence(module->entry_computation()).size()); tensorflow::gtl::FlatMap memory_by_computation; memory_by_computation[cond_computation] = 17; @@ -406,262 +405,16 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { // HeapSimulator doesn't account for subcomputations EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); // HeapSimulator accounts for subcomputations. Cond is the largest one. // The output buffer of the while is aliased. EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( - *entry_computation, sequence.at(entry_computation), + *entry_computation, schedule.sequence(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); } -TEST_F(HloSchedulingTest, UpdateScheduleUnchangedModule) { - // Updating the schedule of an unchanged HLO module should not affect the - // schedule at all. - const string module_str = R"( -HloModule UpdateScheduleUnchanged - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - std::vector entry_schedule = sequence.begin()->second; - - EXPECT_EQ(entry_schedule.size(), 6); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(entry_schedule, sequence.begin()->second); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithNewInstructions) { - // Add some additional instructions to a module and verify the schedule can be - // updated. - const string module_str = R"( -HloModule UpdateScheduleWithNewInstructions - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - HloComputation* entry = module->entry_computation(); - const Shape shape = entry->root_instruction()->shape(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kSubtract, constant, entry->root_instruction())); - entry->set_root_instruction(sub); - - auto in_schedule = [&](const HloInstruction* hlo) { - return std::find(sequence.at(entry).begin(), sequence.at(entry).end(), - hlo) != sequence.at(entry).end(); - }; - - EXPECT_EQ(sequence.at(entry).size(), 6); - EXPECT_FALSE(in_schedule(constant)); - EXPECT_FALSE(in_schedule(sub)); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 8); - EXPECT_TRUE(in_schedule(constant)); - EXPECT_TRUE(in_schedule(sub)); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithAddedAndDeletedInstruction) { - // Add and delete some instructions from a module and verify that the schedule - // can be updated successfully. - const string module_str = R"( -HloModule UpdateScheduleWithAddedAndDeletedInstruction - -ENTRY main { - a = f32[] parameter(0) - b = f32[] parameter(1) - c = f32[] constant(42.0) - sum = f32[] add(a, b) - neg = f32[] negate(c) - ROOT root = f32[] multiply(sum, neg) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - // Set the entry root to some expression containing just a parameter and a - // constant. - HloComputation* entry = module->entry_computation(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); - HloInstruction* new_root = entry->AddInstruction( - HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract, - constant, entry->parameter_instruction(0))); - entry->set_root_instruction(new_root); - - // DCE should remove everything but the parameters and the newly added code. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(entry).size(), 6); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 4); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithCompletelyReplacedModule) { - // Completely replace a module with an entirely new set of instructions and - // verify that the schedule can be updated successfully. - const string module_str = R"( -HloModule UpdateScheduleWithCompletelyReplacedModule - -ENTRY main { - a = f32[] constant(42.0) - b = f32[] constant(123.0) - ROOT sum = f32[] add(a, b) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape()); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - // Replace the entry computation with the negation of a constant. - HloComputation* entry = module->entry_computation(); - HloInstruction* constant = entry->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kNegate, constant)); - entry->set_root_instruction(new_root); - - // DCE the old instructions. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(entry).size(), 3); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(entry).size(), 2); -} - -TEST_F(HloSchedulingTest, UpdateScheduleWithMultipleComputations) { - // Create changes to more than one computation in an HLO module and verify - // that the schedule can be updated. - const string module_str = R"( -HloModule UpdateScheduleWithMultipleComputations - -%Body (param.1: (s32[], token[])) -> (s32[], token[]) { - %param.1 = (s32[], token[]) parameter(0) - %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0 - %constant.1 = s32[] constant(1) - %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1) - %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1 - %after-all = token[] after-all(token[] %get-tuple-element.2) - ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all) -} - -%Cond (param: (s32[], token[])) -> pred[] { - %param = (s32[], token[]) parameter(0) - %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0 - %constant = s32[] constant(42) - ROOT %less-than = pred[] less-than(s32[] %get-tuple-element, s32[] %constant) -} - -ENTRY %WhileLoop () -> s32[] { - %zero = s32[] constant(0) - %init_token = token[] after-all() - %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token) - %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body - ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0 -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseHloString(module_str)); - TF_ASSERT_OK_AND_ASSIGN( - SequentialHloOrdering::HloModuleSequence sequence, - ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), - /*pointer_size=*/sizeof(void*)); - })); - tensorflow::gtl::FlatMap> - id_sequence = ComputeIdSchedule(sequence); - - const HloInstruction* xla_while = - module->entry_computation()->root_instruction()->operand(0); - HloComputation* body = xla_while->while_body(); - HloComputation* cond = xla_while->while_condition(); - - // Negate the root of the cond. - cond->set_root_instruction(cond->AddInstruction( - HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kNot, cond->root_instruction()))); - - // Replace the body with a computation which just passes through its - // parameter. - body->set_root_instruction(body->parameter_instruction(0)); - - // DCE the dead code in the body. - HloDCE dce; - TF_ASSERT_OK(dce.Run(module.get()).status()); - - EXPECT_EQ(sequence.at(body).size(), 7); - EXPECT_EQ(sequence.at(cond).size(), 4); - - TF_ASSERT_OK(UpdateSchedule(*module, id_sequence, &sequence)); - TF_ASSERT_OK(VerifySchedule(*module, sequence)); - - EXPECT_EQ(sequence.at(body).size(), 1); - EXPECT_EQ(sequence.at(cond).size(), 5); -} - } // namespace } // namespace xla -- GitLab From 857b55492b311cf4161e8528f7e7e9227fc912af Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 5 Sep 2018 17:23:17 -0700 Subject: [PATCH 167/540] Add cuboid convolution benchmarks. PiperOrigin-RevId: 211727610 --- tensorflow/core/kernels/eigen_benchmark_cpu_test.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc index 7c2bbb8148..3b34f650b6 100644 --- a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc +++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc @@ -403,9 +403,15 @@ BM_CuboidConvolutions(8, // batch size 16, 5, 5, 5, // filter: count, height, width, panes "conv3d_depth4"); BM_CuboidConvolutions(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); +BM_CuboidConvolutions(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1"); +BM_CuboidConvolutions(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2"); BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4"); BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); +BM_CuboidConvolutionsBwdInput(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1"); +BM_CuboidConvolutionsBwdInput(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2"); BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4"); BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); +BM_CuboidConvolutionsBwdKernel(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1"); +BM_CuboidConvolutionsBwdKernel(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2"); -- GitLab From 680e1754b49362858cda8fd6cea52e1cc4c41e6b Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 5 Sep 2018 17:25:13 -0700 Subject: [PATCH 168/540] Deprecate `tf.train.input_producer()` and related APIs. These APIs are based on queue runners, which have been deprecated and will be removed in TensorFlow 2.0. They have been replaced with `tf.data.Dataset`, which provides a more efficient version of the same functionality. PiperOrigin-RevId: 211727844 --- tensorflow/python/training/input.py | 32 ++++++++++++++++--- .../api/golden/v2/tensorflow.train.pbtxt | 20 ------------ 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 94c6b47027..9d9db70890 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -76,7 +76,10 @@ def match_filenames_once(pattern, name=None): collections=[ops.GraphKeys.LOCAL_VARIABLES]) -@tf_export("train.limit_epochs") +@tf_export(v1=["train.limit_epochs"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.from_tensors(tensor).repeat(num_epochs)`.") def limit_epochs(tensor, num_epochs=None, name=None): """Returns tensor `num_epochs` times and then raises an `OutOfRange` error. @@ -109,7 +112,12 @@ def limit_epochs(tensor, num_epochs=None, name=None): return array_ops.identity(tensor, name=name) -@tf_export("train.input_producer") +@tf_export(v1=["train.input_producer"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.from_tensor_slices(input_tensor).shuffle" + "(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If " + "`shuffle=False`, omit the `.shuffle(...)`.") def input_producer(input_tensor, element_shape=None, num_epochs=None, @@ -192,7 +200,12 @@ def input_producer(input_tensor, return q -@tf_export("train.string_input_producer") +@tf_export(v1=["train.string_input_producer"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.from_tensor_slices(string_tensor).shuffle" + "(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If " + "`shuffle=False`, omit the `.shuffle(...)`.") def string_input_producer(string_tensor, num_epochs=None, shuffle=True, @@ -262,7 +275,11 @@ def string_input_producer(string_tensor, cancel_op=cancel_op) -@tf_export("train.range_input_producer") +@tf_export(v1=["train.range_input_producer"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.range(limit).shuffle(limit).repeat(num_epochs)`. If " + "`shuffle=False`, omit the `.shuffle(...)`.") def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None): """Produces the integers from 0 to limit-1 in a queue. @@ -300,7 +317,12 @@ def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None, shared_name, "fraction_of_%d_full" % capacity, name) -@tf_export("train.slice_input_producer") +@tf_export(v1=["train.slice_input_producer"]) +@deprecation.deprecated( + None, "Queue-based input pipelines have been replaced by `tf.data`. Use " + "`tf.data.Dataset.from_tensor_slices(tuple(tensor_list)).shuffle" + "(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If " + "`shuffle=False`, omit the `.shuffle(...)`.") def slice_input_producer(tensor_list, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None): """Produces a slice of each `Tensor` in `tensor_list`. diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt index e2b74e4d67..b21dabbde7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt @@ -308,10 +308,6 @@ tf_module { name: "init_from_checkpoint" argspec: "args=[\'ckpt_dir_or_file\', \'assignment_map\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "input_producer" - argspec: "args=[\'input_tensor\', \'element_shape\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'summary_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\', \'None\'], " - } member_method { name: "inverse_time_decay" argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'decay_rate\', \'staircase\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], " @@ -320,10 +316,6 @@ tf_module { name: "latest_checkpoint" argspec: "args=[\'checkpoint_dir\', \'latest_filename\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "limit_epochs" - argspec: "args=[\'tensor\', \'num_epochs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } member_method { name: "linear_cosine_decay" argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'0.001\', \'None\'], " @@ -360,10 +352,6 @@ tf_module { name: "polynomial_decay" argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'end_learning_rate\', \'power\', \'cycle\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0001\', \'1.0\', \'False\', \'None\'], " } - member_method { - name: "range_input_producer" - argspec: "args=[\'limit\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], " - } member_method { name: "remove_checkpoint" argspec: "args=[\'checkpoint_prefix\', \'checkpoint_format_version\', \'meta_graph_suffix\'], varargs=None, keywords=None, defaults=[\'2\', \'meta\'], " @@ -384,14 +372,6 @@ tf_module { name: "sdca_shrink_l1" argspec: "args=[\'weights\', \'l1\', \'l2\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } - member_method { - name: "slice_input_producer" - argspec: "args=[\'tensor_list\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\'], " - } - member_method { - name: "string_input_producer" - argspec: "args=[\'string_tensor\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\'], " - } member_method { name: "summary_iterator" argspec: "args=[\'path\'], varargs=None, keywords=None, defaults=None" -- GitLab From 7ec8114697a78271277c1b81707f53057d047901 Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Wed, 5 Sep 2018 17:47:58 -0700 Subject: [PATCH 169/540] Modify tags for internal CI PiperOrigin-RevId: 211730301 --- tensorflow/contrib/lite/testing/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 89912fd116..0b3a97d4f5 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -173,6 +173,7 @@ tf_cc_test( srcs = ["tflite_driver_test.cc"], data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"], tags = [ + "no_oss", # b/112769036 "tflite_not_portable_android", "tflite_not_portable_ios", ], -- GitLab From ad5c0c4d091c93ef65e91c55cb4df065d0c7a989 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 5 Sep 2018 18:16:33 -0700 Subject: [PATCH 170/540] [tf.data] Move all C++ code inside the `tensorflow::data` namespace. PiperOrigin-RevId: 211733735 --- tensorflow/compiler/jit/xla_device_ops.h | 18 +++++------ .../bigtable/kernels/bigtable_kernels.cc | 6 ++++ .../contrib/bigtable/kernels/bigtable_lib.h | 4 +++ .../kernels/bigtable_lookup_dataset_op.cc | 2 ++ .../kernels/bigtable_prefix_key_dataset_op.cc | 2 ++ .../kernels/bigtable_range_key_dataset_op.cc | 2 ++ .../bigtable_sample_key_pairs_dataset_op.cc | 2 ++ .../bigtable_sample_keys_dataset_op.cc | 2 ++ .../kernels/bigtable_scan_dataset_op.cc | 2 ++ .../data/kernels/assert_next_dataset_op.cc | 2 ++ .../contrib/data/kernels/csv_dataset_op.cc | 2 ++ .../kernels/directed_interleave_dataset_op.cc | 4 +-- .../data/kernels/identity_indexed_dataset.cc | 2 ++ .../data/kernels/ignore_errors_dataset_op.cc | 4 +-- .../contrib/data/kernels/indexed_dataset.cc | 5 ++-- .../contrib/data/kernels/indexed_dataset.h | 2 ++ .../contrib/data/kernels/lmdb_dataset_op.cc | 2 ++ .../data/kernels/prefetching_kernels.cc | 4 ++- .../data/kernels/threadpool_dataset_op.cc | 2 ++ .../contrib/data/kernels/unique_dataset_op.cc | 4 +-- .../hadoop/kernels/hadoop_dataset_ops.cc | 4 ++- tensorflow/core/framework/dataset.cc | 2 ++ tensorflow/core/framework/dataset.h | 30 ++++++++++++++----- .../framework/dataset_stateful_op_whitelist.h | 11 ++++--- tensorflow/core/framework/stats_aggregator.h | 3 ++ .../core/kernels/data/batch_dataset_op.cc | 4 +-- .../core/kernels/data/cache_dataset_ops.cc | 4 +-- .../core/kernels/data/captured_function.cc | 2 ++ .../core/kernels/data/captured_function.h | 8 +++++ .../kernels/data/concatenate_dataset_op.cc | 4 +-- tensorflow/core/kernels/data/dataset_ops.cc | 2 ++ tensorflow/core/kernels/data/dataset_utils.cc | 6 ++-- tensorflow/core/kernels/data/dataset_utils.h | 6 ++-- .../data/dense_to_sparse_batch_dataset_op.cc | 4 +-- .../data/filter_by_component_dataset_op.cc | 4 +-- .../core/kernels/data/filter_dataset_op.cc | 4 +-- .../core/kernels/data/flat_map_dataset_op.cc | 6 ++-- .../core/kernels/data/generator_dataset_op.cc | 4 +++ .../core/kernels/data/generator_dataset_op.h | 2 ++ .../data/group_by_reducer_dataset_op.cc | 2 ++ .../data/group_by_window_dataset_op.cc | 2 ++ .../kernels/data/interleave_dataset_op.cc | 8 ++--- tensorflow/core/kernels/data/iterator_ops.cc | 19 +++++++++++- tensorflow/core/kernels/data/iterator_ops.h | 2 ++ .../kernels/data/map_and_batch_dataset_op.cc | 4 +-- .../core/kernels/data/map_dataset_op.cc | 4 +-- tensorflow/core/kernels/data/map_defun_op.cc | 4 ++- .../core/kernels/data/optimize_dataset_op.cc | 2 ++ tensorflow/core/kernels/data/optional_ops.cc | 2 ++ tensorflow/core/kernels/data/optional_ops.h | 2 ++ .../kernels/data/padded_batch_dataset_op.cc | 4 +-- .../data/parallel_interleave_dataset_op.cc | 8 ++--- .../kernels/data/parallel_map_dataset_op.cc | 4 +-- .../kernels/data/parallel_map_iterator.cc | 2 ++ .../core/kernels/data/parallel_map_iterator.h | 2 ++ .../kernels/data/parse_example_dataset_op.cc | 4 +-- .../core/kernels/data/prefetch_autotuner.cc | 2 ++ .../core/kernels/data/prefetch_autotuner.h | 2 ++ .../kernels/data/prefetch_autotuner_test.cc | 2 ++ .../core/kernels/data/prefetch_dataset_op.cc | 5 ++++ .../core/kernels/data/prefetch_dataset_op.h | 2 ++ .../core/kernels/data/random_dataset_op.cc | 4 +-- .../core/kernels/data/range_dataset_op.cc | 4 +-- .../core/kernels/data/reader_dataset_ops.cc | 4 +-- .../core/kernels/data/repeat_dataset_op.cc | 4 +-- .../core/kernels/data/scan_dataset_op.cc | 4 +-- .../core/kernels/data/shuffle_dataset_op.cc | 4 +-- .../kernels/data/single_threaded_executor.cc | 2 ++ .../kernels/data/single_threaded_executor.h | 2 ++ .../data/single_threaded_executor_test.cc | 2 ++ .../core/kernels/data/skip_dataset_op.cc | 4 +-- .../core/kernels/data/slide_dataset_op.cc | 4 +-- .../data/sparse_tensor_slice_dataset_op.cc | 4 +-- .../core/kernels/data/sql/driver_manager.cc | 4 +-- .../core/kernels/data/sql/driver_manager.h | 4 +-- .../core/kernels/data/sql/query_connection.h | 3 +- .../data/sql/sqlite_query_connection.cc | 4 +-- .../data/sql/sqlite_query_connection.h | 4 +-- .../core/kernels/data/sql_dataset_ops.cc | 5 ++-- .../data/stats_aggregator_dataset_op.cc | 2 ++ .../core/kernels/data/stats_aggregator_ops.cc | 2 ++ .../core/kernels/data/stats_dataset_ops.cc | 2 ++ .../core/kernels/data/take_dataset_op.cc | 4 +-- .../core/kernels/data/tensor_dataset_op.cc | 4 +-- .../kernels/data/tensor_queue_dataset_op.cc | 4 +-- .../kernels/data/tensor_slice_dataset_op.cc | 4 +-- .../core/kernels/data/unbatch_dataset_op.cc | 4 +-- .../core/kernels/data/window_dataset.cc | 2 ++ tensorflow/core/kernels/data/window_dataset.h | 2 ++ .../core/kernels/data/window_dataset_op.cc | 4 +-- tensorflow/core/kernels/data/writer_ops.cc | 3 +- .../core/kernels/data/zip_dataset_op.cc | 4 +-- 92 files changed, 259 insertions(+), 119 deletions(-) diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 13da5d2f94..49c8582682 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -198,33 +198,33 @@ class XlaAssignVariableOp : public AsyncOpKernel { \ REGISTER_KERNEL_BUILDER( \ Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"), \ - GeneratorDatasetOp); \ + data::GeneratorDatasetOp); \ REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") \ .Device(DEVICE) \ .HostMemory("buffer_size") \ .HostMemory("input_dataset") \ .HostMemory("handle"), \ - PrefetchDatasetOp); \ + data::PrefetchDatasetOp); \ \ REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE), \ - IteratorHandleOp); \ + data::IteratorHandleOp); \ REGISTER_KERNEL_BUILDER( \ Name("MakeIterator").Device(DEVICE).HostMemory("dataset"), \ - MakeIteratorOp); \ + data::MakeIteratorOp); \ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE), \ - AnonymousIteratorHandleOp); \ + data::AnonymousIteratorHandleOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \ - IteratorGetNextOp); \ + data::IteratorGetNextOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \ - IteratorGetNextSyncOp); \ + data::IteratorGetNextSyncOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \ .Device(DEVICE) \ .HostMemory("string_handle"), \ - IteratorToStringHandleOp); \ + data::IteratorToStringHandleOp); \ REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2") \ .Device(DEVICE) \ .HostMemory("string_handle"), \ - IteratorFromStringHandleOp); \ + data::IteratorFromStringHandleOp); \ REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp) \ .Device(DEVICE) \ .HostMemory("output") \ diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc index a25a641cdb..6138d79126 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc @@ -172,6 +172,11 @@ class BigtableTableOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU), BigtableTableOp); +} // namespace + +namespace data { +namespace { + class ToBigtableOp : public AsyncOpKernel { public: explicit ToBigtableOp(OpKernelConstruction* ctx) @@ -354,5 +359,6 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU), ToBigtableOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h index a2a5df1037..4652021fec 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h @@ -79,6 +79,8 @@ class BigtableTableResource : public ResourceBase { ::google::cloud::bigtable::noex::Table table_; }; +namespace data { + // BigtableReaderDatasetIterator is an abstract class for iterators from // datasets that are "readers" (source datasets, not transformation datasets) // that read from Bigtable. @@ -138,6 +140,8 @@ class BigtableReaderDatasetIterator : public DatasetIterator { ::google::cloud::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_); }; +} // namespace data + } // namespace tensorflow #endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_ diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc index bd32672aa9..11f530e82a 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { namespace { class BigtableLookupDatasetOp : public UnaryDatasetOpKernel { @@ -226,4 +227,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU), BigtableLookupDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc index a803fdcb49..5cab729d9c 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { namespace { class BigtablePrefixKeyDatasetOp : public DatasetOpKernel { @@ -111,4 +112,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU), BigtablePrefixKeyDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc index 5cd0371c79..4dc4647bd2 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { namespace { class BigtableRangeKeyDatasetOp : public DatasetOpKernel { @@ -117,4 +118,5 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU), BigtableRangeKeyDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc index 6928d9423c..736775bdac 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { namespace { class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel { @@ -205,4 +206,5 @@ REGISTER_KERNEL_BUILDER( BigtableSampleKeyPairsDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc index a759fb5063..208b7b3e08 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { namespace { class BigtableSampleKeysDatasetOp : public DatasetOpKernel { @@ -118,4 +119,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU), BigtableSampleKeysDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc index 78a920b077..9407855fe8 100644 --- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc +++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { namespace { class BigtableScanDatasetOp : public DatasetOpKernel { @@ -224,4 +225,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU), BigtableScanDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc index e36c9c0634..c19a609780 100644 --- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" namespace tensorflow { +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -150,4 +151,5 @@ REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU), AssertNextDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 0ba905b92e..74107d5242 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_inputstream.h" namespace tensorflow { +namespace data { namespace { class CSVDatasetOp : public DatasetOpKernel { @@ -851,4 +852,5 @@ class CSVDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc index ccf7ec1f84..a5321620bf 100644 --- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -276,5 +276,5 @@ REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU), DirectedInterleaveDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc index 4718c1c8b9..c3cb45dbf7 100644 --- a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc +++ b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { +namespace data { namespace { class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel { @@ -150,4 +151,5 @@ REGISTER_KERNEL_BUILDER(Name("IdentityIndexedDataset").Device(DEVICE_CPU), IdentityIndexedDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index db24e60846..beec344534 100644 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -137,5 +137,5 @@ REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU), IgnoreErrorsDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.cc b/tensorflow/contrib/data/kernels/indexed_dataset.cc index c69564a31b..ced8ab0d60 100644 --- a/tensorflow/contrib/data/kernels/indexed_dataset.cc +++ b/tensorflow/contrib/data/kernels/indexed_dataset.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" namespace tensorflow { - +namespace data { namespace { Status VerifyTypesMatch(const DataTypeVector& expected, @@ -367,6 +367,7 @@ REGISTER_KERNEL_BUILDER(Name("IndexedDatasetMaterialize").Device(DEVICE_CPU), MaterializeDatasetOp); REGISTER_KERNEL_BUILDER(Name("IndexedDatasetGet").Device(DEVICE_CPU), IndexedDatasetGet); -} // namespace +} // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.h b/tensorflow/contrib/data/kernels/indexed_dataset.h index 6149de888c..7aa2d3fdbc 100644 --- a/tensorflow/contrib/data/kernels/indexed_dataset.h +++ b/tensorflow/contrib/data/kernels/indexed_dataset.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace data { // TODO(saeta): Urgh, this is ugly. class MaterializedIndexedDataset { @@ -112,6 +113,7 @@ Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor, Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset, Tensor* tensor); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_ diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc index 80f39992fb..d233c1f8ec 100644 --- a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "lmdb.h" // NOLINT(build/include) namespace tensorflow { +namespace data { namespace { class LMDBDatasetOp : public DatasetOpKernel { @@ -212,4 +213,5 @@ class LMDBDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index 725f8933c9..078de717e0 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +namespace data { namespace { struct BufferElement { @@ -1114,5 +1115,6 @@ REGISTER_KERNEL_BUILDER( Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU), MultiDeviceIteratorFromStringHandleOp); -} // anonymous namespace +} // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index ab584504a0..30fa97a636 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/util/work_sharder.h" namespace tensorflow { +namespace data { namespace { class ThreadPoolResource : public ResourceBase { @@ -214,4 +215,5 @@ REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU), ThreadPoolDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc index 6fbf5d2ebb..57fc5697a4 100644 --- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -219,5 +219,5 @@ REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU), UniqueDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc index 80b2d3e08b..2bf6097d01 100644 --- a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc +++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/platform/file_system.h" namespace tensorflow { +namespace data { namespace { static const size_t kSyncMarkerSize = 16; @@ -332,9 +333,10 @@ class SequenceFileDatasetOp : public DatasetOpKernel { }; DataTypeVector output_types_; }; -} // namespace REGISTER_KERNEL_BUILDER(Name("SequenceFileDataset").Device(DEVICE_CPU), SequenceFileDatasetOp); +} // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 9ffd8e1ee0..5281c56f04 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" namespace tensorflow { +namespace data { namespace { @@ -329,4 +330,5 @@ void BackgroundWorker::WorkerLoop() { } } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 04865a1d4f..4e51fba048 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -40,6 +40,13 @@ limitations under the License. namespace tensorflow { +// Forward declarations to avoid introducing a dependency on headers in +// "tensorflow/core/graph/...". +class GraphDefBuilder; +class Node; + +namespace data { + class DatasetBase; class SerializationContext; @@ -66,11 +73,6 @@ class IteratorStateWriter { virtual ~IteratorStateWriter() {} }; -// Forward declarations to avoid introducing a dependency on headers in -// "tensorflow/core/graph/...". -class GraphDefBuilder; -class Node; - // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. class GraphDefBuilderWrapper { public: @@ -222,8 +224,7 @@ class GraphDefBuilderWrapper { return (str_util::EndsWith(op_def->name(), "Dataset") && op_def->output_arg_size() == 1 && op_def->output_arg(0).type() == DT_VARIANT) || - dataset::WhitelistedStatefulOpRegistry::Global()->Contains( - op_def->name()); + WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name()); } bool HasAttr(const string& op_type_name, const string& attr_name) const; @@ -751,6 +752,21 @@ class BackgroundWorker { std::deque> work_queue_ GUARDED_BY(mu_); }; +} // namespace data + +// TODO(b/114112161): Remove these aliases when all users have moved over to the +// `tensorflow::data` namespace. +using data::DatasetBase; +using data::DatasetContext; +using data::DatasetIterator; +using data::DatasetOpKernel; +using data::IteratorBase; +using data::IteratorContext; +using data::IteratorStateReader; +using data::IteratorStateWriter; +using data::SerializationContext; +using data::UnaryDatasetOpKernel; + } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ diff --git a/tensorflow/core/framework/dataset_stateful_op_whitelist.h b/tensorflow/core/framework/dataset_stateful_op_whitelist.h index 3b48999edb..21c21723d0 100644 --- a/tensorflow/core/framework/dataset_stateful_op_whitelist.h +++ b/tensorflow/core/framework/dataset_stateful_op_whitelist.h @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -namespace dataset { +namespace data { // Registry for stateful ops that need to be used in dataset functions. // See below macro for usage details. class WhitelistedStatefulOpRegistry { @@ -47,7 +47,7 @@ class WhitelistedStatefulOpRegistry { std::set op_names_; }; -} // namespace dataset +} // namespace data // Use this macro to whitelist an op that is marked stateful but needs to be // used inside a map_fn in an input pipeline. This is only needed if you wish @@ -67,10 +67,9 @@ class WhitelistedStatefulOpRegistry { WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name) #define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \ WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) -#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ - static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \ - ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \ - name) +#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ + static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::data::WhitelistedStatefulOpRegistry::Global()->Add(name) } // namespace tensorflow diff --git a/tensorflow/core/framework/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h index 4a18efc940..af53ed0a3c 100644 --- a/tensorflow/core/framework/stats_aggregator.h +++ b/tensorflow/core/framework/stats_aggregator.h @@ -25,6 +25,8 @@ namespace tensorflow { class Summary; +namespace data { + // A `StatsAggregator` accumulates statistics incrementally. A // `StatsAggregator` can accumulate multiple different statistics, distinguished // by a string name. @@ -87,6 +89,7 @@ class StatsAggregatorResource : public ResourceBase { const std::shared_ptr stats_aggregator_; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_ diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index f9b5353724..a25f78c6f1 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -241,5 +241,5 @@ REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU), BatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 6ca0bcd37d..221b5ad835 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level description of @@ -891,5 +891,5 @@ REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU), CacheDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index 186740c2ac..ad2365b25b 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/notification.h" namespace tensorflow { +namespace data { /* static */ Status CapturedFunction::Create( @@ -418,4 +419,5 @@ CapturedFunction::CapturedFunction(const NameAttrList& func, captured_inputs_(std::move(captured_inputs)), use_inter_op_parallelism_(use_inter_op_parallelism) {} +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index 9526da22d1..e44bc78b1c 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -32,6 +32,8 @@ class Device; class OpKernelContext; class ResourceMgr; +namespace data { + // A `CapturedFunction` encapsulates a TensorFlow function and all of // the runtime support required to execute it. // @@ -141,6 +143,12 @@ class CapturedFunction { TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); }; +} // namespace data + +// TODO(b/114112161): Remove these aliases when all users have moved over to the +// `tensorflow::data` namespace. +using data::CapturedFunction; + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index c361a9adcb..a04f150e71 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -195,5 +195,5 @@ REGISTER_KERNEL_BUILDER(Name("ConcatenateDataset").Device(DEVICE_CPU), ConcatenateDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc index c71d027f23..bd1ccd5b5d 100644 --- a/tensorflow/core/kernels/data/dataset_ops.cc +++ b/tensorflow/core/kernels/data/dataset_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { +namespace data { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. @@ -48,4 +49,5 @@ class DatasetToGraphOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU), DatasetToGraphOp); +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index d85ef1cbab..e7ac368ae3 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -17,8 +17,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" namespace tensorflow { - -namespace dataset { +namespace data { Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector& input_element, @@ -45,6 +44,5 @@ Status MakeIteratorFromInputElement( ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator); } -} // namespace dataset - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 6c4191c2be..234856ea39 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -20,16 +20,14 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - -namespace dataset { +namespace data { Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, std::unique_ptr* out_iterator); -} // namespace dataset - +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_ diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc index 9770bc025d..237511a07d 100644 --- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -301,5 +301,5 @@ REGISTER_KERNEL_BUILDER(Name("DenseToSparseBatchDataset").Device(DEVICE_CPU), DenseToSparseBatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc index ce577397c5..a7e3a56727 100644 --- a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -166,5 +166,5 @@ REGISTER_KERNEL_BUILDER(Name("FilterByLastComponentDataset").Device(DEVICE_CPU), FilterByLastComponentDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index bbce001eaf..bf0aecaf3c 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -280,5 +280,5 @@ REGISTER_KERNEL_BUILDER(Name("FilterDataset").Device(DEVICE_CPU), FilterDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index b1eb2fd849..e3c45ef86c 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -245,7 +245,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { private: Status BuildCurrentElementIteratorLocked(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return dataset::MakeIteratorFromInputElement( + return MakeIteratorFromInputElement( ctx, captured_func_inputs_, element_index_++, dataset()->captured_func_.get(), prefix(), ¤t_element_iterator_); @@ -285,5 +285,5 @@ REGISTER_KERNEL_BUILDER(Name("FlatMapDataset").Device(DEVICE_CPU), FlatMapDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index ccee690d7e..ac5cc1b2c1 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. @@ -188,10 +189,13 @@ void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx, std::move(finalize_func), output_types_, output_shapes_); } +namespace { REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU), GeneratorDatasetOp); REGISTER_KERNEL_BUILDER( Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"), GeneratorDatasetOp); +} // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h index 8407543136..d23ed97ec3 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.h +++ b/tensorflow/core/kernels/data/generator_dataset_op.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" namespace tensorflow { +namespace data { class GeneratorDatasetOp : public DatasetOpKernel { public: @@ -36,5 +37,6 @@ class GeneratorDatasetOp : public DatasetOpKernel { NameAttrList finalize_func_; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc index 130f04da3e..d6ee42a7c6 100644 --- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -433,4 +434,5 @@ REGISTER_KERNEL_BUILDER(Name("GroupByReducerDataset").Device(DEVICE_CPU), GroupByReducerDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index 46a3185b49..e4fa557598 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -549,4 +550,5 @@ REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU), GroupByWindowDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 716e040277..0768f46665 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -201,7 +201,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(input_impl_->GetNext( ctx, &args_list_[cycle_index_], &end_of_input_)); if (!end_of_input_) { - TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( ctx, args_list_[cycle_index_], cycle_index_, dataset()->captured_func_.get(), prefix(), ¤t_elements_[cycle_index_])); @@ -288,7 +288,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { full_name(strings::StrCat("args_list_[", idx, "][", i, "]")), &args_list_[idx][i])); } - TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( ctx, args_list_[idx], idx, dataset()->captured_func_.get(), prefix(), ¤t_elements_[idx])); TF_RETURN_IF_ERROR( @@ -330,5 +330,5 @@ REGISTER_KERNEL_BUILDER(Name("InterleaveDataset").Device(DEVICE_CPU), InterleaveDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 4e9b280968..fe6d705eab 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -36,7 +36,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -236,6 +236,8 @@ class IteratorResource : public ResourceBase { const std::vector output_shapes_; }; +namespace { + // Helper class for reading data from a VariantTensorData object. class VariantTensorDataReader : public IteratorStateReader { public: @@ -443,6 +445,8 @@ class IteratorStateVariant { REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant, kIteratorVariantTypeName); +} // namespace + // Note that IteratorHandleOp holds a reference to the resource it creates. If // cleaning up resources with DestroyResourceOp is important, consider creating // resource containers with AnonymousIteratorHandleOp instead. @@ -622,6 +626,8 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator))); } +namespace { + class ToSingleElementOp : public AsyncOpKernel { public: explicit ToSingleElementOp(OpKernelConstruction* ctx) @@ -887,6 +893,8 @@ class OneShotIteratorOp : public AsyncOpKernel { const int graph_def_version_; }; +} // namespace + void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { IteratorResource* iterator; OP_REQUIRES_OK_ASYNC( @@ -957,6 +965,8 @@ void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) { } } +namespace { + class IteratorGetNextAsOptionalOp : public AsyncOpKernel { public: explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx) @@ -1037,6 +1047,8 @@ class IteratorGetNextAsOptionalOp : public AsyncOpKernel { std::vector output_shapes_; }; +} // namespace + void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) { const Tensor& resource_handle_t = ctx->input(0); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), @@ -1108,6 +1120,8 @@ void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) { resource_handle_t->scalar()() = resource_handle; } +namespace { + class SerializeIteratorOp : public OpKernel { public: explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -1202,4 +1216,7 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU), DeserializeIteratorOp); +} // namespace + +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index 723564286c..8a2b2639a7 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" namespace tensorflow { +namespace data { class IteratorResource; @@ -142,6 +143,7 @@ class IteratorFromStringHandleOp : public OpKernel { std::vector output_shapes_; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 8b0c9ad6b2..27c89b3661 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/core/platform/tracing.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -675,5 +675,5 @@ REGISTER_KERNEL_BUILDER(Name("MapAndBatchDatasetV2").Device(DEVICE_CPU), MapAndBatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 6c45fcafcc..306486b96a 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -196,5 +196,5 @@ class MapDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index cc4d7976f8..3c562fc7f3 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/util/reffed_status_callback.h" namespace tensorflow { +namespace data { namespace { void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, @@ -191,8 +192,9 @@ class MapDefunOp : public AsyncOpKernel { const OpKernel* kernel_; const size_t iter_; }; -}; // namespace +}; REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 6263dc3cf8..d5b725eac9 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -270,4 +271,5 @@ REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU), OptimizeDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc index cfac45dbc7..b372d31a93 100644 --- a/tensorflow/core/kernels/data/optional_ops.cc +++ b/tensorflow/core/kernels/data/optional_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_op_registry.h" namespace tensorflow { +namespace data { namespace { const char kOptionalVariantTypeName[] = "tensorflow::data::Optional"; @@ -267,4 +268,5 @@ Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) { return Status::OK(); } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h index 6f25567678..2cbf2933f5 100644 --- a/tensorflow/core/kernels/data/optional_ops.h +++ b/tensorflow/core/kernels/data/optional_ops.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_tensor_data.h" namespace tensorflow { +namespace data { // Stores a DT_VARIANT value representing an Optional with the given value // in the `output_index`^th output of the given kernel execution context. @@ -31,6 +32,7 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index, // in the `output_index`^th output of the given kernel execution context. Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_ diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index be45eac46e..fd0e6c4cd0 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -382,5 +382,5 @@ REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU), PaddedBatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index f6b3fd97e3..f8287cf0e3 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -684,7 +684,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { { tf_shared_lock l(ckpt_mu_); worker_thread_states_[thread_index].iterator_creation_status = - dataset::MakeIteratorFromInputElement( + MakeIteratorFromInputElement( ctx.get(), worker_thread_states_[thread_index].input, thread_index, dataset()->captured_func_.get(), prefix(), &worker_thread_states_[thread_index].iterator); @@ -914,7 +914,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { worker_thread_states_[index].iterator.reset(); } else { std::unique_ptr iterator; - Status s = dataset::MakeIteratorFromInputElement( + Status s = MakeIteratorFromInputElement( ctx, worker_thread_states_[index].input, index, dataset()->captured_func_.get(), prefix(), &iterator); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator)); @@ -1068,5 +1068,5 @@ REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU), ParallelInterleaveDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index bff54813d6..ac5ed286ee 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -180,5 +180,5 @@ REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU), ParallelMapDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 61f8139b9e..4ae742aaaf 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -20,6 +20,7 @@ limitations under the License. #include namespace tensorflow { +namespace data { namespace { class ParallelMapIterator : public DatasetBaseIterator { @@ -333,4 +334,5 @@ std::unique_ptr NewParallelMapIterator( std::move(map_func), num_parallel_calls)); } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h index 7e6cc586f3..dc26c5cf25 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.h +++ b/tensorflow/core/kernels/data/parallel_map_iterator.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" namespace tensorflow { +namespace data { // A function that transforms elements of one dataset into another // asynchronously. The arguments are: @@ -47,6 +48,7 @@ std::unique_ptr NewParallelMapIterator( const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, int32 num_parallel_calls); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_ diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index 9057800d94..0cf5db017b 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/util/example_proto_fast_parsing.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -368,5 +368,5 @@ REGISTER_KERNEL_BUILDER(Name("ParseExampleDataset").Device(DEVICE_CPU), ParseExampleDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc index b3272f6bcd..533d0bd5d2 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner.cc +++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/prefetch_autotuner.h" namespace tensorflow { +namespace data { PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size) : buffer_limit_(initial_buffer_size) { @@ -43,4 +44,5 @@ void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) { } } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.h b/tensorflow/core/kernels/data/prefetch_autotuner.h index fa8a184072..8693205512 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner.h +++ b/tensorflow/core/kernels/data/prefetch_autotuner.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace tensorflow { +namespace data { // PrefetchAutotuner dynamically adjusts the buffer size of a prefetch iterator. // @@ -66,6 +67,7 @@ class PrefetchAutotuner { Mode mode_ = Mode::kDisabled; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_ diff --git a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc index 29a8cc50cd..cfc324fc7e 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc +++ b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { +namespace data { namespace { TEST(PrefetchAutotuner, Disabled) { @@ -79,4 +80,5 @@ TEST(PrefetchAutotuner, EnabledSteady) { } } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 50efbcbe2a..a7a2935195 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/core/error_codes.pb.h" namespace tensorflow { +namespace data { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. @@ -346,6 +347,7 @@ void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, *output = new Dataset(ctx, input, buffer_size); } +namespace { REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU), PrefetchDatasetOp); REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") @@ -354,4 +356,7 @@ REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") .HostMemory("input_dataset") .HostMemory("handle"), PrefetchDatasetOp); +} // namespace + +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.h b/tensorflow/core/kernels/data/prefetch_dataset_op.h index c40c4b00da..588fb25a06 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.h +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/prefetch_autotuner.h" namespace tensorflow { +namespace data { class PrefetchDatasetOp : public UnaryDatasetOpKernel { public: @@ -34,6 +35,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { class Dataset; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc index 7817170e73..044a791a3f 100644 --- a/tensorflow/core/kernels/data/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/random_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random_distributions.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -151,5 +151,5 @@ REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU), RandomDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index aa38775125..89fbaae369 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -142,5 +142,5 @@ REGISTER_KERNEL_BUILDER(Name("RangeDataset").Device(DEVICE_CPU), RangeDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc index 086b552936..c474cb4773 100644 --- a/tensorflow/core/kernels/data/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_inputstream.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -691,5 +691,5 @@ REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU), TFRecordDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 299949b99f..94e96635ab 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -250,5 +250,5 @@ REGISTER_KERNEL_BUILDER(Name("RepeatDataset").Device(DEVICE_CPU), RepeatDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index fccad933d0..6e515d6cc8 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -279,5 +279,5 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("ScanDataset").Device(DEVICE_CPU), ScanDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 93a4376836..66466d6a36 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { - +namespace data { namespace { const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds. @@ -620,5 +620,5 @@ REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU), ShuffleAndRepeatDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc index e785b8b4d5..5b084a16f0 100644 --- a/tensorflow/core/kernels/data/single_threaded_executor.cc +++ b/tensorflow/core/kernels/data/single_threaded_executor.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { +namespace data { namespace { typedef gtl::InlinedVector TensorValueVec; @@ -375,4 +376,5 @@ Status NewSingleThreadedExecutor(const LocalExecutorParams& params, return Status::OK(); } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/single_threaded_executor.h b/tensorflow/core/kernels/data/single_threaded_executor.h index 15836b24c9..e934352a1d 100644 --- a/tensorflow/core/kernels/data/single_threaded_executor.h +++ b/tensorflow/core/kernels/data/single_threaded_executor.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/executor.h" namespace tensorflow { +namespace data { // Creates a new `Executor` for executing `graph` synchronously on the caller // thread. @@ -55,6 +56,7 @@ Status NewSingleThreadedExecutor(const LocalExecutorParams& params, std::unique_ptr graph, Executor** executor); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc index f8b5769197..6244e287bb 100644 --- a/tensorflow/core/kernels/data/single_threaded_executor_test.cc +++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" namespace tensorflow { +namespace data { namespace { class ExecutorTest : public ::testing::Test { @@ -327,4 +328,5 @@ BENCHMARK(BM_FeedInputFetchOutput); #endif } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index fe7ef38d5f..b8c7fb15f4 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -187,5 +187,5 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), SkipDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc index 14df3a6801..1e73cfc753 100644 --- a/tensorflow/core/kernels/data/slide_dataset_op.cc +++ b/tensorflow/core/kernels/data/slide_dataset_op.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -293,5 +293,5 @@ REGISTER_KERNEL_BUILDER(Name("SlideDataset").Device(DEVICE_CPU), SlideDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc index e526578701..85b1e50695 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/util/sparse/sparse_tensor.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -274,5 +274,5 @@ TF_CALL_DATASET_TYPES(REGISTER_DATASET_KERNEL); #undef REGISTER_DATASET_KERNEL } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sql/driver_manager.cc b/tensorflow/core/kernels/data/sql/driver_manager.cc index ffabda1a8a..783d1e6cb2 100644 --- a/tensorflow/core/kernels/data/sql/driver_manager.cc +++ b/tensorflow/core/kernels/data/sql/driver_manager.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/sql/sqlite_query_connection.h" namespace tensorflow { - +namespace data { namespace sql { std::unique_ptr DriverManager::CreateQueryConnection( @@ -30,5 +30,5 @@ std::unique_ptr DriverManager::CreateQueryConnection( } } // namespace sql - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sql/driver_manager.h b/tensorflow/core/kernels/data/sql/driver_manager.h index a34691b5a2..c5428f396b 100644 --- a/tensorflow/core/kernels/data/sql/driver_manager.h +++ b/tensorflow/core/kernels/data/sql/driver_manager.h @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/sql/query_connection.h" namespace tensorflow { - +namespace data { namespace sql { // A factory class for creating `QueryConnection` instances. @@ -35,7 +35,7 @@ class DriverManager { }; } // namespace sql - +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_ diff --git a/tensorflow/core/kernels/data/sql/query_connection.h b/tensorflow/core/kernels/data/sql/query_connection.h index e9ffca202f..2fd229a9bf 100644 --- a/tensorflow/core/kernels/data/sql/query_connection.h +++ b/tensorflow/core/kernels/data/sql/query_connection.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" namespace tensorflow { +namespace data { class IteratorContext; @@ -63,7 +64,7 @@ class QueryConnection { }; } // namespace sql - +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_ diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc index 7cd07bd8ec..5108e83976 100644 --- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc +++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" namespace tensorflow { - +namespace data { namespace sql { SqliteQueryConnection::SqliteQueryConnection() {} @@ -115,5 +115,5 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry( } } // namespace sql - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h index 81b19530b7..175492c49d 100644 --- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h +++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace tensorflow { - +namespace data { namespace sql { class SqliteQueryConnection : public QueryConnection { @@ -50,7 +50,7 @@ class SqliteQueryConnection : public QueryConnection { }; } // namespace sql - +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_ diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc index 2aa153fcfa..6bbe459332 100644 --- a/tensorflow/core/kernels/data/sql_dataset_ops.cc +++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc @@ -24,8 +24,9 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" namespace tensorflow { - +namespace data { namespace { + // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following ops. @@ -211,5 +212,5 @@ class SqlDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc index 75af73df54..f5314f7a75 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { @@ -135,4 +136,5 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU), SetStatsAggregatorDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc index b133cfab54..a7ded67876 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" namespace tensorflow { +namespace data { namespace { static mutex* get_counters_map_lock() { @@ -145,4 +146,5 @@ REGISTER_KERNEL_BUILDER(Name("StatsAggregatorSummary").Device(DEVICE_CPU), StatsAggregatorSummaryOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc index 8957f5d997..e9e42f05a1 100644 --- a/tensorflow/core/kernels/data/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { // This op defines a `Dataset` that passes through its input elements and @@ -248,4 +249,5 @@ REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU), BytesProducedStatsDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index e5c237dfaa..e5cdfdd732 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -174,5 +174,5 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index 1192fafc4c..e1cefd23d8 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -140,5 +140,5 @@ REGISTER_KERNEL_BUILDER(Name("TensorDataset").Device(DEVICE_CPU), TensorDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc index ccd5e60acc..2ed636a400 100644 --- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { bool IsGreaterEqualToOrCompatibleWith(const PartialTensorShape& a, @@ -648,5 +648,5 @@ REGISTER_KERNEL_BUILDER(Name("EnqueueInQueueDataset").Device(DEVICE_CPU), EnqueueInQueueDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index dc32cd23e5..7dc64b0a75 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -168,5 +168,5 @@ REGISTER_KERNEL_BUILDER(Name("TensorSliceDataset").Device(DEVICE_CPU), TensorSliceDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc index 1a79f72b28..81c432b938 100644 --- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -204,5 +204,5 @@ REGISTER_KERNEL_BUILDER(Name("UnbatchDataset").Device(DEVICE_CPU), UnbatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc index 0ab6beabfc..2ad4711aab 100644 --- a/tensorflow/core/kernels/data/window_dataset.cc +++ b/tensorflow/core/kernels/data/window_dataset.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { +namespace data { namespace { class WindowDataset : public DatasetBase { @@ -107,4 +108,5 @@ Status NewWindowDataset(std::vector> elements, return Status::OK(); } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/window_dataset.h b/tensorflow/core/kernels/data/window_dataset.h index 7bd31a0bc7..84cb3c7860 100644 --- a/tensorflow/core/kernels/data/window_dataset.h +++ b/tensorflow/core/kernels/data/window_dataset.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { +namespace data { // Creates a dataset representing an eagerly-collected window of elements. // @@ -43,6 +44,7 @@ Status NewWindowDataset(std::vector> elements, std::vector output_shapes, DatasetBase** out_dataset); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_ diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index 41bf9d43fe..3975086841 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/window_dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -195,5 +195,5 @@ REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU), WindowDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc index 1c49874a6a..3f76695bb1 100644 --- a/tensorflow/core/kernels/data/writer_ops.cc +++ b/tensorflow/core/kernels/data/writer_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/file_system.h" namespace tensorflow { - +namespace data { namespace { class ToTFRecordOp : public AsyncOpKernel { @@ -104,4 +104,5 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU), ToTFRecordOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc index e4306579ed..61a2078f46 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -175,5 +175,5 @@ class ZipDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("ZipDataset").Device(DEVICE_CPU), ZipDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow -- GitLab From 19ac7a58287b90e1cd73c8e34438a8db915f481b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 19:51:23 -0700 Subject: [PATCH 171/540] changes to ctc_beam_search PiperOrigin-RevId: 211741560 --- tensorflow/core/util/ctc/ctc_beam_entry.h | 2 +- tensorflow/core/util/ctc/ctc_beam_scorer.h | 2 +- tensorflow/core/util/ctc/ctc_beam_search.h | 1 + tensorflow/core/util/ctc/ctc_decoder.h | 2 +- tensorflow/core/util/ctc/ctc_loss_util.h | 2 +- 5 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h index 973e315f09..24002e72a0 100644 --- a/tensorflow/core/util/ctc/ctc_beam_entry.h +++ b/tensorflow/core/util/ctc/ctc_beam_entry.h @@ -1,4 +1,3 @@ -// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ diff --git a/tensorflow/core/util/ctc/ctc_beam_scorer.h b/tensorflow/core/util/ctc/ctc_beam_scorer.h index 1a622babe1..1e45a8abd3 100644 --- a/tensorflow/core/util/ctc/ctc_beam_scorer.h +++ b/tensorflow/core/util/ctc/ctc_beam_scorer.h @@ -1,4 +1,3 @@ -// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange // Collection of scoring classes that can be extended and provided to the // CTCBeamSearchDecoder to incorporate additional scoring logic (such as a diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index 5e2aeb7830..6fbb1ed0da 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h index 3be36822e5..b55d7d77ac 100644 --- a/tensorflow/core/util/ctc/ctc_decoder.h +++ b/tensorflow/core/util/ctc/ctc_decoder.h @@ -1,4 +1,3 @@ -// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ diff --git a/tensorflow/core/util/ctc/ctc_loss_util.h b/tensorflow/core/util/ctc/ctc_loss_util.h index 36be9e92ef..054412d388 100644 --- a/tensorflow/core/util/ctc/ctc_loss_util.h +++ b/tensorflow/core/util/ctc/ctc_loss_util.h @@ -1,4 +1,3 @@ -// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_ -- GitLab From dc38a06da8295f4cc86fa13bb285577aa3f41858 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 6 Sep 2018 03:14:44 +0000 Subject: [PATCH 172/540] Upcast to float for better conversion, based on review feedback. Signed-off-by: Yong Tang --- tensorflow/core/kernels/non_max_suppression_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index c93f668801..81ce6d6e95 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -146,7 +146,7 @@ void DoNonMaxSuppressionOp( std::priority_queue, decltype(cmp)> candidate_priority_queue(cmp); for (int i = 0; i < scores_data.size(); ++i) { - if (scores_data[i] > static_cast(score_threshold)) { + if (static_cast(scores_data[i]) > score_threshold) { candidate_priority_queue.emplace(Candidate({i, scores_data[i]})); } } -- GitLab From 692a14863c0a6c6ed4c5cd0fffb1bfc6630682d8 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 6 Sep 2018 03:15:32 +0000 Subject: [PATCH 173/540] Add default type as DT_FLOAT to maintain backward-compatibility and fix test failure in: ``` //tensorflow/core/ops/compat:backwards_compatibility_test ``` Signed-off-by: Yong Tang --- tensorflow/core/ops/image_ops.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index abb4e6fcf6..5427275284 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -688,7 +688,7 @@ REGISTER_OP("NonMaxSuppressionV2") .Input("max_output_size: int32") .Input("iou_threshold: float") .Output("selected_indices: int32") - .Attr("T: {half, float}") + .Attr("T: {half, float} = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { // Get inputs and validate ranks. ShapeHandle boxes; @@ -718,7 +718,7 @@ REGISTER_OP("NonMaxSuppressionV3") .Input("iou_threshold: float") .Input("score_threshold: float") .Output("selected_indices: int32") - .Attr("T: {half, float}") + .Attr("T: {half, float} = DT_FLOAT") .SetShapeFn(NMSShapeFn); REGISTER_OP("NonMaxSuppressionV4") @@ -729,7 +729,7 @@ REGISTER_OP("NonMaxSuppressionV4") .Input("score_threshold: float") .Output("selected_indices: int32") .Output("valid_outputs: int32") - .Attr("T: {half, float}") + .Attr("T: {half, float} = DT_FLOAT") .Attr("pad_to_max_output_size: bool = false") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(NMSShapeFn(c)); -- GitLab From f4ae136265d3d3116a008b98ccf21d0791b878fd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Sep 2018 20:22:34 -0700 Subject: [PATCH 174/540] Fix ordering of tf.GraphKeys.VARIABLES line in renames_v2.py PiperOrigin-RevId: 211744058 --- tensorflow/tools/compatibility/renames_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py index 29c62763b0..7e66ad816a 100644 --- a/tensorflow/tools/compatibility/renames_v2.py +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -65,9 +65,9 @@ renames = { 'tf.fft': 'tf.spectral.fft', 'tf.floor': 'tf.math.floor', 'tf.gather_nd': 'tf.manip.gather_nd', + 'tf.GraphKeys.VARIABLES': 'tf.GraphKeys.GLOBAL_VARIABLES', 'tf.greater': 'tf.math.greater', 'tf.greater_equal': 'tf.math.greater_equal', - 'tf.GraphKeys.VARIABLES': 'tf.GraphKeys.GLOBAL_VARIABLES', 'tf.ifft': 'tf.spectral.ifft', 'tf.igamma': 'tf.math.igamma', 'tf.igammac': 'tf.math.igammac', -- GitLab From 5393c8f0dc57857c93482bff67f1134aae9af594 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 5 Sep 2018 22:20:40 -0700 Subject: [PATCH 175/540] Add `TraceCollector::IsEnabled(bool)` method in order to test when tracing is enabled. Some builds install a `TraceCollector` at process startup, but it is mostly not enabled. This inhibits the recent optimization to avoid accessing `OpKernel::name()` and `OpKernel::type_string()` every time a kernel is launched. By caching the `TraceCollector` in the `TracingDevice` and adding a method to enquire about its state, we increase the applicability of the optimization. PiperOrigin-RevId: 211752728 --- tensorflow/core/common_runtime/tracing_device.h | 5 ++++- tensorflow/core/platform/default/device_tracer.cc | 5 +++++ tensorflow/core/platform/tracing.h | 4 ++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/tracing_device.h b/tensorflow/core/common_runtime/tracing_device.h index 39215efa35..e1b163074f 100644 --- a/tensorflow/core/common_runtime/tracing_device.h +++ b/tensorflow/core/common_runtime/tracing_device.h @@ -35,8 +35,11 @@ class TracingDevice : public Device { : Device(env, attributes) {} void Compute(OpKernel* op_kernel, OpKernelContext* context) override { + const tracing::TraceCollector* trace_collector = + tracing::GetTraceCollector(); if (TF_PREDICT_FALSE( - tracing::GetTraceCollector() || + (trace_collector && + trace_collector->IsEnabled(op_kernel->IsExpensive())) || tracing::GetEventCollector(tracing::EventCategory::kCompute))) { const string& op_name = op_kernel->name(); tracing::ScopedActivity activity(op_name, op_kernel->type_string(), diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc index ccddf1eafc..0389149469 100644 --- a/tensorflow/core/platform/default/device_tracer.cc +++ b/tensorflow/core/platform/default/device_tracer.cc @@ -321,6 +321,11 @@ class DeviceTracerImpl : public DeviceTracer, return nullptr; } + bool IsEnabled(bool is_expensive) const override { + // We don't do anything with 'Activities' so we are never 'enabled'. + return false; + } + protected: // This callback is used exclusively by CUPTIManager. friend class CUPTIManager; diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h index e5851f1dfe..9974bbbb4e 100644 --- a/tensorflow/core/platform/tracing.h +++ b/tensorflow/core/platform/tracing.h @@ -155,6 +155,10 @@ class TraceCollector { StringPiece name_part1, StringPiece name_part2, bool is_expensive) const = 0; + // Returns true if this activity handle tracking is enabled for an op of the + // given expensiveness. + virtual bool IsEnabled(bool is_expensive) const = 0; + protected: static string ConcatenateNames(StringPiece first, StringPiece second); -- GitLab From e23d522e943309cefae368a11c21ae37b6986165 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Wed, 5 Sep 2018 22:34:52 -0700 Subject: [PATCH 176/540] Allow creating a py EagerTensor that shares the underlying TensorHandle. This is so that gradients with respect to scalars pass (see the test added in backprop_test.py). A micro benchmark just calling constant_op.constant slows down a bit - this is inevitable as we are creating a new python object. After: walltime: ~2.1 Before: walltime: ~1.47 Linear regression benchmark is pretty much unchanged. PiperOrigin-RevId: 211753801 --- tensorflow/c/eager/c_api.cc | 13 +++++++ tensorflow/c/eager/c_api.h | 6 ++++ tensorflow/c/eager/c_api_test.cc | 25 ++++++++++++++ tensorflow/python/eager/backprop_test.py | 40 ++++++++++++++++++++++ tensorflow/python/eager/benchmarks_test.py | 5 +++ tensorflow/python/eager/pywrap_tensor.cc | 30 ++++++++++++++-- tensorflow/python/framework/constant_op.py | 3 +- 7 files changed, 118 insertions(+), 4 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 77e3878a94..349d9bcd7c 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -399,6 +399,19 @@ const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { : d->name().c_str(); } +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status) { + if (h == nullptr || h->handle == nullptr) { + status->status = tensorflow::errors::InvalidArgument( + "The passed in handle is a nullptr"); + return nullptr; + } + + h->handle->Ref(); + + return new TFE_TensorHandle(h->handle); +} + TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { if (h == nullptr || h->handle == nullptr) { status->status = tensorflow::errors::InvalidArgument( diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index eec2750d6e..337447eec9 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -171,6 +171,12 @@ TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( TFE_TensorHandle* h, TF_Status* status); +// Return a pointer to a new TFE_TensorHandle that shares the underlying tensor +// with `h`. On success, `status` is set to OK. On failure, `status` reflects +// the error and a nullptr is returned. +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor( + TFE_TensorHandle* h, TF_Status* status); + // This function will block till the operation that produces `h` has // completed. The memory returned might alias the internal memory used by // TensorFlow. Hence, callers should not mutate this memory (for example by diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 7126227cf5..55331022b9 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -1528,4 +1528,29 @@ TEST(CAPI, StringAttributes) { TFE_DeleteContext(ctx); TF_DeleteStatus(status); } + +TEST(CAPI, TestTFE_TensorHandleCopySharingUnderlyingTensorHandle) { + TFE_TensorHandle* h = TestMatrixTensorHandle(); + EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); + + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + TFE_TensorHandle* h_shares_tensor = + TFE_TensorHandleCopySharingTensor(h, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_Tensor* t = TFE_TensorHandleResolve(h_shares_tensor, status.get()); + ASSERT_EQ(16, TF_TensorByteSize(t)); + float data[4] = {0}; + memcpy(&data[0], TF_TensorData(t), TF_TensorByteSize(t)); + EXPECT_EQ(1.0, data[0]); + EXPECT_EQ(2.0, data[1]); + EXPECT_EQ(3.0, data[2]); + EXPECT_EQ(4.0, data[3]); + TF_DeleteTensor(t); + + TFE_DeleteTensorHandle(h); + TFE_DeleteTensorHandle(h_shares_tensor); +} } // namespace diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index 6673178ee7..3319b440b4 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -957,5 +957,45 @@ class BackpropTest(test.TestCase): self.assertAllEqual(grad1, grad2) + @test_util.run_in_graph_and_eager_modes + def testDifferentiatingScalarCache(self): + # In the following test, if x2 = x1 (i.e the objects are the exact same), + # then y is essentially, 2*x1, and dy/dx1 = 2. + # When we had a pure scalar cache in eager, this would be the case. This + # test prevents us from going back to that case. + with backprop.GradientTape(persistent=False) as g: + x1 = constant_op.constant(3.0) + x2 = constant_op.constant(3.0) + g.watch(x1) + g.watch(x2) + y = x1 + x2 + grad = g.gradient(target=y, sources=[x1]) + self.assertEqual(self.evaluate(grad), [1.0]) + + def testVariablesAndConstantsProduceTheSameGradients(self): + + # In the following test, differentiating [y, z] against [a, b] gives: + # (dy/da + dz/da, dy/db + dz/db). + # If a and b are the same constant, dz/da will not be 0 (which it should + # be). + # This is solved by using variable since doing a read_value on a tensor will + # produce a new tensor and corresponding TensorHandle, and not reuse the + # same tensor (which would happen if we are using a cache and reusing + # EagerTensor objects). + def get_grads(a, b): + with backprop.GradientTape() as tape: + tape.watch([a, b]) + y = a**3 + z = b**2 + return tape.gradient([y, z], [a, b]) + + gradients_constants = get_grads( + constant_op.constant(2.0), constant_op.constant(2.0)) + gradients_variables = get_grads( + resource_variable_ops.ResourceVariable(2.0), + resource_variable_ops.ResourceVariable(2.0)) + self.assertAllEqual(gradients_constants, gradients_variables) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index a2e8422671..3bdaf0b214 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -175,6 +175,11 @@ class MicroBenchmarks(test.Benchmark): self._run(func, 30000) + def benchmark_create_constant(self): + func = lambda: constant_op.constant(3.0) + + self._run(func, 30000) + def benchmark_create_float_tensor_from_list_CPU(self): self._benchmark_create_tensor([[3.0]], dtypes.float32.as_datatype_enum, CPU) diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 86fbd24d68..432dcbc2e2 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -325,12 +325,36 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { PyObject* context = nullptr; PyObject* device = nullptr; PyObject* dtype = Py_None; - const char* kwlist[] = {"value", "context", "device", "dtype", nullptr}; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|O", + PyObject* other_value = nullptr; + const char* kwlist[] = {"value", "context", "device", + "dtype", "other_value", nullptr}; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OO", const_cast(kwlist), &value, &context, - &device, &dtype)) { + &device, &dtype, &other_value)) { return -1; } + + if (other_value != nullptr) { + if (!EagerTensor_CheckExact(other_value)) { + PyErr_SetString(PyExc_TypeError, + tensorflow::strings::StrCat( + "Expecting an EagerTensor for other_value, got ", + Py_TYPE(other_value)->tp_name) + .c_str()); + + return -1; + } + EagerTensor* other = reinterpret_cast(other_value); + self->handle = + TFE_TensorHandleCopySharingTensor(other->handle, self->status); + + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + return -1; + } + + return 0; + } + // Extract dtype int desired_dtype = -1; if (dtype != Py_None) { diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index eca34ac26e..4b2706d4cf 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -105,7 +105,8 @@ def convert_to_eager_tensor(value, ctx, dtype=None): scalar_cache = ctx.scalar_cache() tensor = scalar_cache.get(cache_key, None) if tensor is not None: - return tensor + return ops.EagerTensor( + value, context=handle, device=device, dtype=dtype, other_value=tensor) t = ops.EagerTensor(value, context=handle, device=device, dtype=dtype) scalar_cache[cache_key] = t return t -- GitLab From 830c8a480a4a65540e60b638cd73b50801408c9b Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Wed, 5 Sep 2018 23:01:27 -0700 Subject: [PATCH 177/540] [FLR] Simplify the Run() (custom callframe) implementation. Profiling showed that we were wastefully (i) heap-allocating and freeing an Executor::Args object on each call, and (as a result) (ii) incurring extra function dispatch overhead in the callback. PiperOrigin-RevId: 211755493 --- tensorflow/core/common_runtime/function.cc | 33 ++++++++-------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 46bb8d92f8..b00e526309 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -925,29 +925,18 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, } DCHECK(run_opts.runner != nullptr); - Executor::Args* exec_args = new Executor::Args; + Executor::Args exec_args; // Inherit the step_id from the caller. - exec_args->step_id = run_opts.step_id; - exec_args->rendezvous = run_opts.rendezvous; - exec_args->stats_collector = run_opts.stats_collector; - exec_args->cancellation_manager = run_opts.cancellation_manager; - exec_args->collective_executor = run_opts.collective_executor; - exec_args->step_container = run_opts.step_container; - exec_args->runner = *run_opts.runner; - exec_args->call_frame = frame; - - item->exec->RunAsync( - // Executor args - *exec_args, - // Done callback. - std::bind( - [item, frame, exec_args](DoneCallback done, - // Start unbound arguments. - const Status& status) { - delete exec_args; - done(status); - }, - std::move(done), std::placeholders::_1)); + exec_args.step_id = run_opts.step_id; + exec_args.rendezvous = run_opts.rendezvous; + exec_args.stats_collector = run_opts.stats_collector; + exec_args.cancellation_manager = run_opts.cancellation_manager; + exec_args.collective_executor = run_opts.collective_executor; + exec_args.step_container = run_opts.step_container; + exec_args.runner = *run_opts.runner; + exec_args.call_frame = frame; + + item->exec->RunAsync(exec_args, std::move(done)); } bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) { -- GitLab From c200cecbec679cc9dbb219fd06663232f18470ff Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 6 Sep 2018 00:36:01 -0700 Subject: [PATCH 178/540] Parse feature_group_count attributes of CustomCall ops. PiperOrigin-RevId: 211762464 --- tensorflow/compiler/xla/service/hlo_parser.cc | 6 ++++++ tensorflow/compiler/xla/service/hlo_parser_test.cc | 8 ++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 0f26ed4235..7c848ba7b4 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1248,11 +1248,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional custom_call_target; optional window; optional dnums; + optional feature_group_count; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/false, AttrTy::kConvolutionDimensionNumbers, &dnums}; + attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, + &feature_group_count}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -1264,6 +1267,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (dnums.has_value()) { instruction->set_convolution_dimension_numbers(*dnums); } + if (feature_group_count.has_value()) { + instruction->set_feature_group_count(*feature_group_count); + } break; } case HloOpcode::kDot: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 0dfc0a4d1c..43e8736532 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1123,13 +1123,13 @@ ENTRY Iota { )" }, -// custom-call with window and dim_labels +// custom-call with window, dim_labels and feature_group_count { -"CustomCallWithWindowAndDimLabels", -R"(HloModule CustomCallWithWindowAndDimLabels +"CustomCallWithWindowAndDimLabelsAndFeatureGroupCount", +R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount ENTRY Computation { - ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, custom_call_target="target" + ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target" } )" -- GitLab From 3b34d4fa50f421022a8eb83f51660d22862557d2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 6 Sep 2018 02:01:00 -0700 Subject: [PATCH 179/540] compat: Update forward compatibility horizon to 2018-09-06 PiperOrigin-RevId: 211770067 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 586f4c6936..118339bfaf 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -26,7 +26,7 @@ import datetime from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 5) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 6) @tf_export("compat.forward_compatible") -- GitLab From d41f5ffb9cdc1c047db2f7b8a71ef24d39d12fb0 Mon Sep 17 00:00:00 2001 From: Loo Rong Jie Date: Wed, 4 Jul 2018 09:04:57 +0800 Subject: [PATCH 180/540] [Bazel/MSVC] Enable jpeg SIMD for MSVC - Add config/msvc.h when building nasm on Windows - Update Windows SIMD for libjpeg-turbo 2.0.0 - Add missing source files --- third_party/jpeg/jpeg.BUILD | 139 +++++++++++++++++++++++++++++++++++- third_party/nasm.BUILD | 5 +- 2 files changed, 141 insertions(+), 3 deletions(-) diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/jpeg.BUILD index 5edf4f8120..1b9b9bf2f5 100644 --- a/third_party/jpeg/jpeg.BUILD +++ b/third_party/jpeg/jpeg.BUILD @@ -11,8 +11,8 @@ libjpegturbo_nocopts = "-[W]error" WIN_COPTS = [ "/Ox", - "/w14711", # function 'function' selected for inline expansion - "/w14710", # 'function' : function not inlined + "-DWITH_SIMD", + "-wd4996", ] libjpegturbo_copts = select({ @@ -127,6 +127,7 @@ cc_library( ":armeabi-v7a": [":simd_armv7a"], ":arm64-v8a": [":simd_armv8a"], ":linux_ppc64le": [":simd_altivec"], + ":windows": [":simd_win_x86_64"], "//conditions:default": [":simd_none"], }), ) @@ -350,6 +351,140 @@ cc_library( nocopts = libjpegturbo_nocopts, ) +cc_library( + name = "simd_win_x86_64", + srcs = [ + "jchuff.h", + "jconfig.h", + "jconfigint.h", + "jdct.h", + "jerror.h", + "jinclude.h", + "jmorecfg.h", + "jpegint.h", + "jpeglib.h", + "jsimd.h", + "jsimddct.h", + "simd/jsimd.h", + "simd/x86_64/jsimd.c", + "simd/x86_64/jccolor-avx2.obj", + "simd/x86_64/jccolor-sse2.obj", + "simd/x86_64/jcgray-avx2.obj", + "simd/x86_64/jcgray-sse2.obj", + "simd/x86_64/jchuff-sse2.obj", + "simd/x86_64/jcphuff-sse2.obj", + "simd/x86_64/jcsample-avx2.obj", + "simd/x86_64/jcsample-sse2.obj", + "simd/x86_64/jdcolor-avx2.obj", + "simd/x86_64/jdcolor-sse2.obj", + "simd/x86_64/jdmerge-avx2.obj", + "simd/x86_64/jdmerge-sse2.obj", + "simd/x86_64/jdsample-avx2.obj", + "simd/x86_64/jdsample-sse2.obj", + "simd/x86_64/jfdctflt-sse.obj", + "simd/x86_64/jfdctfst-sse2.obj", + "simd/x86_64/jfdctint-avx2.obj", + "simd/x86_64/jfdctint-sse2.obj", + "simd/x86_64/jidctflt-sse2.obj", + "simd/x86_64/jidctfst-sse2.obj", + "simd/x86_64/jidctint-avx2.obj", + "simd/x86_64/jidctint-sse2.obj", + "simd/x86_64/jidctred-sse2.obj", + "simd/x86_64/jquantf-sse2.obj", + "simd/x86_64/jquanti-avx2.obj", + "simd/x86_64/jquanti-sse2.obj", + "simd/x86_64/jsimdcpu.obj", + ], + copts = libjpegturbo_copts, +) + +genrule( + name = "simd_win_x86_64_assemble", + srcs = [ + "jconfig.h", + "jconfigint.h", + "simd/x86_64/jccolext-avx2.asm", + "simd/x86_64/jccolext-sse2.asm", + "simd/x86_64/jccolor-avx2.asm", + "simd/x86_64/jccolor-sse2.asm", + "simd/x86_64/jcgray-avx2.asm", + "simd/x86_64/jcgray-sse2.asm", + "simd/x86_64/jcgryext-avx2.asm", + "simd/x86_64/jcgryext-sse2.asm", + "simd/x86_64/jchuff-sse2.asm", + "simd/x86_64/jcphuff-sse2.asm", + "simd/x86_64/jcsample-avx2.asm", + "simd/x86_64/jcsample-sse2.asm", + "simd/x86_64/jdcolext-avx2.asm", + "simd/x86_64/jdcolext-sse2.asm", + "simd/x86_64/jdcolor-avx2.asm", + "simd/x86_64/jdcolor-sse2.asm", + "simd/x86_64/jdmerge-avx2.asm", + "simd/x86_64/jdmerge-sse2.asm", + "simd/x86_64/jdmrgext-avx2.asm", + "simd/x86_64/jdmrgext-sse2.asm", + "simd/x86_64/jdsample-avx2.asm", + "simd/x86_64/jdsample-sse2.asm", + "simd/x86_64/jfdctflt-sse.asm", + "simd/x86_64/jfdctfst-sse2.asm", + "simd/x86_64/jfdctint-avx2.asm", + "simd/x86_64/jfdctint-sse2.asm", + "simd/x86_64/jidctflt-sse2.asm", + "simd/x86_64/jidctfst-sse2.asm", + "simd/x86_64/jidctint-avx2.asm", + "simd/x86_64/jidctint-sse2.asm", + "simd/x86_64/jidctred-sse2.asm", + "simd/x86_64/jquantf-sse2.asm", + "simd/x86_64/jquanti-avx2.asm", + "simd/x86_64/jquanti-sse2.asm", + "simd/x86_64/jsimdcpu.asm", + "simd/nasm/jcolsamp.inc", + "simd/nasm/jdct.inc", + "simd/nasm/jpeg_nbits_table.inc", + "simd/nasm/jsimdcfg.inc", + "simd/nasm/jsimdcfg.inc.h", + "simd/nasm/jsimdext.inc", + ], + outs = [ + "simd/x86_64/jccolor-avx2.obj", + "simd/x86_64/jccolor-sse2.obj", + "simd/x86_64/jcgray-avx2.obj", + "simd/x86_64/jcgray-sse2.obj", + "simd/x86_64/jchuff-sse2.obj", + "simd/x86_64/jcphuff-sse2.obj", + "simd/x86_64/jcsample-avx2.obj", + "simd/x86_64/jcsample-sse2.obj", + "simd/x86_64/jdcolor-avx2.obj", + "simd/x86_64/jdcolor-sse2.obj", + "simd/x86_64/jdmerge-avx2.obj", + "simd/x86_64/jdmerge-sse2.obj", + "simd/x86_64/jdsample-avx2.obj", + "simd/x86_64/jdsample-sse2.obj", + "simd/x86_64/jfdctflt-sse.obj", + "simd/x86_64/jfdctfst-sse2.obj", + "simd/x86_64/jfdctint-avx2.obj", + "simd/x86_64/jfdctint-sse2.obj", + "simd/x86_64/jidctflt-sse2.obj", + "simd/x86_64/jidctfst-sse2.obj", + "simd/x86_64/jidctint-avx2.obj", + "simd/x86_64/jidctint-sse2.obj", + "simd/x86_64/jidctred-sse2.obj", + "simd/x86_64/jquantf-sse2.obj", + "simd/x86_64/jquanti-avx2.obj", + "simd/x86_64/jquanti-sse2.obj", + "simd/x86_64/jsimdcpu.obj", + ], + cmd = "for out in $(OUTS); do\n" + + " $(location @nasm//:nasm) -fwin64 -DWIN64 -D__x86_64__" + + " -I $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/" + + " -I $$(dirname $(location simd/nasm/jdct.inc))/" + + " -I $$(dirname $(location simd/nasm/jdct.inc))/../../win/" + + " -o $$out" + + " $$(dirname $(location simd/x86_64/jccolext-sse2.asm))/$$(basename $${out%.obj}.asm)\n" + + "done", + tools = ["@nasm"], +) + cc_library( name = "simd_none", srcs = [ diff --git a/third_party/nasm.BUILD b/third_party/nasm.BUILD index 2b877883b9..d746a65e7e 100644 --- a/third_party/nasm.BUILD +++ b/third_party/nasm.BUILD @@ -133,7 +133,10 @@ cc_binary( "x86/regs.c", "x86/regs.h", "x86/regvals.c", - ], + ] + select({ + ":windows": ["config/msvc.h"], + "//conditions:default": [], + }), includes = [ "asm", "include", -- GitLab From f936cfa5498dc386242935a71b154b3c2f78579d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 6 Sep 2018 06:31:09 -0700 Subject: [PATCH 181/540] Documentation fixes for segment_* and unsorted_segment_* ops RELNOTES: n/a PiperOrigin-RevId: 211798876 --- .../api_def/base_api/api_def_SegmentMax.pbtxt | 2 +- .../base_api/api_def_SegmentMean.pbtxt | 2 +- .../api_def/base_api/api_def_SegmentMin.pbtxt | 2 +- .../base_api/api_def_SegmentProd.pbtxt | 2 +- .../api_def/base_api/api_def_SegmentSum.pbtxt | 2 +- .../base_api/api_def_UnsortedSegmentMax.pbtxt | 16 ++++--- .../base_api/api_def_UnsortedSegmentMin.pbtxt | 15 +++--- .../api_def_UnsortedSegmentProd.pbtxt | 15 +++--- .../base_api/api_def_UnsortedSegmentSum.pbtxt | 2 +- tensorflow/python/ops/math_ops.py | 48 +++++++++++++------ 10 files changed, 66 insertions(+), 40 deletions(-) diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt index 35f55fe106..d33a36ce06 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <::lowest()`. +If the given segment ID `i` is negative, then the corresponding value is +dropped, and will not be included in the result. +
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt index 37dd973b23..7e139ddf4d 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt @@ -3,15 +3,15 @@ op { in_arg { name: "segment_ids" description: <::max()`. + +If the given segment ID `i` is negative, then the corresponding value is +dropped, and will not be included in the result. END } diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt index efbc023705..9c8ea3b620 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt @@ -3,15 +3,15 @@ op { in_arg { name: "segment_ids" description: < Date: Thu, 6 Sep 2018 06:31:17 -0700 Subject: [PATCH 182/540] Documentation fix for tf.regex_full_match RELNOTES: n/a PiperOrigin-RevId: 211798892 --- tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt index 8cef243aee..30fd97a0d7 100644 --- a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt @@ -9,7 +9,7 @@ END in_arg { name: "pattern" description: < Date: Thu, 6 Sep 2018 07:29:33 -0700 Subject: [PATCH 183/540] Documentation fix for TensorShape.__getitem__ RELNOTES: n/a PiperOrigin-RevId: 211804843 --- tensorflow/python/framework/tensor_shape.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py index 11b681d544..3c2a736fb9 100644 --- a/tensorflow/python/framework/tensor_shape.py +++ b/tensorflow/python/framework/tensor_shape.py @@ -606,8 +606,8 @@ class TensorShape(object): slice. Raises: - ValueError: If `key` is a slice, and any of its elements are negative, or - if `self` is completely unknown and the step is set. + ValueError: If `key` is a slice and `self` is completely unknown and + the step is set. """ if self._dims is not None: if isinstance(key, slice): -- GitLab From 8859ee06cc0cba03d05ce9677b05ff1993c34b03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yan=20Facai=20=28=E9=A2=9C=E5=8F=91=E6=89=8D=29?= Date: Thu, 6 Sep 2018 22:45:25 +0800 Subject: [PATCH 184/540] TST: add more test cases --- .../kernel_tests/broadcast_to_ops_test.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py index 282a619094..8bcf27466c 100644 --- a/tensorflow/python/kernel_tests/broadcast_to_ops_test.py +++ b/tensorflow/python/kernel_tests/broadcast_to_ops_test.py @@ -82,8 +82,8 @@ class BroadcastToTest(test_util.TensorFlowTestCase): # check shape inference when shape input is constant self.assertAllEqual(shape, v_np.shape) - def testGradient(self): - x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32) + def testGradientForScalar(self): + x = constant_op.constant(1, dtype=dtypes.float32) v = array_ops.broadcast_to(x, [2, 4, 3]) out = 2 * v with self.test_session(): @@ -91,9 +91,29 @@ class BroadcastToTest(test_util.TensorFlowTestCase): out, out.get_shape()) self.assertLess(err, 1e-4) - def testGradientForScalar(self): - x = constant_op.constant(1, dtype=dtypes.float32) - v = array_ops.broadcast_to(x, [2, 4, 3]) + def testGradientWithSameRank(self): + x = constant_op.constant(np.reshape(np.arange(6), (2, 1, 3)), + dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [2, 5, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientWithIncreasingRank(self): + x = constant_op.constant([[1], [2]], + dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [5, 2, 3]) + out = 2 * v + with self.test_session(): + err = gradient_checker.compute_gradient_error(x, x.get_shape(), + out, out.get_shape()) + self.assertLess(err, 1e-4) + + def testGradientWithBroadcastAllDimensions(self): + x = constant_op.constant([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32) + v = array_ops.broadcast_to(x, [5, 4, 6]) out = 2 * v with self.test_session(): err = gradient_checker.compute_gradient_error(x, x.get_shape(), -- GitLab From 35f28c57da8aad4a79503db955b11fed63b1fe34 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 6 Sep 2018 08:45:54 -0700 Subject: [PATCH 185/540] Add a command line option to serialize api-reference resolver. PiperOrigin-RevId: 211813852 --- tensorflow/tools/docs/generate_lib.py | 10 ++++++++++ tensorflow/tools/docs/parser.py | 7 ++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 483921fc2f..7db89f7d24 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -548,6 +548,13 @@ class DocGenerator(object): help='The path from the site-root to api_docs' 'directory for this project') + self.argument_parser.add_argument( + '--api_cache_out_path', + type=str, + default=None, + help='Path to store a json-serialized api-index, so links can be ' + 'inserted into docs without rebuilding the api_docs') + def add_output_dir_argument(self): self.argument_parser.add_argument( '--output_dir', @@ -648,6 +655,9 @@ class DocGenerator(object): visitor = self.run_extraction() reference_resolver = self.make_reference_resolver(visitor, doc_index) + if getattr(flags, 'api_cache_out_path', None): + reference_resolver.to_json_file(flags.api_cache_out_path) + # Build the guide_index for the api_docs back links. root_title = getattr(flags, 'root_title', 'TensorFlow') guide_index = _build_guide_index( diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 549056c6c4..4afb61e365 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -153,6 +153,7 @@ class ReferenceResolver(object): self._doc_index = doc_index self._is_class = is_class self._is_module = is_module + self._all_names = set(is_class.keys()) self._py_module_names = py_module_names @@ -210,6 +211,10 @@ class ReferenceResolver(object): Args: filepath: The file path to write the json to. """ + try: + os.makedirs(os.path.dirname(filepath)) + except OSError: + pass json_dict = {} for key, value in self.__dict__.items(): # Drop these two fields. `_doc_index` is not serializable. `_all_names` is @@ -223,7 +228,7 @@ class ReferenceResolver(object): json_dict[key.lstrip('_')] = value with open(filepath, 'w') as f: - json.dump(json_dict, f) + json.dump(json_dict, f, indent=2, sort_keys=True) def replace_references(self, string, relative_path_to_root): """Replace "@{symbol}" references with links to symbol's documentation page. -- GitLab From a41e270641f0613413e1929c9010f32882b4d26b Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Thu, 6 Sep 2018 08:56:46 -0700 Subject: [PATCH 186/540] Add HloSchedule to HloModule. Add HloSchedule as a field on HloModule. This will enable scheduling to be a normal HLO pass and enable some passes such as copy insertion to more easily use tighter instruction live ranges based on the schedule. This change required adding HloSchedule to the "hlo" library because of circular dependencies. Nothing except for tests actually sets the schedule at the moment, but follow up cls will add a scheduling pass which will do so. PiperOrigin-RevId: 211815293 --- tensorflow/compiler/xla/service/BUILD | 30 ++--- tensorflow/compiler/xla/service/gpu/BUILD | 1 - tensorflow/compiler/xla/service/hlo.proto | 26 +++-- .../compiler/xla/service/hlo_computation.cc | 12 +- .../compiler/xla/service/hlo_computation.h | 5 + tensorflow/compiler/xla/service/hlo_module.cc | 33 +++++- tensorflow/compiler/xla/service/hlo_module.h | 20 ++++ .../compiler/xla/service/hlo_module_test.cc | 59 ++++++++++ .../compiler/xla/service/hlo_ordering.cc | 17 --- .../compiler/xla/service/hlo_ordering.h | 4 - tensorflow/compiler/xla/service/hlo_parser.cc | 33 +++++- .../compiler/xla/service/hlo_parser_test.cc | 104 +++++++++++++++++- .../compiler/xla/service/hlo_proto_util.cc | 3 - .../compiler/xla/service/hlo_schedule.cc | 52 +++++++++ .../compiler/xla/service/hlo_schedule.h | 13 ++- 15 files changed, 346 insertions(+), 66 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index ab86dce510..b8ee6a093e 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -291,6 +291,7 @@ cc_library( "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", + "hlo_schedule.cc", "hlo_sharding.cc", ], hdrs = [ @@ -303,6 +304,7 @@ cc_library( "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", + "hlo_schedule.h", "hlo_sharding.h", ], deps = [ @@ -331,6 +333,8 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) @@ -1037,7 +1041,6 @@ tf_cc_test( ":flatten_call_graph", ":hlo", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1065,7 +1068,6 @@ cc_library( ":hlo", ":hlo_dataflow_analysis", ":hlo_proto", - ":hlo_schedule", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1086,7 +1088,6 @@ tf_cc_test( ":hlo", ":hlo_dataflow_analysis", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1108,7 +1109,6 @@ cc_library( ":hlo", ":hlo_ordering", ":hlo_proto", - ":hlo_schedule", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1177,22 +1177,6 @@ cc_library( ], ) -cc_library( - name = "hlo_schedule", - srcs = ["hlo_schedule.cc"], - hdrs = ["hlo_schedule.h"], - deps = [ - ":hlo", - "//tensorflow/compiler/xla:status", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib_internal", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - ], -) - tf_cc_test( name = "hlo_schedule_test", srcs = ["hlo_schedule_test.cc"], @@ -1202,7 +1186,6 @@ tf_cc_test( ":hlo_dce", ":hlo_ordering", ":hlo_parser", - ":hlo_schedule", ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", @@ -1222,7 +1205,6 @@ cc_library( ":heap_simulator", ":hlo", ":hlo_ordering", - ":hlo_schedule", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -1969,6 +1951,8 @@ tf_cc_test( srcs = ["hlo_module_test.cc"], deps = [ ":hlo", + ":hlo_matchers", + ":hlo_parser", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1977,6 +1961,7 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", + "//tensorflow/core:test", "@com_google_absl//absl/memory", "@com_google_absl//absl/types:span", ], @@ -2413,7 +2398,6 @@ cc_library( ":hlo", ":hlo_dce", ":hlo_ordering", - ":hlo_schedule", ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 13ccff35f8..a68b7a1bef 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -813,7 +813,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_reachability", - "//tensorflow/compiler/xla/service:hlo_schedule", "//tensorflow/compiler/xla/service:hlo_scheduling", "@com_google_absl//absl/memory", ], diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 99d0cf50ca..93ec2c9438 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -199,6 +199,17 @@ message HloComputationProto { int64 root_id = 6; } +// Serialization of an HLO schedule. An HLO schedule contains a total order of +// instructions for each non-fusion computation in the module. +message HloScheduleProto { + message InstructionSequence { + repeated int64 instruction_ids = 1; + } + + // Map from computation id to sequence. + map sequences = 1; +} + // Serialization of HloModule. message HloModuleProto { string name = 1; @@ -214,16 +225,9 @@ message HloModuleProto { // The id of this module. int64 id = 5; -} -// Serialization of HloOrdering. -message HloOrderingProto { - // NOTE: currently only sequential orderings are serialized. - message SequentialComputation { - string computation_name = 1; - repeated string instruction_names = 2; - } - repeated SequentialComputation sequential_computations = 1; + // The schedule for this module. + HloScheduleProto schedule = 7; } // Serialization of LogicalBuffer. @@ -322,8 +326,10 @@ message BufferAssignmentProto { // Grouping message that contains all of the information above. message HloProto { + reserved 2; + reserved "hlo_ordering"; + HloModuleProto hlo_module = 1; - HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index fe7f2be888..233d2199d1 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -464,6 +464,14 @@ std::vector HloComputation::MakeEmbeddedComputationsList() } string HloComputation::ToString(const HloPrintOptions& options) const { + return ToString(options, MakeInstructionPostOrder()); +} + +string HloComputation::ToString( + const HloPrintOptions& options, + absl::Span instruction_order) const { + CHECK_EQ(instruction_order.size(), instruction_count()); + std::ostringstream s; for (int i = 0; i < options.indent_amount(); i++) { s << " "; @@ -486,7 +494,9 @@ string HloComputation::ToString(const HloPrintOptions& options) const { new_options.set_indent_amount(options.indent_amount() + 1) .set_is_in_nested_computation(true); CanonicalNameMap name_map; - for (const HloInstruction* instruction : MakeInstructionPostOrder()) { + for (const HloInstruction* instruction : instruction_order) { + CHECK_EQ(this, instruction->parent()); + for (int i = 0; i < new_options.indent_amount(); i++) { s << " "; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index fe2d3bbbe5..91c5234a6f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -170,6 +170,11 @@ class HloComputation { string ToString() const { return ToString(HloPrintOptions()); } string ToString(const HloPrintOptions& options) const; + // Overload which accepts an order to emit the instructions in. + string ToString( + const HloPrintOptions& options, + absl::Span instruction_order) const; + // Returns a serialized representation of this computation. HloComputationProto ToProto() const; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 3a1bc4e328..cfe906d9c5 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -50,6 +51,13 @@ StatusOr HloModule::LaunderConstInstructionFromModule( return const_cast(hlo); } +Status HloModule::set_schedule(HloSchedule schedule) { + TF_RET_CHECK(schedule.module() == this); + TF_RETURN_IF_ERROR(schedule.Verify()); + schedule_ = std::move(schedule); + return Status::OK(); +} + HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, bool uniquify_names) { @@ -198,12 +206,23 @@ void HloModule::ReplaceComputations( string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; - s << "HloModule " << name() << "\n\n"; + s << "HloModule " << name(); + if (has_schedule()) { + TF_CHECK_OK(schedule().Verify()); + s << ", is_scheduled=true"; + } + s << "\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { if (computation == entry_computation()) { s << "ENTRY "; } - s << computation->ToString(options) << "\n\n"; + if (has_schedule() && schedule().is_computation_scheduled(computation)) { + s << computation->ToString( + options, schedule().sequence(computation).instructions()) + << "\n\n"; + } else { + s << computation->ToString(options) << "\n\n"; + } } return s.str(); } @@ -221,6 +240,9 @@ HloModuleProto HloModule::ToProto() const { } proto.add_computations()->Swap(&computation_proto); } + if (has_schedule()) { + *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); + } return proto; } @@ -309,6 +331,13 @@ StatusOr> HloModule::CreateFromProto( } } + if (proto.has_schedule()) { + TF_ASSIGN_OR_RETURN( + HloSchedule schedule, + HloSchedule::CreateFromProto(module.get(), proto.schedule())); + TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule))); + } + return std::move(module); } diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 3c3371426b..26fd1b2438 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" @@ -32,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -235,6 +237,19 @@ class HloModule { StatusOr LaunderConstInstructionFromModule( const HloInstruction* hlo); + // Sets the schedule of the module to the given schedule. + Status set_schedule(HloSchedule schedule); + + // Clears the schedule of the module. + void clear_schedule() { schedule_.reset(); } + + // Returns true if the module has a schedule set. + bool has_schedule() const { return schedule_.has_value(); } + + // Returns the schedue of the module. CHECK fails if no schedule is set. + const HloSchedule& schedule() const { return *schedule_; } + HloSchedule& schedule() { return *schedule_; } + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -262,6 +277,11 @@ class HloModule { static std::atomic next_unique_module_id_; // A unique id to label modules with. int unique_id_; + + // The HloSchedule of the module. The schedule if it exists contains a + // sequential order of instructions for each non-fusion computation in the + // module. + absl::optional schedule_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 4bc1bacd7d..400bd4d947 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -19,9 +19,12 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/test.h" @@ -30,6 +33,8 @@ namespace xla { namespace { +namespace op = ::xla::testing::opcode_matchers; + class HloModuleTest : public HloTestBase { protected: HloModuleTest() {} @@ -194,6 +199,60 @@ TEST_F(HloModuleTest, UniqueModuleId) { EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } +TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_FALSE(module_copy->has_schedule()); +} + +TEST_F(HloModuleTest, ProtoSerializationWithSchedule) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module_copy, + HloModule::CreateFromProto(module->ToProto(), module->config())); + ASSERT_TRUE(module_copy->has_schedule()); + TF_ASSERT_OK(module_copy->schedule().Verify()); + EXPECT_EQ(module_copy->schedule().sequences().size(), 1); + ASSERT_TRUE(module_copy->schedule().is_computation_scheduled( + module_copy->entry_computation())); + EXPECT_THAT( + module_copy->schedule() + .sequence(module_copy->entry_computation()) + .instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 2105f7a349..f1dc08bafa 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -293,23 +293,6 @@ bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b, !LiveRangeStrictlyBefore(b, a, dataflow); } -HloOrderingProto HloOrdering::ToProto() const { - HloOrderingProto proto; - for (const auto& computation : module_->computations()) { - const std::vector* sequence = - SequentialOrder(*computation); - if (sequence != nullptr) { - HloOrderingProto::SequentialComputation* proto_computation = - proto.add_sequential_computations(); - proto_computation->set_computation_name(computation->name()); - for (const HloInstruction* instruction : *sequence) { - *proto_computation->add_instruction_names() = instruction->name(); - } - } - } - return proto; -} - PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) : HloOrdering(module) {} diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index b21071c4b2..b0361c3f02 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -72,10 +72,6 @@ class HloOrdering { virtual string ToString() const = 0; - // Returns the serialized representation of this ordering. - // Only sequential computation orders are represented. - HloOrderingProto ToProto() const; - protected: // Returns true if instruction 'a' executes before instruction 'b'. // Precondition: 'a' and 'b' are in the same computation. diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 7c848ba7b4..c54360b063 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" @@ -44,6 +45,20 @@ using absl::StrJoin; const double kF16max = 65504; +// Creates and returns a schedule created using the order of the instructions in +// the HloComputation::instructions() vectors in the module. +HloSchedule ScheduleFromInstructionOrder(const HloModule* module) { + HloSchedule schedule(module); + for (const HloComputation* computation : module->computations()) { + if (!computation->IsFusionComputation()) { + for (const HloInstruction* instruction : computation->instructions()) { + schedule.GetOrCreateSequence(computation).push_back(instruction); + } + } + } + return schedule; +} + // Parser for the HloModule::ToString() format text. class HloParser { public: @@ -366,9 +381,25 @@ bool HloParser::ParseHloModule() { return false; } + absl::optional is_scheduled; + std::unordered_map attrs; + attrs["is_scheduled"] = {/*required=*/false, AttrTy::kBool, &is_scheduled}; + if (!ParseAttributes(attrs)) { + return false; + } + module_ = absl::make_unique(name, config_); - return ParseComputations(); + if (!ParseComputations()) { + return false; + } + + if (is_scheduled.has_value() && *is_scheduled) { + TF_CHECK_OK( + module_->set_schedule(ScheduleFromInstructionOrder(module_.get()))); + } + + return true; } // computations ::= (computation)+ diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 43e8736532..cca50fab54 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1133,8 +1133,21 @@ ENTRY Computation { } )" + }, +// is_scheduled=true attribute +{ +"ScheduledModule", +R"(HloModule scheduled_module, is_scheduled=true + +ENTRY Sort { + keys = f32[1024]{0} parameter(0) + values = s32[1024]{0} parameter(1) + ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0} } - }); + +)" +} +}); // clang-format on } @@ -1790,5 +1803,94 @@ TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { EXPECT_EQ(convolution->feature_group_count(), 1); } +TEST_F(HloParserTest, IsScheduledIsFalse) { + const string text = R"( +HloModule axpy_module, is_scheduled=false + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledNotPresent) { + const string text = R"( +HloModule axpy_module + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_FALSE(module->has_schedule()); +} + +TEST_F(HloParserTest, IsScheduledIsTrue) { + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Broadcast(), op::Parameter(), + op::Multiply(), op::Parameter(), op::Add())); +} + +TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) { + // As above but in with a different schedule order. + const string text = R"( +HloModule axpy_module, is_scheduled=true + +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %y = f32[2,4]{1,0} parameter(2) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(text)); + ASSERT_TRUE(module->has_schedule()); + TF_ASSERT_OK(module->schedule().Verify()); + EXPECT_EQ(module->schedule().sequences().size(), 1); + ASSERT_TRUE( + module->schedule().is_computation_scheduled(module->entry_computation())); + EXPECT_THAT( + module->schedule().sequence(module->entry_computation()).instructions(), + ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(), + op::Broadcast(), op::Multiply(), op::Add())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 3460679558..b9c0b0c4ee 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -23,11 +23,8 @@ namespace xla { HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment) { - HloOrderingProto proto_ordering = - assignment.liveness().hlo_ordering().ToProto(); BufferAssignmentProto proto_assignment = assignment.ToProto(); HloProto proto = MakeHloProto(module); - proto.mutable_hlo_ordering()->Swap(&proto_ordering); proto.mutable_buffer_assignment()->Swap(&proto_assignment); return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc index a65b33bf40..3fc5dbeb02 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/hlo_schedule.cc @@ -21,12 +21,64 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" namespace xla { +/* static */ StatusOr HloSchedule::CreateFromProto( + const HloModule* module, const HloScheduleProto& proto) { + tensorflow::gtl::FlatMap id_to_computation; + for (const HloComputation* computation : module->computations()) { + id_to_computation[computation->unique_id()] = computation; + } + + HloSchedule schedule(module); + for (const auto& id_sequence : proto.sequences()) { + int64 computation_id = id_sequence.first; + + auto comp_it = id_to_computation.find(computation_id); + TF_RET_CHECK(comp_it != id_to_computation.end()) + << "No computation exists in HLO module with id " << computation_id; + const HloComputation* computation = comp_it->second; + + tensorflow::gtl::FlatMap id_to_instruction; + for (const HloInstruction* instruction : computation->instructions()) { + id_to_instruction[instruction->unique_id()] = instruction; + } + + HloInstructionSequence& sequence = + schedule.GetOrCreateSequence(computation); + for (const int64 instruction_id : id_sequence.second.instruction_ids()) { + auto instr_it = id_to_instruction.find(instruction_id); + TF_RET_CHECK(instr_it != id_to_instruction.end()) + << "No instruction exists in HLO computation " << computation->name() + << " with id " << instruction_id; + sequence.push_back(instr_it->second); + } + } + TF_RETURN_IF_ERROR(schedule.Verify()); + return std::move(schedule); +} + +StatusOr HloSchedule::ToProto() const { + TF_RETURN_IF_ERROR(Verify()); + HloScheduleProto proto; + for (const auto& id_sequence : sequences_) { + int64 computation_id = id_sequence.first; + const HloInstructionSequence& sequence = id_sequence.second; + HloScheduleProto::InstructionSequence& proto_sequence = + (*proto.mutable_sequences())[computation_id]; + proto_sequence.mutable_instruction_ids()->Reserve(sequence.size()); + for (const int64 id : sequence.ids()) { + proto_sequence.add_instruction_ids(id); + } + } + return std::move(proto); +} + void HloSchedule::set_sequence( const HloComputation* computation, absl::Span sequence) { diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h index 21c6988638..270fe6039f 100644 --- a/tensorflow/compiler/xla/service/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/hlo_schedule.h @@ -21,18 +21,20 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" #include "tensorflow/compiler/xla/status.h" namespace xla { +class HloModule; + // Class representing a sequence of HLO instructions such as the sequential // execution order of an HLO computation. class HloInstructionSequence { public: HloInstructionSequence() = default; - HloInstructionSequence(absl::Span instructions) { + explicit HloInstructionSequence( + absl::Span instructions) { for (const HloInstruction* instruction : instructions) { push_back(instruction); } @@ -77,7 +79,12 @@ class HloInstructionSequence { // non-fusion computation in the HLO module. class HloSchedule { public: - HloSchedule(const HloModule* module) : module_(module) {} + explicit HloSchedule(const HloModule* module) : module_(module) {} + + // (De)Serialize an HloSchedule to/from a HloScheduleProto. + static StatusOr CreateFromProto(const HloModule* module, + const HloScheduleProto& proto); + StatusOr ToProto() const; // Returns a reference to the sequence for the given computation. const HloInstructionSequence& sequence( -- GitLab From bfff3425e0938c6bcc635edce2673252c4762a99 Mon Sep 17 00:00:00 2001 From: Doe Hyun Yoon Date: Thu, 6 Sep 2018 09:42:22 -0700 Subject: [PATCH 187/540] Replace Placeholder with Const to GrapplerFunctionItem for function shape inference if possible. PiperOrigin-RevId: 211821596 --- .../core/grappler/costs/graph_properties.cc | 50 +++++++++++++---- .../grappler/costs/graph_properties_test.cc | 55 ++++++++++++++++++- 2 files changed, 91 insertions(+), 14 deletions(-) diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 6710ff9df3..d24e7e8ee4 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -429,18 +429,22 @@ class SymbolicShapeRefiner { // perform shape inference on the function body. // // Propagate shape information of final function body node - // to function node `node`. + // to function node `function_node`. // - // In the event of an error, UpdateNode will simply set `node`'s + // In the event of an error, UpdateNode will simply set `function_node`'s // output shape to be Unknown. - Status UpdateFunction(const NodeDef* node) { - auto it = fun_to_grappler_function_item_.find(node->op()); + Status UpdateFunction(const NodeDef* function_node) { + auto it = fun_to_grappler_function_item_.find(function_node->op()); if (it == fun_to_grappler_function_item_.end()) { return errors::InvalidArgument( - node->op(), " was not previously added to SymbolicShapeRefiner."); + function_node->op(), + " was not previously added to SymbolicShapeRefiner."); } - GrapplerFunctionItem& grappler_function_item = it->second; + // Copy (not reference) so that changes we make here (e.g., replacing + // Placeholder with Const) don't affect one in + // fun_to_grappler_function_item_. + GrapplerFunctionItem grappler_function_item = it->second; GraphView gv(&grappler_function_item.graph); // Forward shapes from function input nodes to argument nodes. @@ -453,7 +457,7 @@ class SymbolicShapeRefiner { "supported."); } NodeDef* fun_node = gv.GetNode(fun_input.input_name); - const string& input = node->input(i); + const string& input = function_node->input(i); const string& node_name = NodeName(input); if (IsControlInput(input)) { @@ -478,16 +482,35 @@ class SymbolicShapeRefiner { TensorShapeProto proto; const auto& handle = input_inference_context->output(output_port_num); input_inference_context->ShapeHandleToProto(handle, &proto); + // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1. + for (int i = 0; i < proto.dim_size(); i++) { + if (proto.dim(i).size() < -1) { + proto.mutable_dim(i)->set_size(-1); + } + } *attr_output_shape.mutable_shape() = proto; (*fun_node->mutable_attr())["shape"] = attr_output_shape; } + // Replace input Placeholders with Consts, if values are known. Note that + // we don't check exceptions here as it's done in the above loop. + for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) { + const string& input = function_node->input(i); + const string& node_name = NodeName(input); + NodeDef* input_node = graph_.GetNode(node_name); + // TODO(dyoon): also use Const when output_tensors_as_shape is available. + if (IsConstant(*input_node)) { + TF_CHECK_OK( + ReplaceInputWithConst(*input_node, i, &grappler_function_item)); + } + } + // Perform inference on function body. GraphProperties gp(grappler_function_item); TF_RETURN_IF_ERROR(gp.InferStatically(true)); // Add return nodes for output shapes. - auto ic = GetContext(node); + auto ic = GetContext(function_node); int output = 0; for (auto const& out_arg : grappler_function_item.outputs()) { if (out_arg.output_tensors.size() > 1) { @@ -505,8 +528,9 @@ class SymbolicShapeRefiner { const NodeDef* retnode = gv.GetNode(node_name); if (retnode == nullptr) { - return errors::FailedPrecondition("Unable to find return node ", - node_name, " for ", node->name()); + return errors::FailedPrecondition( + "Unable to find return function_node ", node_name, " for ", + function_node->name()); } auto output_properties = gp.GetOutputProperties(retnode->name()); @@ -671,11 +695,13 @@ class SymbolicShapeRefiner { // true, as the updates to the call node will have changed, even if it's // the same function being called twice with the same input shapes. // Example: simple_function.pbtxt - if (UpdateFunction(node).ok()) { + auto s = UpdateFunction(node); + if (s.ok()) { return Status::OK(); } else { VLOG(1) << "UpdateFunction failed for " << node->op() - << ". Defaulting to ShapeUnknown."; + << ". Defaulting to ShapeUnknown.\n" + << s.ToString(); } } diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 8938b7c32e..3ec68a4e59 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -785,7 +785,58 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) { EXPECT_EQ("float: [128,256]", PropToString(prop)); } -TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) { +TEST_F(GraphPropertiesTest, FunctionWithConstInput) { + FunctionDefLibrary library; + // This function is simply + // out = Fill(shape, value), but + // Fill requires values in the shape input, not just shape of it, to infer + // output shape; hence, func + *library.add_function() = FunctionDefHelper::Create( + // Name + "MyFillFunc", + // Inputs + {"shape: int32", "value: float"}, + // Outputs + {"out: float"}, + // Attrs + {}, + // Nodes + { + {{"a"}, + "Fill", + {"shape", "value"}, + {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}}, + }, + // Returns + {{"out", "a:output:0"}}); + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + TF_CHECK_OK(s.graph()->AddFunctionLibrary(library)); + Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4}); + Output value = ops::Const(s.WithOpName("value"), 0.1f, {}); + auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc", + s.graph()->op_registry()); + tensorflow::Node* func_op; + auto _shape = tensorflow::ops::AsNodeOut(s, shape); + auto _value = tensorflow::ops::AsNodeOut(s, value); + TF_CHECK_OK( + builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op)); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("MyFillFunc"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ(DT_FLOAT, out_prop0.dtype()); + EXPECT_FALSE(out_prop0.shape().unknown_rank()); + EXPECT_EQ(4, out_prop0.shape().dim_size()); + EXPECT_EQ(1, out_prop0.shape().dim(0).size()); + EXPECT_EQ(2, out_prop0.shape().dim(1).size()); + EXPECT_EQ(3, out_prop0.shape().dim(2).size()); + EXPECT_EQ(4, out_prop0.shape().dim(3).size()); +} + +TEST_F(GraphPropertiesTest, FunctionWithScalarInput) { // Create graph with a function that takes a scalar value so that we use // Placeholder with scalar as for input to the function shape inference. // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of @@ -818,7 +869,7 @@ TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) { // MyFunc output shouldn't be unknown rank. GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically(false)); + TF_CHECK_OK(properties.InferStatically(true)); const auto out_props = properties.GetOutputProperties("MyFunc"); const OpInfo::TensorProperties out_prop0 = out_props[0]; EXPECT_EQ(DT_FLOAT, out_prop0.dtype()); -- GitLab From d17016a8dfd9b9bd92a55fc1fddee4fd1c29bdbe Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Thu, 6 Sep 2018 10:01:46 -0700 Subject: [PATCH 188/540] Extend ConditionalAccumulator with SUM functionality. Previously take_grad represents the average gradients being aggregated. However this does not cover other use cases such as summing quantiles, or summing probability distributions from parallel workers. This change extends the functionality. PiperOrigin-RevId: 211824519 --- .../core/kernels/conditional_accumulator.h | 6 +- .../kernels/conditional_accumulator_base.cc | 13 ++- .../kernels/conditional_accumulator_base.h | 3 +- .../kernels/conditional_accumulator_base_op.h | 3 + .../kernels/conditional_accumulator_op.cc | 3 +- .../kernels/sparse_conditional_accumulator.h | 4 +- .../sparse_conditional_accumulator_op.cc | 4 +- .../typed_conditional_accumulator_base.h | 5 +- tensorflow/core/ops/data_flow_ops.cc | 2 + .../conditional_accumulator_test.py | 88 +++++++++++++++++-- .../sparse_conditional_accumulator_test.py | 83 +++++++++++++++-- tensorflow/python/ops/data_flow_ops.py | 20 ++++- .../tensorflow.-conditional-accumulator.pbtxt | 2 +- ...flow.-sparse-conditional-accumulator.pbtxt | 2 +- .../tensorflow.-conditional-accumulator.pbtxt | 2 +- ...flow.-sparse-conditional-accumulator.pbtxt | 2 +- 16 files changed, 207 insertions(+), 35 deletions(-) diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h index a7836896c7..390db8fe5a 100644 --- a/tensorflow/core/kernels/conditional_accumulator.h +++ b/tensorflow/core/kernels/conditional_accumulator.h @@ -51,9 +51,11 @@ class ConditionalAccumulator // dtype: The datatype of the gradients to be accumulated. // shape: The shape of the accumulated gradients. // name: A name to use for the ConditionalAccumulator. + // reduction_type: The reduction type, i.e., MEAN or SUM ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, - const string& name) - : TypedConditionalAccumulatorBase(dtype, shape, name) {} + const string& name, const string& reduction_type) + : TypedConditionalAccumulatorBase(dtype, shape, name, + reduction_type) {} ~ConditionalAccumulator() override{}; protected: diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc index 90593c56b8..292cf0cd64 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.cc +++ b/tensorflow/core/kernels/conditional_accumulator_base.cc @@ -14,12 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/conditional_accumulator_base.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { ConditionalAccumulatorBase::ConditionalAccumulatorBase( - const DataType& dtype, const PartialTensorShape& shape, const string& name) - : dtype_(dtype), shape_(shape), name_(name) { + const DataType& dtype, const PartialTensorShape& shape, const string& name, + const string& reduction_type) + : dtype_(dtype), + shape_(shape), + name_(name), + reduction_type_(reduction_type) { counter_ = 0; current_global_step_ = 0; } @@ -190,7 +195,9 @@ bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx, current_global_step_++; // Average the accumulated gradient - DivideAccumGradByCounter(ctx); + if (reduction_type_ == "MEAN") { + DivideAccumGradByCounter(ctx); + } // Set output for accumulated gradient tensor bool successful_set_output = SetOutput(ctx); diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h index b7b7482a00..4a5ec6f0fb 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.h +++ b/tensorflow/core/kernels/conditional_accumulator_base.h @@ -52,7 +52,7 @@ class ConditionalAccumulatorBase : public ResourceBase { // name: A name to use for the ConditionalAccumulator. ConditionalAccumulatorBase(const DataType& dtype, const PartialTensorShape& shape, - const string& name); + const string& name, const string& reduction_type); typedef AsyncOpKernel::DoneCallback DoneCallback; @@ -125,6 +125,7 @@ class ConditionalAccumulatorBase : public ResourceBase { const DataType dtype_; const PartialTensorShape shape_; const string name_; + const string reduction_type_; mutex mu_; int counter_ GUARDED_BY(mu_); int64 current_global_step_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h index 012a0dcc12..ca24d690f8 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base_op.h +++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h @@ -51,6 +51,8 @@ class ConditionalAccumulatorBaseOp : public OpKernel { &accumulator_handle_, nullptr)); OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(context, + context->GetAttr("reduction_type", &reduction_type_)); } void Compute(OpKernelContext* ctx) override { @@ -81,6 +83,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel { DataType dtype_; PartialTensorShape shape_; ContainerInfo cinfo_; + string reduction_type_; private: Status SetAccumulatorHandle(OpKernelContext* ctx) diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc index e13bf8a4c6..52ac51a9b6 100644 --- a/tensorflow/core/kernels/conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/conditional_accumulator_op.cc @@ -34,7 +34,8 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { Creator GetCreator() const override { return [this](ConditionalAccumulatorBase** ret) { ConditionalAccumulator* accumulator = - new ConditionalAccumulator(dtype_, shape_, cinfo_.name()); + new ConditionalAccumulator(dtype_, shape_, cinfo_.name(), + reduction_type_); *ret = accumulator; return Status::OK(); }; diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h index 11149c4d16..a4453bd7ab 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator.h +++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h @@ -50,10 +50,10 @@ class SparseConditionalAccumulator public: SparseConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, - const string& name) + const string& name, const string& reduction_type) : TypedConditionalAccumulatorBase< std::tuple>( - dtype, shape, name) { + dtype, shape, name, reduction_type) { accum_idx_vec_ = nullptr; count_element_ = nullptr; accum_val_ = nullptr; diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc index 80bc1f1934..1e542a26a7 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc @@ -34,8 +34,8 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { Creator GetCreator() const override { return [this](ConditionalAccumulatorBase** ret) { SparseConditionalAccumulator* accumulator = - new SparseConditionalAccumulator(dtype_, shape_, - cinfo_.name()); + new SparseConditionalAccumulator( + dtype_, shape_, cinfo_.name(), reduction_type_); *ret = accumulator; return Status::OK(); }; diff --git a/tensorflow/core/kernels/typed_conditional_accumulator_base.h b/tensorflow/core/kernels/typed_conditional_accumulator_base.h index 9dedb618f9..ca341e511e 100644 --- a/tensorflow/core/kernels/typed_conditional_accumulator_base.h +++ b/tensorflow/core/kernels/typed_conditional_accumulator_base.h @@ -35,8 +35,9 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase { public: TypedConditionalAccumulatorBase(const DataType& dtype, const PartialTensorShape& shape, - const string& name) - : ConditionalAccumulatorBase(dtype, shape, name) {} + const string& name, + const string& reduction_type) + : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {} /** * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index eed0bce174..ffab8ad661 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -419,6 +419,7 @@ REGISTER_OP("ConditionalAccumulator") .Attr("shape: shape") .Attr("container: string = ''") .Attr("shared_name: string = ''") + .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->Vector(2)); @@ -456,6 +457,7 @@ REGISTER_OP("SparseConditionalAccumulator") .Attr("shape: shape") .Attr("container: string = ''") .Attr("shared_name: string = ''") + .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->Vector(2)); diff --git a/tensorflow/python/kernel_tests/conditional_accumulator_test.py b/tensorflow/python/kernel_tests/conditional_accumulator_test.py index 7570523495..86802664d1 100644 --- a/tensorflow/python/kernel_tests/conditional_accumulator_test.py +++ b/tensorflow/python/kernel_tests/conditional_accumulator_test.py @@ -42,14 +42,22 @@ class ConditionalAccumulatorTest(test.TestCase): with ops.Graph().as_default(): q = data_flow_ops.ConditionalAccumulator(dtypes_lib.float32, name="Q") self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) - self.assertProtoEquals(""" + self.assertProtoEquals( + """ name:'Q' op:'ConditionalAccumulator' attr { key: 'dtype' value { type: DT_FLOAT } } attr { key: 'shape' value { shape { unknown_rank: true} } } attr { key: 'container' value { s: '' } } attr { key: 'shared_name' value { s: '' } } + attr { key: 'reduction_type' value {s: 'MEAN'} } """, q.accumulator_ref.op.node_def) + def testConstructorWithInvalidArg(self): + with ops.Graph().as_default(): + with self.assertRaises(ValueError): + data_flow_ops.ConditionalAccumulator( + dtypes_lib.float32, name="Q", reduction_type="Invalid") + def testConstructorWithShape(self): with ops.Graph().as_default(): q = data_flow_ops.ConditionalAccumulator( @@ -57,7 +65,8 @@ class ConditionalAccumulatorTest(test.TestCase): name="Q", shape=tensor_shape.TensorShape([1, 5, 2, 8])) self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) - self.assertProtoEquals(""" + self.assertProtoEquals( + """ name:'Q' op:'ConditionalAccumulator' attr { key: 'dtype' value { type: DT_FLOAT } } attr { key: 'shape' value { shape { dim {size: 1 } @@ -67,6 +76,7 @@ class ConditionalAccumulatorTest(test.TestCase): } } } attr { key: 'container' value { s: '' } } attr { key: 'shared_name' value { s: '' } } + attr { key: 'reduction_type' value {s: 'MEAN'} } """, q.accumulator_ref.op.node_def) def testAccumulatorSizeEmpty(self): @@ -237,12 +247,11 @@ class ConditionalAccumulatorTest(test.TestCase): extract_t.op.run() self.assertEqual(q.num_accumulated().eval(), 0) - def testAccumulatorTakeGrad(self): + def testAccumulatorTakeGradMean(self): with self.test_session(): q = data_flow_ops.ConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) elems = [10.0, 20.0] - elems_ave = sum(elems) / len(elems) accum_ops = [q.apply_grad((x,), local_step=0) for x in elems] takeg_t = q.take_grad(1) @@ -251,7 +260,7 @@ class ConditionalAccumulatorTest(test.TestCase): accum_op.run() val = takeg_t.eval() - self.assertEqual(elems_ave, val) + self.assertEqual(15.0, val) accum_ops = [q.apply_grad((x,), local_step=1) for x in elems] takeg_t = q.take_grad(constant_op.constant(1)) @@ -260,7 +269,42 @@ class ConditionalAccumulatorTest(test.TestCase): accum_op.run() val = takeg_t.eval() - self.assertEqual(elems_ave, val) + self.assertEqual(15.0, val) + + def testAccumulatorTakeGradSum(self): + with self.test_session(): + q = data_flow_ops.ConditionalAccumulator( + dtypes_lib.float32, + name="Q", + shape=tensor_shape.TensorShape([1]), + reduction_type="SUM") + elems = [10.0, 20.0] + + accum_ops = [q.apply_grad((x,), local_step=0) for x in elems] + takeg_t = q.take_grad(1) + + for accum_op in accum_ops: + accum_op.run() + + val = takeg_t.eval() + self.assertEqual(30.0, val) + + accum_ops = [q.apply_grad((x,), local_step=1) for x in elems] + takeg_t = q.take_grad(constant_op.constant(1)) + + for accum_op in accum_ops: + accum_op.run() + + val = takeg_t.eval() + self.assertEqual(30.0, val) + + def testAccumulatorTakeGradInvalidReductionType(self): + with self.assertRaises(ValueError): + data_flow_ops.ConditionalAccumulator( + dtypes_lib.float32, + name="Q", + shape=tensor_shape.TensorShape([1]), + reduction_type="Invalid") def testAccumulatorInvalidTakeGrad(self): with self.test_session(): @@ -277,7 +321,7 @@ class ConditionalAccumulatorTest(test.TestCase): with self.assertRaises(errors_impl.InvalidArgumentError): takeg_t.eval() - def testAccumulatorRepeatedTakeGrad(self): + def testAccumulatorRepeatedTakeGradMean(self): with self.test_session(): q = data_flow_ops.ConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1])) @@ -304,6 +348,36 @@ class ConditionalAccumulatorTest(test.TestCase): val = takeg_t.eval() self.assertEqual(elems_ave + 0.0, val) + def testAccumulatorRepeatedTakeGradSum(self): + with self.test_session(): + q = data_flow_ops.ConditionalAccumulator( + dtypes_lib.float32, + name="Q", + shape=tensor_shape.TensorShape([1]), + reduction_type="SUM") + + elems = [10.0, 20.0] + elems_sum = 30.0 + accum_ops = [q.apply_grad((x,), local_step=0) for x in elems] + takeg_t = q.take_grad(1) + + for accum_op in accum_ops: + accum_op.run() + + val = takeg_t.eval() + self.assertEqual(elems_sum, val) + + elems = [20.0, 30.0] + elems_sum = 50.0 + accum_ops = [q.apply_grad((x,), local_step=1) for x in elems] + takeg_t = q.take_grad(1) + + for accum_op in accum_ops: + accum_op.run() + + val = takeg_t.eval() + self.assertEqual(elems_sum, val) + def testAccumulatorIncrementGlobalStep(self): with self.test_session(): q = data_flow_ops.ConditionalAccumulator( diff --git a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py index d749843410..3bb5e899fe 100644 --- a/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py +++ b/tensorflow/python/kernel_tests/sparse_conditional_accumulator_test.py @@ -61,14 +61,22 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q") self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) - self.assertProtoEquals(""" + self.assertProtoEquals( + """ name:'Q' op:'SparseConditionalAccumulator' attr { key: 'dtype' value { type: DT_FLOAT } } attr { key: 'shape' value { shape { unknown_rank: true} } } attr { key: 'container' value { s: '' } } attr { key: 'shared_name' value { s: '' } } + attr { key: 'reduction_type' value {s: 'MEAN'} } """, q.accumulator_ref.op.node_def) + def testConstructorWithInvalidArg(self): + with ops.Graph().as_default(): + with self.assertRaises(ValueError): + data_flow_ops.SparseConditionalAccumulator( + dtypes_lib.float32, name="Q", reduction_type="Invalid") + def testConstructorWithShape(self): with ops.Graph().as_default(): q = data_flow_ops.SparseConditionalAccumulator( @@ -76,7 +84,8 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): name="Q", shape=tensor_shape.TensorShape([1, 5, 2, 8])) self.assertTrue(isinstance(q.accumulator_ref, ops.Tensor)) - self.assertProtoEquals(""" + self.assertProtoEquals( + """ name:'Q' op:'SparseConditionalAccumulator' attr { key: 'dtype' value { type: DT_FLOAT } } attr { key: 'shape' value { shape { dim {size: 1 } @@ -86,6 +95,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): } } } attr { key: 'container' value { s: '' } } attr { key: 'shared_name' value { s: '' } } + attr { key: 'reduction_type' value {s: 'MEAN'} } """, q.accumulator_ref.op.node_def) def testAccumulatorSizeEmpty(self): @@ -164,7 +174,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): result = sess.run(accums[i].take_indexed_slices_grad(1)) self._assertEqual_indexedslices(expected_tensors[i], result) - def testAccumulatorTakeGrad(self): + def testAccumulatorTakeGradMean(self): with self.test_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q", shape=()) @@ -180,9 +190,34 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): takeg_t = q.take_indexed_slices_grad(1) val = sess.run(takeg_t) - self.assertAllEqual(val.indices, [0, 1, 2]) - self.assertAllEqual(val.values, [[0.5, 0.5], [0, 2], [3, 0]]) - self.assertAllEqual(val.dense_shape, [-1, 2]) + self.assertAllEqual([0, 1, 2], val.indices) + self.assertAllEqual([[0.5, 0.5], [0, 2], [3, 0]], val.values) + self.assertAllEqual([-1, 2], val.dense_shape) + + def testAccumulatorTakeGradSum(self): + with self.test_session() as sess: + q = data_flow_ops.SparseConditionalAccumulator( + dtypes_lib.float32, name="Q", shape=(), reduction_type="SUM") + + grad_indexed_slices = ops.IndexedSlices( + indices=[0, 1], values=np.array([[1, 0], [0, 2]]).astype(np.float32)) + accum_op = q.apply_indexed_slices_grad(grad_indexed_slices) + accum_op.run() + accum_op = q.apply_grad([0, 2], + np.array([[0, 1], [3, 0]]).astype(np.float32), + [3, 2]) + accum_op.run() + + takeg_t = q.take_indexed_slices_grad(1) + val = sess.run(takeg_t) + self.assertAllEqual([0, 1, 2], val.indices) + self.assertAllEqual([[1, 1], [0, 2], [3, 0]], val.values) + self.assertAllEqual([-1, 2], val.dense_shape) + + def testAccumulatorTakeGradInvalidReductionType(self): + with self.assertRaises(ValueError): + data_flow_ops.SparseConditionalAccumulator( + dtypes_lib.float32, name="Q", shape=(), reduction_type="Invalid") def testAccumulatorRepeatedTakeGrad(self): with self.test_session() as sess: @@ -222,7 +257,7 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): self.assertAllEqual(val.values, [[5, 5], [0, 20], [30, 0]]) self.assertAllEqual(val.dense_shape, [-1, 2]) - def testParallelApplyGrad(self): + def testParallelApplyGradMean(self): with self.test_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2])) @@ -253,6 +288,40 @@ class IndexedSlicesConditionalAccumulatorTest(test.TestCase): np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32), val, sess) + def testParallelApplyGradSum(self): + with self.test_session() as sess: + q = data_flow_ops.SparseConditionalAccumulator( + dtypes_lib.float32, + name="Q", + shape=tensor_shape.TensorShape([2, 2]), + reduction_type="SUM") + elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0] + accum_ops = [] + for x in elems: + x = _indexedslice(np.array([[x, 0], [0, x]]).astype(np.float32)) + accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0)) + takeg_t = q.take_indexed_slices_grad(1) + + def apply_indexed_slices_grad(accum_op): + sess.run(accum_op) + + threads = [ + self.checkedThread(target=apply_indexed_slices_grad, args=(o,)) + for o in accum_ops + ] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + val = sess.run(takeg_t) + + expected_val = 550.0 + self._assertEqual_nparray( + np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32), + val, sess) + def testParallelTakeGrad(self): with self.test_session() as sess: q = data_flow_ops.SparseConditionalAccumulator( diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 7af2ca56be..69c0fcbbee 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -1229,7 +1229,8 @@ class ConditionalAccumulator(ConditionalAccumulatorBase): dtype, shape=None, shared_name=None, - name="conditional_accumulator"): + name="conditional_accumulator", + reduction_type="MEAN"): """Creates a new ConditionalAccumulator. Args: @@ -1238,9 +1239,14 @@ class ConditionalAccumulator(ConditionalAccumulatorBase): shared_name: Optional. If non-empty, this accumulator will be shared under the given name across multiple sessions. name: Optional name for the accumulator. + reduction_type: Reduction type to use when taking the gradient. """ accumulator_ref = gen_data_flow_ops.conditional_accumulator( - dtype=dtype, shape=shape, shared_name=shared_name, name=name) + dtype=dtype, + shape=shape, + shared_name=shared_name, + name=name, + reduction_type=reduction_type) super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref) def apply_grad(self, grad, local_step=0, name=None): @@ -1312,15 +1318,21 @@ class SparseConditionalAccumulator(ConditionalAccumulatorBase): shared_name: Optional. If non-empty, this accumulator will be shared under the given name across multiple sessions. name: Optional name for the accumulator. + reduction_type: Reduction type to use when taking the gradient. """ def __init__(self, dtype, shape=None, shared_name=None, - name="sparse_conditional_accumulator"): + name="sparse_conditional_accumulator", + reduction_type="MEAN"): accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator( - dtype=dtype, shape=shape, shared_name=shared_name, name=name) + dtype=dtype, + shape=shape, + shared_name=shared_name, + name=name, + reduction_type=reduction_type) super(SparseConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt index d23b3bd0ca..15e0ab76b6 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-conditional-accumulator.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], " + argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], " } member_method { name: "apply_grad" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt index 2260279ad2..39ff336c4f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-sparse-conditional-accumulator.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], " + argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], " } member_method { name: "apply_grad" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt index d23b3bd0ca..15e0ab76b6 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-conditional-accumulator.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\'], " + argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'conditional_accumulator\', \'MEAN\'], " } member_method { name: "apply_grad" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt index 2260279ad2..39ff336c4f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-sparse-conditional-accumulator.pbtxt @@ -17,7 +17,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\'], " + argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], " } member_method { name: "apply_grad" -- GitLab From 43a3c393d7a329b7dc7aec02a7d46dc69e5a8ee1 Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Thu, 6 Sep 2018 10:02:24 -0700 Subject: [PATCH 189/540] Update docstring for BoostedTrees n_batches_per_layer. PiperOrigin-RevId: 211824645 --- tensorflow/python/estimator/canned/boosted_trees.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index d104c961d3..19f18015e4 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -1000,8 +1000,11 @@ class BoostedTreesClassifier(estimator.Estimator): bucketized_feature_2 = bucketized_column( numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + # Need to see a large portion of the data before we can build a layer, for + # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE classifier = estimator.BoostedTreesClassifier( feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_batches_per_layer=n_batches_per_layer, n_trees=100, ... ) @@ -1024,7 +1027,8 @@ class BoostedTreesClassifier(estimator.Estimator): the model. All items in the set should be instances of classes derived from `FeatureColumn`. n_batches_per_layer: the number of batches to collect statistics per - layer. + layer. The total number of batches is total number of data divided by + batch size. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. @@ -1138,8 +1142,11 @@ class BoostedTreesRegressor(estimator.Estimator): bucketized_feature_2 = bucketized_column( numeric_column('feature_2'), BUCKET_BOUNDARIES_2) + # Need to see a large portion of the data before we can build a layer, for + # example half of data n_batches_per_layer = 0.5 * NUM_EXAMPLES / BATCH_SIZE regressor = estimator.BoostedTreesRegressor( feature_columns=[bucketized_feature_1, bucketized_feature_2], + n_batches_per_layer=n_batches_per_layer, n_trees=100, ... ) @@ -1162,7 +1169,8 @@ class BoostedTreesRegressor(estimator.Estimator): the model. All items in the set should be instances of classes derived from `FeatureColumn`. n_batches_per_layer: the number of batches to collect statistics per - layer. + layer. The total number of batches is total number of data divided by + batch size. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. -- GitLab From 84f091dff8e1bcd93ac2d69d2cc11faca3790ac9 Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Thu, 6 Sep 2018 10:20:55 -0700 Subject: [PATCH 190/540] Add python test for While op lowering. Test that fetching values of while outputs in sess.run by tensor name works. This tests that an IdentityN node with the same name and outputs as the original while op was added to the graph during lowering. PiperOrigin-RevId: 211827934 --- tensorflow/python/kernel_tests/BUILD | 1 + .../kernel_tests/functional_ops_test.py | 35 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 3026c7755a..58c8975daa 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1634,6 +1634,7 @@ cuda_py_test( srcs = ["functional_ops_test.py"], additional_deps = [ "//third_party/py/numpy", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 3ddb5e06c9..e39daf1371 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import iterator_ops @@ -738,6 +739,40 @@ class FunctionalOpsTest(test.TestCase): self.assertAllEqual(Run(sess, 20.), 210.) self.assertAllEqual(Run(sess, 100.), 5050.) + def testWhileLowering(self): + + def Run(n, fetch_by_name): + for use_gpu in (True, False): + with ops.Graph().as_default() as g: + + @function.Defun(*[dtypes.float32] * 2) + def Cond(n, unused_x): + return n > 0 + + @function.Defun(*[dtypes.float32] * 2) + def Body(n, x): + return n - 1, x + n + + # outputs: [0, n*(n+1)/2] + outputs = functional_ops.While([n, 0.], Cond, Body, name="my_while") + + # `outputs` is the list of output tensors of the While op. We + # arbitrarily choose the 0th tensor to get the While op and set the + # lowering attribute on it. + outputs[0].op._set_attr("_lower_using_switch_merge", + attr_value_pb2.AttrValue(b=True)) + if not fetch_by_name: + fetch = outputs[1] + else: + fetch = "my_while:1" + with self.test_session(graph=g, use_gpu=use_gpu) as sess: + return sess.run(fetch) + + self.assertAllEqual(Run(20., False), 210.) + self.assertAllEqual(Run(20., True), 210.) + self.assertAllEqual(Run(100., False), 5050.) + self.assertAllEqual(Run(100., True), 5050.) + def testWhileError(self): for use_gpu in (True, False): with ops.Graph().as_default() as g: -- GitLab From b9310932ce2120c8c36eb69bc135748fd3caf897 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 6 Sep 2018 10:46:36 -0700 Subject: [PATCH 191/540] Automated rollback of commit 4cd79b3f6361b6518463349a51fe33f7520f3b49 PiperOrigin-RevId: 211832421 --- .../python/training/lazy_adam_optimizer.py | 63 +++++-------------- .../training/lazy_adam_optimizer_test.py | 17 +---- 2 files changed, 17 insertions(+), 63 deletions(-) diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py index f026f437dc..72117c1e81 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py @@ -25,11 +25,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.training import adam @@ -48,12 +46,7 @@ class LazyAdamOptimizer(adam.AdamOptimizer): may lead to different empirical results. """ - def _apply_sparse_shared(self, - grad, - var, - indices, - scatter_update, - scatter_sub): + def _apply_sparse(self, grad, var): beta1_power, beta2_power = self._get_beta_accumulators() beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) @@ -65,51 +58,23 @@ class LazyAdamOptimizer(adam.AdamOptimizer): # \\(m := beta1 * m + (1 - beta1) * g_t\\) m = self.get_slot(var, "m") - m_t = scatter_update(m, indices, - beta1_t * array_ops.gather(m, indices) + - (1 - beta1_t) * grad) + m_t = state_ops.scatter_update(m, grad.indices, + beta1_t * array_ops.gather(m, grad.indices) + + (1 - beta1_t) * grad.values, + use_locking=self._use_locking) # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\) v = self.get_slot(var, "v") - v_t = scatter_update(v, indices, - beta2_t * array_ops.gather(v, indices) + - (1 - beta2_t) * math_ops.square(grad)) + v_t = state_ops.scatter_update(v, grad.indices, + beta2_t * array_ops.gather(v, grad.indices) + + (1 - beta2_t) * math_ops.square(grad.values), + use_locking=self._use_locking) # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\) - m_t_slice = array_ops.gather(m_t, indices) - v_t_slice = array_ops.gather(v_t, indices) + m_t_slice = array_ops.gather(m_t, grad.indices) + v_t_slice = array_ops.gather(v_t, grad.indices) denominator_slice = math_ops.sqrt(v_t_slice) + epsilon_t - var_update = scatter_sub(var, indices, - lr * m_t_slice / denominator_slice) + var_update = state_ops.scatter_sub(var, grad.indices, + lr * m_t_slice / denominator_slice, + use_locking=self._use_locking) return control_flow_ops.group(var_update, m_t, v_t) - - def _apply_sparse(self, grad, var): - return self._apply_sparse_shared( - grad.values, var, grad.indices, - self._scatter_update, - self._scatter_sub) - - def _resource_apply_sparse(self, grad, var, indices): - return self._apply_sparse_shared( - grad, var, indices, - self._resource_scatter_update, - self._resource_scatter_sub) - - # Utility functions for updating resource or non-resource variables. - def _scatter_update(self, x, i, v): - return state_ops.scatter_update( - x, i, v, use_locking=self._use_locking) - - def _scatter_sub(self, x, i, v): - return state_ops.scatter_sub( - x, i, v, use_locking=self._use_locking) - - def _resource_scatter_update(self, x, i, v): - update_op = resource_variable_ops.resource_scatter_update(x.handle, i, v) - with ops.control_dependencies([update_op]): - return x.value() - - def _resource_scatter_sub(self, x, i, v): - sub_op = resource_variable_ops.resource_scatter_sub(x.handle, i, v) - with ops.control_dependencies([sub_op]): - return x.value() diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py index d3e9e89502..dc4c462ce4 100644 --- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py @@ -27,7 +27,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -52,7 +51,7 @@ def adam_update_numpy(param, class AdamOptimizerTest(test.TestCase): - def doTestSparse(self, use_resource=False): + def testSparse(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): # Initialize variables for numpy implementation. @@ -62,12 +61,8 @@ class AdamOptimizerTest(test.TestCase): var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) - if use_resource: - var0 = resource_variable_ops.ResourceVariable(var0_np) - var1 = resource_variable_ops.ResourceVariable(var1_np) - else: - var0 = variables.Variable(var0_np) - var1 = variables.Variable(var1_np) + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) grads0_np_indices = np.array([0, 1], dtype=np.int32) grads0 = ops.IndexedSlices( constant_op.constant(grads0_np), @@ -99,12 +94,6 @@ class AdamOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType(var0_np, var0.eval()) self.assertAllCloseAccordingToType(var1_np, var1.eval()) - def testSparse(self): - self.doTestSparse(use_resource=False) - - def testResourceSparse(self): - self.doTestSparse(use_resource=True) - def testSparseDevicePlacement(self): for index_dtype in [dtypes.int32, dtypes.int64]: with self.test_session(force_gpu=test.is_gpu_available()): -- GitLab From 9638524520d582e93a8038a89cd5cc62d719a3b6 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Thu, 6 Sep 2018 10:50:35 -0700 Subject: [PATCH 192/540] Job name should be picked based on the cluster_spec PiperOrigin-RevId: 211833041 --- .../cluster_resolver/python/training/tpu_cluster_resolver.py | 4 ++++ tensorflow/contrib/distribute/python/tpu_strategy.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 1ab150d74a..1056894f18 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -229,6 +229,10 @@ class TPUClusterResolver(ClusterResolver): def get_master(self): return self.master() + def get_job_name(self): + if self._shouldResolve(): + return self._job_name + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 4fb70ec685..6ba83976fc 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -310,7 +310,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def get_host_cpu_device(self, host_id): if self._tpu_cluster_resolver.get_master() in ('', 'local'): return '/replica:0/task:0/device:CPU:0' - return '/job:tpu_worker/task:%d/device:CPU:0' % (host_id,) + job_name = self._tpu_cluster_resolver.get_job_name() or 'tpu_worker' + return '/job:%s/task:%d/device:CPU:0' % (job_name, host_id) def configure(self, session_config=None, -- GitLab From 58857d06e671863ebacc025d0363d564a65bb7b0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 6 Sep 2018 10:53:51 -0700 Subject: [PATCH 193/540] Add feature_util build target so the library can be included in a lightweight way PiperOrigin-RevId: 211833556 --- tensorflow/core/BUILD | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index c06fea130f..f74379fca5 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -701,6 +701,21 @@ cc_library( ], ) +cc_library( + name = "feature_util", + srcs = ["example/feature_util.cc"], + hdrs = [ + "example/feature_util.h", + "platform/types.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":core_stringpiece", + ":platform_protobuf", + ":protos_all_cc", + ], +) + cc_library( name = "abi", srcs = ["platform/abi.cc"], -- GitLab From 6d893ecfb9ba2dfc3948215557d4f8ddaf7cf51b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 6 Sep 2018 10:55:42 -0700 Subject: [PATCH 194/540] Ignore partitioned variable in TPU computation. PiperOrigin-RevId: 211833891 --- tensorflow/contrib/tpu/python/tpu/tpu.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 1e21cc5252..c1f90c3963 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -652,13 +652,28 @@ def split_compile_and_replicate(computation, # TODO(phawkins): consider removing this code. It will # be less confusing to clients if they knowingly choose to use resource # variables. + # Partitioned variables is not supported (b/112311320). + def custom_getter(getter, name, *args, **kwargs): + partitioner = kwargs["partitioner"] + if partitioner is None: + return getter(name, *args, **kwargs) + else: + raise ValueError( + "Partitioned variables are not supported on TPU. Got " + "`partitioner` that is {}.".format(partitioner)) + vscope = variable_scope.get_variable_scope() + saved_use_resource = vscope.use_resource + saved_custom_getter = vscope.custom_getter + vscope.set_use_resource(True) + vscope.set_custom_getter(custom_getter) outputs = computation(*computation_inputs) vscope.set_use_resource(saved_use_resource) + vscope.set_custom_getter(saved_custom_getter) # If the computation returns `None`, make it an empty tuple. if outputs is None: -- GitLab From 025277a1598fa227b53ddc4e316a7a953b2006c8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 6 Sep 2018 10:57:58 -0700 Subject: [PATCH 195/540] Small improvements to handling of Datasets in Keras. * Allow sparse labels to work with Datasets. * Allow sample_weights to be passed as the third output of a Dataset (like how generator input is treated). PiperOrigin-RevId: 211834259 --- .../contrib/distribute/python/keras_test.py | 3 +- tensorflow/python/keras/engine/training.py | 21 ++++++--- .../python/keras/engine/training_eager.py | 9 ++-- .../python/keras/engine/training_test.py | 43 ++++++++++++++++++- .../python/keras/engine/training_utils.py | 18 +++++--- 5 files changed, 72 insertions(+), 22 deletions(-) diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index d39fd57294..3cee3e37a7 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -446,8 +446,7 @@ class TestWithDistributionStrategy(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - with self.assertRaisesRegexp(ValueError, - 'expected input to have 2 dimensions'): + with self.assertRaisesRegexp(ValueError, 'expected input to have shape'): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0) # Wrong input shape diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 966b446f22..46149bed09 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -928,11 +928,16 @@ class Model(Network): 'Make sure that your dataset can generate ' 'required number of samples.') - if not isinstance(next_element, (list, tuple)) or len(next_element) != 2: - raise ValueError('Please provide model inputs as a list or tuple of 2 ' - 'elements: input and target pair. ' - 'Received %s' % next_element) - x, y = next_element + if (not isinstance(next_element, (list, tuple)) or + len(next_element) not in [2, 3]): + raise ValueError( + 'Please provide model inputs as a list or tuple of 2 or 3' + 'elements: (input, target) or (input, target, sample_weights)' + 'Received %s' % next_element) + if len(next_element) == 2: + x, y = next_element + else: + x, y, sample_weight = next_element x, y, sample_weights = self._standardize_weights(x, y, sample_weight, class_weight, batch_size) return x, y, sample_weights @@ -1331,7 +1336,8 @@ class Model(Network): (in case the model has multiple inputs). - A dict mapping input names to the corresponding array/tensors, if the model has named inputs. - - A `tf.data` dataset or a dataset iterator. + - A `tf.data` dataset or a dataset iterator. Should return a tuple + of either (inputs, targets) or (inputs, targets, sample_weights). y: Target data. Like the input data `x`, it could be either Numpy array(s) or TensorFlow tensor(s). It should be consistent with `x` (you cannot have Numpy inputs and @@ -1396,7 +1402,8 @@ class Model(Network): to apply a different weight to every timestep of every sample. In this case you should make sure to specify `sample_weight_mode="temporal"` in `compile()`. This argument is not - supported when `x` is a dataset or a dataset iterator. + supported when `x` is a dataset or a dataset iterator, instead + provide the sample_weights as the third element of `x`. initial_epoch: Integer. Epoch at which to start training (useful for resuming a previous training run). diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py index 1e377149b6..f5bf2429d0 100644 --- a/tensorflow/python/keras/engine/training_eager.py +++ b/tensorflow/python/keras/engine/training_eager.py @@ -417,11 +417,12 @@ def iterator_predict_loop(model, inputs, steps, verbose=0): """ assert isinstance(inputs, iterator_ops.EagerIterator) if not isinstance(inputs.output_shapes, - (list, tuple)) or len(inputs.output_shapes) > 2: + (list, tuple)) or len(inputs.output_shapes) > 3: raise ValueError( - 'Please provide data as a list or tuple of 1 or 2 elements ' - ' - input or input and target pair. Received %s. We do not use the ' - '`target` value here.' % inputs.output_shapes) + 'Please provide data as a list or tuple of 1, 2, or 3 elements ' + ' - `(input)`, or `(input, target)`, or `(input, target,' + 'sample_weights)`. Received %s. We do not use the `target` or' + '`sample_weights` value here.' % inputs.output_shapes) outs = [] if verbose == 1: progbar = generic_utils.Progbar(target=steps) diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index bf5c7fd7f8..d5c9a2ed1a 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -2097,6 +2097,43 @@ class TestTrainingWithDataset(test.TestCase): 'you should specify the `steps` argument'): model.predict(dataset, verbose=0) + @tf_test_util.run_in_graph_and_eager_modes + def test_dataset_with_sample_weights(self): + model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3) + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'mse' + metrics = ['mae', metrics_module.CategoricalAccuracy()] + model.compile(optimizer, loss, metrics=metrics) + + inputs = np.zeros((10, 3), np.float32) + targets = np.zeros((10, 4), np.float32) + sample_weights = np.ones((10), np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets, + sample_weights)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + model.train_on_batch(dataset) + model.predict_on_batch(dataset) + + @tf_test_util.run_in_graph_and_eager_modes + def test_dataset_with_sparse_labels(self): + model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3) + optimizer = RMSPropOptimizer(learning_rate=0.001) + loss = 'sparse_categorical_crossentropy' + model.compile(optimizer, loss) + + inputs = np.zeros((10, 3)) + targets = np.random.randint(0, 4, size=10, dtype=np.int32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = dataset.batch(10) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + def test_dataset_input_shape_validation(self): with self.test_session(): model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3) @@ -2108,8 +2145,10 @@ class TestTrainingWithDataset(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) dataset = dataset.repeat(100) - with self.assertRaisesRegexp(ValueError, - r'expected (.*?) to have 2 dimensions'): + with self.assertRaisesRegexp( + ValueError, + r'expected (.*?) to have shape \(3,\) but got array with shape \(1,\)' + ): model.train_on_batch(dataset) # Wrong input shape diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index f94697c913..ae5741d9f7 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -210,10 +210,11 @@ def check_num_samples(ins, def standardize_single_array(x): if x is None: return None - elif tensor_util.is_tensor(x): - return x - elif x.ndim == 1: - x = np.expand_dims(x, 1) + if x.shape is not None and len(x.shape) == 1: + if tensor_util.is_tensor(x): + return array_ops.expand_dims(x, axis=1) + else: + return np.expand_dims(x, 1) return x @@ -341,7 +342,7 @@ def standardize_sample_or_class_weights(x_weight, output_names, weight_type): Raises: ValueError: In case of invalid user-provided argument. """ - if x_weight is None or len(x_weight) == 0: # pylint: disable=g-explicit-length-test + if x_weight is None or (isinstance(x_weight, list) and len(x_weight) == 0): # pylint: disable=g-explicit-length-test return [None for _ in output_names] if len(output_names) == 1: if isinstance(x_weight, list) and len(x_weight) == 1: @@ -675,7 +676,8 @@ def standardize_weights(y, 'Expected sample_weight with rank ' 'less than or equal to ' + str(len(y.shape))) - if y.shape[:sample_weight.ndim] != sample_weight.shape: + if (not tensor_util.is_tensor(sample_weight) and + y.shape[:sample_weight.ndim] != sample_weight.shape): raise ValueError( 'Found a sample_weight array with shape ' + str(sample_weight.shape) + ' for an input with shape ' + str(y.shape) + '. ' @@ -777,7 +779,9 @@ def validate_iterator_input(x, y, sample_weight, validation_split=None): 'Received: %s' % (x, y)) if sample_weight is not None: raise ValueError('`sample_weight` argument is not supported when input ' - '`x` is a dataset or a dataset iterator. ' + '`x` is a dataset or a dataset iterator. Instead, you' + 'can provide sample_weight as the third element of your' + 'dataset, i.e. (inputs, targets, sample_weight). ' 'Received: x=%s, sample_weight=%s' % (x, sample_weight)) if validation_split is not None and validation_split != 0.0: raise ValueError( -- GitLab From ca5952670d98b568fa4ac671cf2310d78474c525 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 6 Sep 2018 11:03:14 -0700 Subject: [PATCH 196/540] Add StaticRegexFullMatch which can be used in place of RegexFullMatch when the regex pattern are fixed. This allows the Op to perform the expensive regex compilation once upon creation instead of with each call to compute. RELNOTES: Performance improvements for regex full match operations. PiperOrigin-RevId: 211835278 --- .../api_def_StaticRegexFullMatch.pbtxt | 29 +++++++++ .../core/kernels/regex_full_match_op.cc | 33 ++++++++++ tensorflow/core/ops/string_ops.cc | 6 ++ tensorflow/python/kernel_tests/BUILD | 1 + .../kernel_tests/regex_full_match_op_test.py | 60 +++++++++++++++---- tensorflow/python/ops/string_ops.py | 34 ++++++++++- 6 files changed, 151 insertions(+), 12 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt new file mode 100644 index 0000000000..6d9d9908ca --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt @@ -0,0 +1,29 @@ +op { + graph_op_name: "StaticRegexFullMatch" + in_arg { + name: "input" + description: <