diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8669c25c452b53da48239bc20c9a2d3528e75422..db4b1581ae671b1e676e215c9a80dfaab832fa21 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -90,7 +90,7 @@ Bazel BUILD files also need to include a license section, e.g., Changes to TensorFlow C++ code should conform to [Google C++ Style Guide](https://google.github.io/styleguide/cppguide.html). -Use `clang-tidy` to check your C/C++ changes. To install clang-tidy on ubuntu:16.04, do: +Use `clang-tidy` to check your C/C++ changes. To install `clang-tidy` on ubuntu:16.04, do: ```bash apt-get install -y clang-tidy diff --git a/README.md b/README.md index 6fb4486d0de9ff476b5cf1dbd63d66879637df84..63853137cfd30b396f8c7d204811f3e4a1794c07 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ $ python 42 >>> sess.close() ``` +Learn more examples about how to do specific tasks in TensorFlow at the [tutorials page of tensorflow.org](https://www.tensorflow.org/tutorials/). ## Contribution guidelines diff --git a/RELEASE.md b/RELEASE.md index 84d9d52868ecd55d38d6073315749d11c2340e8c..e09e9c6190f57adec67c2ae1d85848dabfd9c2a7 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,62 @@ +# Release 1.9.0 + +## Major Features And Improvements +* Update tf.keras to the Keras 2.1.6 API. +* `tfe.Network` is deprecated. Please inherit from `tf.keras.Model`. +* Adding support of core feature columns and losses to gradient boosted trees estimators. +* The distributions.Bijector API supports broadcasting for Bijectors with new API changes. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/distributions/bijectors/Bijector) for more details. +* Layered variable names have changed in the following conditions: + * Using `tf.keras.layers` with custom variable scopes. + * Using `tf.layers` in a subclassed `tf.keras.Model` class. See [here](https://www.tensorflow.org/versions/r1.9/api_docs/python/tf/layers) for more details + +## Breaking Chances + * If you're opening empty variable scopes; replace `variable_scope`('', ...) by `variable_scope`(`tf.get_variable_scope()`, ...). + +## Bug Fixes and Other Changes +* `tf.data`: + * The `DatasetBase::DebugString()` method is now `const`. + * Added the `tf.contrib.data.sample_from_datasets()` API for randomly sampling from multiple datasets. +* Eager Execution: +* `tf.keras`: + * Move Keras code out of _impl folder and remove API files. + * `tf.keras.Model.save_weights` now saves in TensorFlow format by default. + * Enable dataset iterators to be passed to `tf.keras.Model` training/eval methods. +* Accelerated Linear Algebra (XLA): +* TensorFlow Debugger (tfdbg): fix an issue in which the TensorBoard Debugger Plugin could not handle total source file size exceeding gRPC message size limit (4 MB). +* `tf.contrib`: + * Add `tf.contrib.data.choose_from_datasets()`. + * `tf.contrib.data.make_csv_dataset()` now supports line breaks in quoted strings. Two arguments were removed from `make_csv_dataset`. + * `tf.contrib.framework.zero_initializer` supports ResourceVariable. + * Adding "constrained_optimization" to tensorflow/contrib. +* Other: + * Add GCS Configuration Ops. + * Changing signature of `MakeIterator` to enable propagating error status. + * KL divergence for two Dirichlet distributions. + * More consistent GcsFileSystem behavior for certain reads past EOF. + * Update benchmark for tf.scan to match ranges across eager and graph modes. + * Fixed bug in `tf.reduce_prod gradient` for complex dtypes. + * Add optional `args` argument to `Dataset.from_generator()`. + * Allow the use of '.' in variables (e.g. "hparams.parse('a.b=1.0')"), which would previously raise an error. This will correspond to an attribute name with an embedded '.' symbol (e.g. 'a.b'), which can only be accessed indirectly (e.g. through getattr and setattr). To set this up the user will first need to explicitly add the variable to the hparam object (e.g. "hparams.add_hparam(name='a.b', value=0.0)"). + * Benchmark for tf.scan in graph and eager modes. + * Added complex128 support to FFT, FFT2D, FFT3D, IFFT, IFFT2D, and IFFT3D. + * Making ids unique in `nn.embedding_lookup_sparse`. This helps to reduce RPC calls for looking up the embeddings when there are repeated ids in the batch. + * Support indicator column in boosted trees. + * Prevent `tf.gradients()` from backpropagating through integer tensors. + * LinearOperator[1D,2D,3D]Circulant added to `tensorflow.linalg`. + * Conv3D, Conv3DBackpropInput, Conv3DBackpropFilter now supports arbitrary. + * Added `tf.train.Checkpoint` for reading/writing object-based checkpoints. + * `Dataset.list_files()` now produces determinstic results when `shuffle=False` or a `seed` is passed. + * Added LinearOperatorKronecker, a dense-free implementation of the Kronecker Product. + * Allow LinearOperator to broadcast. + * SavedModelBuilder will now deduplicate asset names that point to files with the same basename and the same contents. Note that this may result in new asset files included in SavedModels in cases where assets with the same name but different contents were previously overwriting each other. + + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Abdullah Alrasheed, Achal Shah, Ad-530, ADiegoCAlonso, Aditya Yogi, Ag Ramesh, akindyakov, Andy Kernahan, Anya Petrova, Aurelien Geron, Ben, Ben Barsdell, Bhavani-Subramanian, braincodercn, Brett Koonce, Brian Nemsick, Brian Zier, Bryan Heden, candy.dc, cclauss, Clayne Robison, ctiijima, Dalmo Cirne, David Norman, David T.H. Kao, DosLin, ekelsen, Elson Rodriguez, Erik Smistad, Felix Abecassis, Fergal Cotter, fo40225, foo0x29a, Freedom" Koan-Sin Tan, FréDéRic Branchaud-Charron, gdh1995, Geoffrey Irving, Giuseppe, gracehoney, Guido Zuidhof, Guillaume Klein, Guozhong Zhuang, Haggai, Harald Husum, imsheridan, Ivan Zhang, Jan Zikes, Jayaram Bobba, Jesse Benson, Jesse Gumz, Jiajia Li, Jie, jinghuangintel, Jingwen, jjsjann123, Joe Yearsley, Joel Hestness, Joel Shor, josephyearsley, Junpeng Lao, Karol M. Langner, Kb Sriram, krantideep95, Krish Ravindranath, Letian Feng, Loo Rong Jie, Lukas Geiger, Maciej, Mahmoud Abuzaina, ManHyuk, Mark Ryan, mbhuiyan, Michal Turek, Mostafa Alaa, Myungsung Kwak, Nand Dalal, Nehal J Wani, Neil Tenenholtz, ngc92, Nicholas Nadeau, P.Eng., Avs, Niranjan Hasabnis, P-Hidringer, Paul Van Eck, Peng Yu, Qing Zhao, Qingying Chen, Quanlong, Rajendra Arora, Rholais Lii, rmanyari, Robin Richtsfeld, Russell Klopfer, Sagi, Sam Sendelbach, Sandeep N Gupta, Sandip Giri, Sarah Edkins, Scott Tseng, Sdalbsoo, Sergii Khomenko, Seungwoo Choi (Biggie), Seyed Majid Azimi, Shaoning Zeng, shengfuintel, Siu Kei, Muk, Smit Shilu, soonson, Stefan Schweter, Sukhwan Kim, Sunitha Kambhampati, Taehoon Lee, tamimaddari82, Tang, Wenyi, Ted Chang, u2takey, Utkarsh Upadhyay, Vadim Markovtsev, voegtlel, Wai Hon Law, wangsiyu, Wenhao Hu, wenhao.hu, William D. Irons, Yan Facai (颜发才), Yanbo Liang, Yihong Wang, Yilei (Dolee) Yang, Yong Tang, Yuan (Terry) Tang + # Release 1.8.0 ## Major Features And Improvements @@ -404,14 +463,6 @@ answered questions, and were part of inspiring discussions. # Release 1.4.0 -## Major Features And Improvements -* `tf.keras` is now part of the core TensorFlow API. -* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of - the core TensorFlow API. - * The API is now subject to backwards compatibility guarantees. - -# Release 1.4.0 - ## Major Features And Improvements * `tf.keras` is now part of the core TensorFlow API. * [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of diff --git a/SECURITY.md b/SECURITY.md index 01886b613e5d93793953124331b57f075fe7a373..e2f6ff353a3c04a6ec6b8ccbaeb75db59fa22d54 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -168,7 +168,7 @@ below). Please use a descriptive subject line for your report email. After the initial reply to your report, the security team will endeavor to keep you informed of -the progress being made towards a fix and announcement. +the progress being made towards a fix and announcement. In addition, please include the following information along with your report: @@ -242,9 +242,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= -----END PGP PUBLIC KEY BLOCK----- ``` -### Known vulnerabilities - -| Type | Versions affected | Reported by | Additional Information | -|--------------------|:-----------------:|-----------------------|-----------------------------| -| Out Of Bounds Read | <=1.4 | Blade Team of Tencent | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | +### Known Vulnerabilities +For a list of known vulnerabilities and security advisories for TensorFlow, +(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md)[click here]. diff --git a/WORKSPACE b/WORKSPACE index 4ddfb9a3832ea1ea639ace887e1d601bdd857086..fd7570a80ae2ee0087f7d2fd771fcce5b9690028 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -22,26 +22,10 @@ check_bazel_version_at_least("0.10.0") load("//tensorflow:workspace.bzl", "tf_workspace") -# Uncomment and update the paths in these entries to build the Android demo. -#android_sdk_repository( -# name = "androidsdk", -# api_level = 23, -# # Ensure that you have the build_tools_version below installed in the -# # SDK manager as it updates periodically. -# build_tools_version = "26.0.1", -# # Replace with path to Android SDK on your system -# path = "", -#) -# -#android_ndk_repository( -# name="androidndk", -# path="", -# # This needs to be 14 or higher to compile TensorFlow. -# # Please specify API level to >= 21 to build for 64-bit -# # archtectures or the Android NDK will automatically select biggest -# # API level that it supports without notice. -# # Note that the NDK version is not the API level. -# api_level=14) +load("//third_party/android:android_configure.bzl", "android_configure") +android_configure(name="local_config_android") +load("@local_config_android//:android.bzl", "android_workspace") +android_workspace() # Please add all new TensorFlow dependencies in workspace.bzl. tf_workspace() diff --git a/configure.py b/configure.py index b6c32543cf707983d48e390cc89abf13dafd55d3..ada342a50ab5104509156d3e44e6435a308255a3 100644 --- a/configure.py +++ b/configure.py @@ -670,8 +670,9 @@ def create_android_ndk_rule(environ_cp): error_msg=('The path %s or its child file "source.properties" ' 'does not exist.') ) - - write_android_ndk_workspace_rule(android_ndk_home_path) + write_action_env_to_bazelrc('ANDROID_NDK_HOME', android_ndk_home_path) + write_action_env_to_bazelrc('ANDROID_NDK_API_LEVEL', + check_ndk_level(android_ndk_home_path)) def create_android_sdk_rule(environ_cp): @@ -733,41 +734,12 @@ def create_android_sdk_rule(environ_cp): error_msg=('The selected SDK does not have build-tools version %s ' 'available.')) - write_android_sdk_workspace_rule(android_sdk_home_path, - android_build_tools_version, - android_api_level) - - -def write_android_sdk_workspace_rule(android_sdk_home_path, - android_build_tools_version, - android_api_level): - print('Writing android_sdk_workspace rule.\n') - with open(_TF_WORKSPACE, 'a') as f: - f.write(""" -android_sdk_repository( - name="androidsdk", - api_level=%s, - path="%s", - build_tools_version="%s")\n -""" % (android_api_level, android_sdk_home_path, android_build_tools_version)) - - -def write_android_ndk_workspace_rule(android_ndk_home_path): - print('Writing android_ndk_workspace rule.') - ndk_api_level = check_ndk_level(android_ndk_home_path) - if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: - print('WARNING: The API level of the NDK in %s is %s, which is not ' - 'supported by Bazel (officially supported versions: %s). Please use ' - 'another version. Compiling Android targets may result in confusing ' - 'errors.\n' % (android_ndk_home_path, ndk_api_level, - _SUPPORTED_ANDROID_NDK_VERSIONS)) - with open(_TF_WORKSPACE, 'a') as f: - f.write(""" -android_ndk_repository( - name="androidndk", - path="%s", - api_level=%s)\n -""" % (android_ndk_home_path, ndk_api_level)) + write_action_env_to_bazelrc('ANDROID_BUILD_TOOLS_VERSION', + android_build_tools_version) + write_action_env_to_bazelrc('ANDROID_SDK_API_LEVEL', + android_api_level) + write_action_env_to_bazelrc('ANDROID_SDK_HOME', + android_sdk_home_path) def check_ndk_level(android_ndk_home_path): @@ -780,18 +752,16 @@ def check_ndk_level(android_ndk_home_path): revision = re.search(r'Pkg.Revision = (\d+)', filedata) if revision: - return revision.group(1) - return None - - -def workspace_has_any_android_rule(): - """Check the WORKSPACE for existing android_*_repository rules.""" - with open(_TF_WORKSPACE, 'r') as f: - workspace = f.read() - has_any_rule = re.search(r'^android_[ns]dk_repository', - workspace, - re.MULTILINE) - return has_any_rule + ndk_api_level = revision.group(1) + else: + raise Exception('Unable to parse NDK revision.') + if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The API level of the NDK in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % (android_ndk_home_path, ndk_api_level, + _SUPPORTED_ANDROID_NDK_VERSIONS)) + return ndk_api_level def set_gcc_host_compiler_path(environ_cp): @@ -1223,7 +1193,7 @@ def set_tf_cuda_compute_capabilities(environ_cp): # Check whether all capabilities from the input is valid all_valid = True # Remove all whitespace characters before splitting the string - # that users may insert by accident, as this will result in error + # that users may insert by accident, as this will result in error tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split()) for compute_capability in tf_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) @@ -1427,6 +1397,10 @@ def set_grpc_build_flags(): write_to_bazelrc('build --define grpc_no_ares=true') +def set_build_strip_flag(): + write_to_bazelrc('build --strip=always') + + def set_windows_build_flags(): if is_windows(): # The non-monolithic build is not supported yet @@ -1549,23 +1523,18 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) + set_build_strip_flag() set_windows_build_flags() - if workspace_has_any_android_rule(): - print('The WORKSPACE file has at least one of ["android_sdk_repository", ' - '"android_ndk_repository"] already set. Will not ask to help ' - 'configure the WORKSPACE. Please delete the existing rules to ' - 'activate the helper.\n') - else: - if get_var( - environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', - False, - ('Would you like to interactively configure ./WORKSPACE for ' - 'Android builds?'), - 'Searching for NDK and SDK installations.', - 'Not configuring the WORKSPACE for Android builds.'): - create_android_ndk_rule(environ_cp) - create_android_sdk_rule(environ_cp) + if get_var( + environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', + False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), + 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): + create_android_ndk_rule(environ_cp) + create_android_sdk_rule(environ_cp) print('Preconfigured Bazel build configs. You can use any of the below by ' 'adding "--config=<>" to your build command. See tools/bazel.rc for ' diff --git a/tensorflow/BUILD b/tensorflow/BUILD index f2ad16fa04f5beb6616c58c28d0f0c460c3e3a17..6d134dbb80cb8c3dcf15b2ba20783870a67e9a62 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -19,6 +19,10 @@ load( "//tensorflow/core:platform/default/build_config.bzl", "tf_additional_binary_deps", ) +load( + "//tensorflow/tools/api/generator:api_gen.bzl", + "gen_api_init_files", # @unused +) # Config setting for determining if we are building for Android. config_setting( @@ -471,7 +475,7 @@ tf_cc_shared_object( # excludes all but a subset of function names. # On MacOS, the linker does not support version_script, but has an # an "-exported_symbols_list" command. -z defs disallows undefined -# symbols in object files and -s strips the output. +# symbols in object files. tf_cc_shared_object( name = "libtensorflow.so", @@ -485,7 +489,6 @@ tf_cc_shared_object( "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", - "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow/c:version_script.lds)", ], @@ -511,7 +514,6 @@ tf_cc_shared_object( "//tensorflow:windows_msvc": [], "//conditions:default": [ "-z defs", - "-s", "-Wl,--version-script", # This line must be directly followed by the version_script.lds file "$(location //tensorflow:tf_version_script.lds)", ], @@ -536,13 +538,19 @@ exports_files( ], ) +gen_api_init_files( + name = "tensorflow_python_api_gen", + srcs = ["api_template.__init__.py"], + root_init_template = "api_template.__init__.py", +) + py_library( name = "tensorflow_py", - srcs = ["__init__.py"], + srcs = [ + ":tensorflow_python_api_gen", + "//tensorflow/python/estimator/api:estimator_python_api_gen", + ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [ - "//tensorflow/python", - "//tensorflow/tools/api/generator:python_api", - ], + deps = ["//tensorflow/python"], ) diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index c8683e3976c90add3f1f54d8e575c798327e9273..440e9f8dbd2f4b2a2ab78eaaf26408584e7c1446 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -22,9 +22,6 @@ from __future__ import print_function # pylint: disable=g-bad-import-order from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import -# pylint: disable=wildcard-import -from tensorflow.tools.api.generator.api import * # pylint: disable=redefined-builtin -# pylint: enable=wildcard-import from tensorflow.python.util.lazy_loader import LazyLoader contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9662d7b478ba61c69edc20b0d47293f9939e7881 --- /dev/null +++ b/tensorflow/api_template.__init__.py @@ -0,0 +1,58 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Bring in all of the public TensorFlow interface into this module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=g-bad-import-order +from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import +# API IMPORTS PLACEHOLDER + +try: + import os # pylint: disable=g-import-not-at-top + # Add `estimator` attribute to allow access to estimator APIs via + # "tf.estimator..." + from tensorflow.python.estimator.api import estimator # pylint: disable=g-import-not-at-top + + # Add `estimator` to the __path__ to allow "from tensorflow.estimator..." + # style imports. + from tensorflow.python.estimator import api as estimator_api # pylint: disable=g-import-not-at-top + __path__ += [os.path.dirname(estimator_api.__file__)] + del estimator_api + del os +except (ImportError, AttributeError): + print('tf.estimator package not installed.') + +from tensorflow.python.util.lazy_loader import LazyLoader # pylint: disable=g-import-not-at-top +contrib = LazyLoader('contrib', globals(), 'tensorflow.contrib') +del LazyLoader + +from tensorflow.python.platform import flags # pylint: disable=g-import-not-at-top +app.flags = flags # pylint: disable=undefined-variable + +del absolute_import +del division +del print_function + +# These symbols appear because we import the python package which +# in turn imports from tensorflow.core and tensorflow.python. They +# must come from this module. So python adds these symbols for the +# resolution to succeed. +# pylint: disable=undefined-variable +del python +del core +# pylint: enable=undefined-variable diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index b86b277ac3200b88ae03490a6c1b64d464e81950..cb0b093ad260e000dcef9d1123e967a77cf1a041 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -631,7 +631,22 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, "Failed to allocate memory to serialize message of type '", in.GetTypeName(), "' and size ", proto_size); } - in.SerializeToArray(buf, proto_size); + // SerializeToArray takes size as an int. + // This next 'if' is a workaround till we update to depend on a version + // of protocol buffers that includes + // https://github.com/google/protobuf/pull/4739 + if (proto_size > std::numeric_limits::max()) { + return InvalidArgument("Cannot serialize protocol buffer of type ", + in.GetTypeName(), " as the serialized size (", + proto_size, + "bytes) would be larger than the limit (", + std::numeric_limits::max(), " bytes)"); + } + if (!in.SerializeToArray(buf, proto_size)) { + return InvalidArgument("Unable to serialize ", in.GetTypeName(), + " protocol buffer, perhaps the serialized size (", + proto_size, " bytes) is too large?"); + } out->data = buf; out->length = proto_size; out->data_deallocator = [](void* data, size_t length) { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 9ce781fab0be709fb0f115a2206eea4c2826bf36..f265da2c2c89c0e9caf14f2213c606fcb69997e0 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -14,6 +14,7 @@ tf_cuda_library( name = "c_api", srcs = [ "c_api.cc", + "c_api_debug.cc", "c_api_internal.h", ], hdrs = ["c_api.h"], @@ -45,6 +46,7 @@ tf_cuda_library( "//tensorflow:with_xla_support": [ "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:xla_device", ], "//conditions:default": [], }) + [ @@ -99,9 +101,31 @@ tf_cuda_library( ], ) +tf_cuda_library( + name = "c_api_test_util", + testonly = 1, + srcs = ["c_api_test_util.cc"], + hdrs = ["c_api_test_util.h"], + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + ], + deps = [ + ":c_api", + "//tensorflow/c:c_test_util", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + ], +) + tf_cuda_cc_test( name = "c_api_test", - srcs = ["c_api_test.cc"], + srcs = [ + "c_api_debug_test.cc", + "c_api_test.cc", + ], extra_copts = tfe_xla_copts(), tags = [ "guitar", @@ -109,6 +133,7 @@ tf_cuda_cc_test( ], deps = [ ":c_api", + ":c_api_test_util", "//tensorflow/c:c_test_util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 216210c88c1593ebc68f604547ab06b543a7b2af..81221c4078bec9820ee187efdf0314da378be62b 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -73,10 +73,6 @@ string DeviceName(const tensorflow::Device* d) { return (d == nullptr) ? "cpu:0" : d->name(); } -#ifdef TENSORFLOW_EAGER_USE_XLA -std::atomic_int_fast64_t func_id_generator(0); -#endif // TENSORFLOW_EAGER_USE_XLA - tensorflow::Status GetAllRemoteDevices( const std::vector& remote_workers, tensorflow::WorkerCacheInterface* worker_cache, diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 574a097e0d6f5d6e7acd77cae246678b6675129b..1862af3ce2f505a6e83b4805417eaf335ed07bc0 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -191,6 +191,45 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice( TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, TF_Status* status); +// Debugging/Profiling information for TFE_TensorHandle +// +// TFE_TensorDebugInfo contains information useful for debugging and +// profiling tensors. +typedef struct TFE_TensorDebugInfo TFE_TensorDebugInfo; + +// Retrieves TFE_TensorDebugInfo for `handle`. +// If TFE_TensorHandleTensorDebugInfo succeeds, `status` is set to OK and caller +// is responsible for deleting returned TFE_TensorDebugInfo. +// If TFE_TensorHandleTensorDebugInfo fails, `status` is set to appropriate +// error and nullptr is returned. This function can block till the operation +// that produces `handle` has completed. +TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( + TFE_TensorHandle* handle, TF_Status* status); + +// Deletes `debug_info`. +TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( + TFE_TensorDebugInfo* debug_info); + +// Returns the number of dimensions used to represent the tensor on its device. +// The number of dimensions used to reprensent the tensor on device can be +// different from the number returned by TFE_TensorHandleNumDims. +// The return value was current at the time of TFE_TensorDebugInfo creation. +TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( + TFE_TensorDebugInfo* debug_info); + +// Returns the number of elements in dimension `dim_index`. +// Tensor representation on device can be transposed from its representation +// on host. The data contained in dimension `dim_index` on device +// can correspond to the data contained in another dimension in on-host +// representation. The dimensions are indexed using the standard TensorFlow +// major-to-minor order (slowest varying dimension first), +// not the XLA's minor-to-major order. +// On-device dimensions can be padded. TFE_TensorDebugInfoOnDeviceDim returns +// the number of elements in a dimension after padding. +// The return value was current at the time of TFE_TensorDebugInfo creation. +TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim( + TFE_TensorDebugInfo* debug_info, int dim_index); + // Description of the TensorFlow op to execute. // // Assumes that the provided 'ctx' outlives the returned TFE_Op, i.e., diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc new file mode 100644 index 0000000000000000000000000000000000000000..5006b76f1981d068e99a2c081115ebb3a66d8c7f --- /dev/null +++ b/tensorflow/c/eager/c_api_debug.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" + +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api_internal.h" +#ifdef TENSORFLOW_EAGER_USE_XLA +#include "tensorflow/compiler/jit/xla_device.h" +#endif // TENSORFLOW_EAGER_USE_XLA + +using tensorflow::int64; +using tensorflow::string; + +namespace { + +std::vector TensorShapeAsVector(TFE_TensorHandle* handle, + TF_Status* status) { + std::vector shape; + int rank = TFE_TensorHandleNumDims(handle, status); + if (!status->status.ok()) { + return shape; + } + shape.reserve(rank); + for (int i = 0; i < rank; ++i) { + shape.push_back(TFE_TensorHandleDim(handle, i, status)); + if (!status->status.ok()) { + return shape; + } + } + return shape; +} + +} // namespace + +extern "C" { + +TF_CAPI_EXPORT extern TFE_TensorDebugInfo* TFE_TensorHandleTensorDebugInfo( + TFE_TensorHandle* handle, TF_Status* status) { + const tensorflow::Tensor* tensor; + status->status = handle->handle->Tensor(&tensor); + if (!status->status.ok()) { + return nullptr; + } + + tensorflow::Device* device; + status->status = handle->handle->Device(&device); + if (!status->status.ok()) { + return nullptr; + } + +#ifdef TENSORFLOW_EAGER_USE_XLA + // If tensor resides on an XLA device, use XLA device's PaddedShapeFn. + tensorflow::XlaDevice* xla_device = + dynamic_cast(device); + if (xla_device != nullptr) { + tensorflow::XlaDevice::PaddedShapeFn shape_fn = + xla_device->metadata().padded_shape_fn(); + xla::Shape padded_shape; + status->status = shape_fn(*tensor, &padded_shape); + if (!status->status.ok()) { + return nullptr; + } + if (VLOG_IS_ON(3)) { + std::vector shape_to_log = TensorShapeAsVector(handle, status); + if (!status->status.ok()) { + // Ignore the status here as we are simply logging. + status->status = tensorflow::Status::OK(); + } else { + VLOG(3) << "Fully padded shape of [" + << tensorflow::str_util::Join(shape_to_log, ", ") << "] is " + << padded_shape.DebugString(); + } + } + + if (xla::ShapeUtil::IsTuple(padded_shape)) { + if (xla::ShapeUtil::TupleElementCount(padded_shape) != 2) { + // Currently, the only case of XlaTensor containing a tuple shape is to + // represent 64 bit ints, doubles, and complex numbers (we don't support + // 64bit complex numbers). + status->status = tensorflow::errors::InvalidArgument( + "XlaTensors should only contain tuples of size 2. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + + // shape0 is not a const& because we will assign it to padded_shape below. + // It is illegal to assign a part of a message to itself. + xla::Shape shape0 = xla::ShapeUtil::GetTupleElementShape(padded_shape, 0); + const xla::Shape& shape1 = + xla::ShapeUtil::GetTupleElementShape(padded_shape, 1); + if (xla::ShapeUtil::IsTuple(shape0) || xla::ShapeUtil::IsTuple(shape1)) { + status->status = tensorflow::errors::InvalidArgument( + "XlaTensors should not contain nested tuples. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + if (!xla::ShapeUtil::Equal(shape0, shape1)) { + status->status = tensorflow::errors::InvalidArgument( + "Subshapes of XlaTensors should be the same. Shape: ", + padded_shape.DebugString()); + return nullptr; + } + + // Since the only case we handle here are two equal subshapes, we + // simply return one of them. The caller will interpret it as this + // shape directly storing the 64bit types. This approximation is good + // enough for this API's debugging use case. + padded_shape = shape0; + } + + int rank = padded_shape.dimensions_size(); + std::vector dev_dims; + dev_dims.reserve(rank); + if (rank == 1) { + // Rank 1 tensors might not have padded_shape.layout.minor_to_major set, + dev_dims.push_back(padded_shape.dimensions(0)); + } else { + for (int i = rank - 1; i >= 0; --i) { + int64 dim_index = padded_shape.layout().minor_to_major(i); + dev_dims.push_back(padded_shape.dimensions(dim_index)); + } + } + status->status = tensorflow::Status::OK(); + return new TFE_TensorDebugInfo(dev_dims); + } +#endif // TENSORFLOW_EAGER_USE_XLA + + // If the tensor is not an XLA tensor, the device shape is + // the same as regular tensor shape. + std::vector dev_dims = TensorShapeAsVector(handle, status); + if (!status->status.ok()) { + return nullptr; + } + return new TFE_TensorDebugInfo(dev_dims); +} + +TF_CAPI_EXPORT extern void TFE_DeleteTensorDebugInfo( + TFE_TensorDebugInfo* debug_info) { + delete debug_info; +} + +TF_CAPI_EXPORT extern int TFE_TensorDebugInfoOnDeviceNumDims( + TFE_TensorDebugInfo* debug_info) { + return debug_info->dev_dims.size(); +} + +TF_CAPI_EXPORT extern int64_t TFE_TensorDebugInfoOnDeviceDim( + TFE_TensorDebugInfo* debug_info, int dim_index) { + return debug_info->dev_dims[dim_index]; +} + +} // extern "C" diff --git a/tensorflow/c/eager/c_api_debug_test.cc b/tensorflow/c/eager/c_api_debug_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cddb9f6e00e9d639026f4bbe061d58f76771c0a9 --- /dev/null +++ b/tensorflow/c/eager/c_api_debug_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api.h" + +#include +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +TEST(CApiDebug, ScalarCPU) { + TFE_TensorHandle* h = TestScalarTensorHandle(); + TF_Status* status = TF_NewStatus(); + TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + ASSERT_EQ(0, TFE_TensorDebugInfoOnDeviceNumDims(debug_info)); + + TFE_DeleteTensorDebugInfo(debug_info); + TFE_DeleteTensorHandle(h); + TF_DeleteStatus(status); +} + +TEST(CApiDebug, 2DCPU) { + TFE_TensorHandle* h = TestMatrixTensorHandle3X2(); + TF_Status* status = TF_NewStatus(); + TFE_TensorDebugInfo* debug_info = TFE_TensorHandleTensorDebugInfo(h, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + ASSERT_EQ(2, TFE_TensorDebugInfoOnDeviceNumDims(debug_info)); + // Shape is the same for CPU tensors. + EXPECT_EQ(3, TFE_TensorDebugInfoOnDeviceDim(debug_info, 0)); + EXPECT_EQ(2, TFE_TensorDebugInfoOnDeviceDim(debug_info, 1)); + + TFE_DeleteTensorDebugInfo(debug_info); + TFE_DeleteTensorHandle(h); + TF_DeleteStatus(status); +} diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 2b8384d72038c3b4a050c70b3e7c5e0ca0bd94f3..04a6efc47c5177c82b7e88168b67cc584587de7c 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -107,6 +107,14 @@ struct TFE_TensorHandle { tensorflow::TensorHandle* handle; }; +struct TFE_TensorDebugInfo { + TFE_TensorDebugInfo(const std::vector& dims) + : dev_dims(dims) {} + + // Fully-padded, minor-to-major. + std::vector dev_dims; +}; + struct TFE_Op { // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a // primitive operation. diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 49646bb73599d96fce2df90f918e692df7972aeb..992d1afd5fcb0641794bb2abbe5ab20a287d3b62 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include +#include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/core/distributed_runtime/rpc/eager/eager_grpc_server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -32,122 +33,6 @@ using tensorflow::string; namespace { -TFE_TensorHandle* DoubleTestMatrixTensorHandle() { - int64_t dims[] = {2, 2}; - double data[] = {1.0, 2.0, 3.0, 4.0}; - TF_Tensor* t = TF_AllocateTensor( - TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_TensorHandle* TestMatrixTensorHandle() { - int64_t dims[] = {2, 2}; - float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_TensorHandle* TestMatrixTensorHandle3X2() { - int64_t dims[] = {3, 2}; - double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; - TF_Tensor* t = TF_AllocateTensor( - TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { - TF_Status* status = TF_NewStatus(); - - TFE_Op* op = TFE_NewOp(ctx, "MatMul", status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, a, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, b, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteStatus(status); - TFE_OpSetAttrBool(op, "transpose_a", 0); - TFE_OpSetAttrBool(op, "transpose_b", 0); - TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); - - return op; -} - -TFE_TensorHandle* TestAxisTensorHandle() { - int64_t dims[] = {1}; - int data[] = {1}; - TF_Tensor* t = TF_AllocateTensor( - TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); - memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); - TF_Status* status = TF_NewStatus(); - TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteTensor(t); - TF_DeleteStatus(status); - return th; -} - -TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, - TFE_TensorHandle* axis) { - TF_Status* status = TF_NewStatus(); - - TFE_Op* op = TFE_NewOp(ctx, "Min", status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, input, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpAddInput(op, axis, status); - CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TFE_OpSetAttrBool(op, "keep_dims", 1); - TFE_OpSetAttrType(op, "Tidx", TF_INT32); - TF_DeleteStatus(status); - TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); - - return op; -} - -// If there is a GPU device, returns true and sets 'gpu_device_name' -// accordingly. -bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); - CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); - - const int num_devices = TF_DeviceListCount(devices); - for (int i = 0; i < num_devices; ++i) { - const string device_type(TF_DeviceListType(devices, i, status.get())); - CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - const string device_name(TF_DeviceListName(devices, i, status.get())); - CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); - if (device_type == "GPU") { - *gpu_device_name = device_name; - LOG(INFO) << "Found GPU device " << device_name; - TF_DeleteDeviceList(devices); - return true; - } - } - TF_DeleteDeviceList(devices); - return false; -} - void BM_InitOp(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); @@ -257,8 +142,10 @@ void TestRemoteExecute(bool async) { TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), status); - TFE_ContextOptionsSetAsync(opts, static_cast(1)); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); TFE_Context* ctx = TFE_NewContext(opts, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); @@ -320,6 +207,83 @@ void TestRemoteExecute(bool async) { TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); } TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); } +void TestRemoteExecuteSilentCopies(bool async) { + tensorflow::ServerDef server_def = GetServerDef(2); + + // This server def has the task index set to 0. + string serialized = server_def.SerializeAsString(); + + server_def.set_task_index(1); + + std::unique_ptr worker_server; + ASSERT_TRUE( + tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server) + .ok()); + ASSERT_TRUE(worker_server->Start().ok()); + + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(), + status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_ContextOptionsSetAsync(opts, static_cast(1)); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(); + TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(); + const char remote_device_name[] = + "/job:localhost/replica:0/task:1/device:CPU:0"; + + // Handles are on task0, but op is on remote (task1). + TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0); + TFE_OpSetDevice(matmul, remote_device_name, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + auto* retval_task0 = TFE_TensorHandleCopyToDevice( + retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteTensorHandle(retval_task0); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TFE_DeleteTensorHandle(h0_task0); + TFE_DeleteTensorHandle(h1_task0); + TFE_DeleteTensorHandle(retvals[0]); + + TFE_DeleteOp(matmul); + + TFE_ContextAsyncWait(ctx, status); + TFE_DeleteContext(ctx, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TF_DeleteStatus(status); + + // TODO(nareshmodi): Figure out how to correctly shut the server down. + worker_server.release(); +} + +TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); } +TEST(CAPI, RemoteExecuteSilentCopiesAsync) { + TestRemoteExecuteSilentCopies(true); +} + TEST(CAPI, TensorHandle) { TFE_TensorHandle* h = TestMatrixTensorHandle(); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); @@ -536,7 +500,7 @@ void TensorHandleSilentCopy(bool async) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -583,7 +547,7 @@ void TensorHandleSilentCopyLocal(bool async) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); @@ -624,7 +588,7 @@ void SetAndGetOpDevices(bool async) { // Disable the test if no GPU is present. string gpu_device_name; - if (GetGPUDeviceName(ctx, &gpu_device_name)) { + if (GetDeviceName(ctx, &gpu_device_name, "GPU")) { TFE_OpSetDevice(matmul, "GPU:0", status); ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); const char* device_name = TFE_OpGetDevice(matmul, status); @@ -688,7 +652,7 @@ void Execute_MatMul_CPU_Runtime_Error(bool async) { TFE_DeleteContextOptions(opts); TFE_TensorHandle* m1 = TestMatrixTensorHandle(); - TFE_TensorHandle* m2 = TestMatrixTensorHandle3X2(); + TFE_TensorHandle* m2 = DoubleTestMatrixTensorHandle3X2(); TFE_Op* matmul = MatMulOp(ctx, m1, m2); TFE_OpSetDevice(matmul, "/job:localhost/replica:0/task:0/device:CPU:0", status); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..5607c9dcb0bbec72b2f86def3dd4e6590d73197b --- /dev/null +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -0,0 +1,163 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/c_api_test_util.h" + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +using tensorflow::string; + +TFE_TensorHandle* TestScalarTensorHandle() { + float data[] = {1.0f}; + TF_Tensor* t = TF_AllocateTensor(TF_FLOAT, nullptr, 0, sizeof(float)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* DoubleTestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_DOUBLE, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestMatrixTensorHandle() { + int64_t dims[] = {2, 2}; + float data[] = {1.0f, 2.0f, 3.0f, 4.0f}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + double data[] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestMatrixTensorHandle3X2() { + int64_t dims[] = {3, 2}; + float data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + TF_Tensor* t = TF_AllocateTensor( + TF_FLOAT, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "MatMul", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, b, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + TFE_OpSetAttrBool(op, "transpose_a", 0); + TFE_OpSetAttrBool(op, "transpose_b", 0); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; +} + +TFE_TensorHandle* TestAxisTensorHandle() { + int64_t dims[] = {1}; + int data[] = {1}; + TF_Tensor* t = TF_AllocateTensor( + TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "Min", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, input, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, axis, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrBool(op, "keep_dims", 1); + TFE_OpSetAttrType(op, "Tidx", TF_INT32); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); + + return op; +} + +bool GetDeviceName(TFE_Context* ctx, string* device_name, + const char* device_type) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + const int num_devices = TF_DeviceListCount(devices); + for (int i = 0; i < num_devices; ++i) { + const string dev_type(TF_DeviceListType(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + const string dev_name(TF_DeviceListName(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + if (dev_type == device_type) { + *device_name = dev_name; + LOG(INFO) << "Found " << device_type << " device " << *device_name; + TF_DeleteDeviceList(devices); + return true; + } + } + TF_DeleteDeviceList(devices); + return false; +} diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..474cae67c89249af3a62707f0db00ba458ca8f31 --- /dev/null +++ b/tensorflow/c/eager/c_api_test_util.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ +#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ + +#include "tensorflow/c/eager/c_api.h" + +#include "tensorflow/core/platform/types.h" + +// Return a tensor handle containing a float scalar +TFE_TensorHandle* TestScalarTensorHandle(); + +// Return a tensor handle containing a 2x2 matrix of doubles +TFE_TensorHandle* DoubleTestMatrixTensorHandle(); + +// Return a tensor handle containing a 2x2 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle(); + +// Return a tensor handle containing a 3x2 matrix of doubles +TFE_TensorHandle* DoubleTestMatrixTensorHandle3X2(); + +// Return a tensor handle containing a 3x2 matrix of floats +TFE_TensorHandle* TestMatrixTensorHandle3X2(); + +// Return a matmul op multiplying `a` by `b`. +TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b); + +// Return an 1-D INT32 tensor containing a single value 1. +TFE_TensorHandle* TestAxisTensorHandle(); + +// Return an op taking minimum of `input` long `axis` dimension. +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis); + +// If there is a device of type `device_type`, returns true +// and sets 'device_name' accordingly. +// `device_type` must be either "GPU" or "TPU". +bool GetDeviceName(TFE_Context* ctx, tensorflow::string* device_name, + const char* device_type); + +#endif // TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_ diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 1833b25fea0047c9652318e49599ba623daaec26..734e712daa39c03f0177eb199b1acb1b19e5d845 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -48,7 +48,7 @@ struct OpTapeEntry { // Should be called before deleting the backward function. TODO(apassos) use // unique_ptrs to ensure this happens. - std::function backward_function_deleter; + std::function backward_function_deleter; }; // Map from tensor_id to internally-defined operation-id of the operation which @@ -110,12 +110,6 @@ class VSpace { // Deletes the input tensor. virtual void DeleteGradient(Gradient* gradient) const = 0; - - // Lets this VSpace know that it can release resources held by the - // `backward_function`, It will not be called again. - // `backward_function` must not be null. - virtual void ReleaseBackwardFunction( - BackwardFunction* backward_function) const = 0; }; // Traces the execution of operations, doing eager garbage collection, and @@ -130,7 +124,7 @@ class GradientTape { GradientTape(bool persistent) : persistent_(persistent) {} ~GradientTape() { for (const auto& pair : op_tape_) { - pair.second.backward_function_deleter(); + pair.second.backward_function_deleter(pair.second.backward_function); } } @@ -139,12 +133,12 @@ class GradientTape { void Watch(int64 tensor_id); - void RecordOperation(const string& op_type, - gtl::ArraySlice output_tensors, - gtl::ArraySlice input_tensor_id, - gtl::ArraySlice input_dtypes, - BackwardFunction* backward_function, - const std::function& backward_function_deleter); + void RecordOperation( + const string& op_type, gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, + gtl::ArraySlice input_dtypes, + BackwardFunction* backward_function, + const std::function& backward_function_deleter); void DeleteTrace(int64 tensor_id); @@ -218,9 +212,9 @@ void GradientTape::RecordOperation( gtl::ArraySlice input_tensor_id, gtl::ArraySlice input_dtypes, BackwardFunction* backward_function, - const std::function& backward_function_deleter) { + const std::function& backward_function_deleter) { if (!ShouldRecord(input_tensor_id, input_dtypes)) { - backward_function_deleter(); + backward_function_deleter(backward_function); return; } std::vector ids; @@ -275,7 +269,7 @@ void GradientTape::DeleteTrace(int64 tensor_id) { for (int64 id : op_it->second.input_tensor_id) { DeleteTrace(id); } - op_it->second.backward_function_deleter(); + op_it->second.backward_function_deleter(op_it->second.backward_function); op_tape_.erase(op_it); } @@ -381,7 +375,8 @@ BackpropInitialState PrepareBackprop( // backward functions that will be used for gradient computation // has been transferred to `result`. for (const auto& op_pair : *op_tape) { - op_pair.second.backward_function_deleter(); + op_pair.second.backward_function_deleter( + op_pair.second.backward_function); } op_tape->clear(); } @@ -473,7 +468,7 @@ Status GradientTape::ComputeGradient( if (!persistent_) { // Release all backprop functions for (const auto& pair : state.op_tape) { - pair.second.backward_function_deleter(); + pair.second.backward_function_deleter(pair.second.backward_function); } } }; @@ -541,7 +536,7 @@ Status GradientTape::ComputeGradient( Status s = vspace.CallBackwardFunction(trace.backward_function, out_gradients, &in_gradients); if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); + trace.backward_function_deleter(trace.backward_function); } if (!s.ok()) { cleanup(); @@ -550,7 +545,7 @@ Status GradientTape::ComputeGradient( } else { in_gradients.resize(trace.input_tensor_id.size()); if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); + trace.backward_function_deleter(trace.backward_function); } for (Gradient* grad : out_gradients) { if (grad != nullptr) { diff --git a/tensorflow/c/generate-pc.sh b/tensorflow/c/generate-pc.sh index 02a6a58b6153bb78c684f9290ef95900f96e9357..7184ad68fb79f2598067d68d5ab5ba8f2c7a22c8 100755 --- a/tensorflow/c/generate-pc.sh +++ b/tensorflow/c/generate-pc.sh @@ -15,10 +15,12 @@ # ============================================================================== TF_PREFIX='/usr/local' +LIBDIR='lib' usage() { echo "Usage: $0 OPTIONS" echo -e "-p, --prefix\tset installation prefix (default: /usr/local)" + echo -e "-l, --libdir\tset lib directory (default: lib)" echo -e "-v, --version\tset TensorFlow version" echo -e "-h, --help\tdisplay this message" } @@ -26,7 +28,7 @@ usage() { [ $# == 0 ] && usage && exit 0 # read the options -ARGS=$(getopt -o p:v:h --long prefix:,version:,help -n $0 -- "$@") +ARGS=$(getopt -o p:l:v:h --long prefix:,libdir:,version:,help -n $0 -- "$@") eval set -- "$ARGS" # extract options and their arguments into variables. @@ -38,6 +40,11 @@ while true ; do "") shift 2 ;; *) TF_PREFIX=$2 ; shift 2 ;; esac ;; + -l|--libdir) + case "$2" in + "") shift 2 ;; + *) LIBDIR=$2 ; shift 2 ;; + esac ;; -v|--version) case "$2" in "") shift 2 ;; @@ -55,7 +62,7 @@ echo "Generating pkgconfig file for TensorFlow $TF_VERSION in $TF_PREFIX" cat << EOF > tensorflow.pc prefix=${TF_PREFIX} exec_prefix=\${prefix} -libdir=\${exec_prefix}/lib +libdir=\${exec_prefix}/${LIBDIR} includedir=\${prefix}/include Name: TensorFlow diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index d6a4f141b6bb8ccadb77f1fa83b5fb742d78f70f..dfdef88945deca376368edd6f7aa322b1e1cbf94 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -273,6 +273,12 @@ string PrintAttrValue(const string& op, const AttrValue& attr_value) { return ""; // Prevent missing return warning } +bool IsEmptyList(const AttrValue::ListValue& list) { + return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 && + list.b_size() == 0 && list.type_size() == 0 && + list.shape_size() == 0 && list.tensor_size() == 0; +} + string ToCamelCase(const string& str) { string result; const char joiner = '_'; @@ -297,9 +303,9 @@ string ToCamelCase(const string& str) { // indicate whether to treat the type as const when accepting the C++ type as an // argument to a function. std::pair AttrTypeName(StringPiece attr_type) { - static const std::unordered_map, - StringPieceHasher> - attr_type_map{ + static const auto* attr_type_map = + new std::unordered_map, + StringPieceHasher>{ {"string", {"StringPiece", false}}, {"list(string)", {"gtl::ArraySlice", true}}, {"int", {"int64", false}}, @@ -317,14 +323,34 @@ std::pair AttrTypeName(StringPiece attr_type) { {"func", {"NameAttrList", true}}, }; - auto entry = attr_type_map.find(attr_type); - if (entry == attr_type_map.end()) { + auto entry = attr_type_map->find(attr_type); + if (entry == attr_type_map->end()) { LOG(FATAL) << "Unsupported Attr type: " << attr_type; return {"", false}; } return entry->second; } +const char* ListElementTypeName(StringPiece attr_type) { + static const auto* attr_list_type_map = + new std::unordered_map{ + {"list(string)", "string"}, + {"list(int)", "int"}, + {"list(float)", "float"}, + {"list(bool)", "bool"}, + {"list(type)", "DataType"}, + {"list(shape)", "PartialTensorShape"}, + {"list(tensor)", "TensorProto"}, + }; + + auto entry = attr_list_type_map->find(attr_type); + if (entry == attr_list_type_map->end()) { + LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type; + return ""; + } + return entry->second; +} + bool IsCPPKeyword(StringPiece name) { static const std::unordered_set // Keywords obtained from http://en.cppreference.com/w/cpp/keyword @@ -668,6 +694,7 @@ OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; + string defaults_static_storage; for (int i = 0; i < graph_op_def.attr_size(); ++i) { const auto& attr(graph_op_def.attr(i)); @@ -705,11 +732,32 @@ string OpInfo::GetOpAttrStruct() const { "_ = x;\n"); strings::StrAppend(&setters, " return ret;\n }\n\n"); - strings::StrAppend( - &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(), - "_ = ", - PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), - ";\n"); + string field_initiliazer; + auto& default_value = api_def_attr.default_value(); + if (default_value.value_case() == AttrValue::kList && + !IsEmptyList(default_value.list())) { + // Non-empty lists need static storage for their defaults. Define a + // function with static local variable that stores the array. + strings::StrAppend(&defaults_static_storage, " static ", + attr_type_name, " Default_", api_def_attr.rename_to(), + "() {\n"); + strings::StrAppend( + &defaults_static_storage, " static const ", + ListElementTypeName(attr.type()), " kStorage[] = ", + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), + ";\n"); + strings::StrAppend(&defaults_static_storage, " return ", + attr_type_name, "(kStorage);\n }\n"); + // Set the field_initializer to call the defined function. + strings::StrAppend(&field_initiliazer, "Default_", + api_def_attr.rename_to(), "()"); + } else { + field_initiliazer = + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()); + } + strings::StrAppend(&struct_fields, " ", attr_type_name, " ", + api_def_attr.rename_to(), "_ = ", field_initiliazer, + ";\n"); } if (struct_fields.empty()) { @@ -721,6 +769,9 @@ string OpInfo::GetOpAttrStruct() const { string struct_decl = MakeComment(attrs_comment, " "); strings::StrAppend(&struct_decl, " struct Attrs {\n"); strings::StrAppend(&struct_decl, setters, struct_fields); + if (!defaults_static_storage.empty()) { + strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage); + } strings::StrAppend(&struct_decl, " };\n"); return struct_decl; diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 52c177212a8c88f1857defcc38de4a01ac47dab0..35a01e0341cb08c9b314908b6dcd76fd99c1e68b 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -38,6 +38,7 @@ REGISTER_NO_GRADIENT_OP("NotEqual"); REGISTER_NO_GRADIENT_OP("LogicalAnd"); REGISTER_NO_GRADIENT_OP("LogicalOr"); REGISTER_NO_GRADIENT_OP("LogicalNot"); +REGISTER_NO_GRADIENT_OP("Floor"); // Conjugate helper function returns the conjugate of an Output if it // is complex valued. diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index fd2cf2b67d4618dd626b8eef78eed044d7fde0a4..0ecc3feeb6fef1dd691ab2785b3221075a79ba88 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -7,6 +7,10 @@ package( load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") +# We disable some tfcompile tests in the open source build with the +# "manual" tag to avoid making our OSS users build LLVM twice +# (once for host and once for target). + test_suite( name = "all_tests", tags = ["manual"], diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 980e0eec9e23b15a97b826067bac08053a437712..8c74014614789758192691ee065f92759a113a7a 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -25,6 +25,7 @@ load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( @@ -178,7 +179,9 @@ cc_library( "//tensorflow/core/kernels:identity_n_op", "//tensorflow/core/kernels:identity_op", "//tensorflow/core/kernels:no_op", + "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", + "//tensorflow/core/kernels:shape_ops", "//tensorflow/core/kernels:variable_ops", ], ) @@ -311,9 +314,9 @@ cc_library( ":common", ":shape_inference_helpers", ":union_find", + ":xla_cluster_util", "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/jit/kernels:parallel_check_op", - "//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags", "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", @@ -331,6 +334,19 @@ cc_library( ], ) +cc_library( + name = "xla_cluster_util", + srcs = ["xla_cluster_util.cc"], + hdrs = ["xla_cluster_util.h"], + deps = [ + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:bounds_check", + ], +) + cc_library( name = "union_find", hdrs = ["union_find.h"], @@ -407,6 +423,38 @@ tf_cc_test( ], ) +cc_library( + name = "xla_fusion_optimizer", + srcs = ["xla_fusion_optimizer.cc"], + hdrs = ["xla_fusion_optimizer.h"], + visibility = ["//visibility:public"], + deps = [ + ":common", + ":union_find", + ":xla_cluster_util", + "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + ], +) + +tf_cuda_cc_test( + name = "xla_fusion_optimizer_test", + srcs = ["xla_fusion_optimizer_test.cc"], + deps = [ + ":common", + ":xla_cluster_util", + ":xla_fusion_optimizer", + "//tensorflow/core:graph", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler/utils:grappler_test", + ], +) + # This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. cc_header_only_library( name = "xla_jit_headers_lib", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 6d1e3325ebd35b9608ea273fb7de39bad381e60d..ea90d714c8fe2c2959a92a83ec105dbf2c098f7a 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -23,7 +23,6 @@ limitations under the License. #include #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" -#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" @@ -182,8 +181,7 @@ class Encapsulator { // Write a copy of the input graph to 'graph_out', where the subgraphs are // replaced with calls to the new functions. - Status BuildOutputGraph(bool parallel_checking, Graph* graph_out, - FunctionLibraryDefinition* library); + Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library); private: // A subgraph of the input, all marked with a common 'group_attribute' @@ -271,7 +269,7 @@ class Encapsulator { // Adds the function call node to graph_out. Status AddFunctionCallNode( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out. Status AddOutsideCompilationHostIONodes( @@ -284,11 +282,9 @@ class Encapsulator { // Subgraph. void GetOutsideCompilationSubgraphNames(std::vector* names) const; - // Returns the Node that inputs to the function should be wired up to. - Node* GetCallNodeForInputs() const; - - // Returns the Node that outputs to the function should be wired up to. - Node* GetCallNodeForOutputs() const; + // Returns the Node that the inputs and outputs of the function should be + // wired up to. + Node* GetCallNode() const; // Returns the index of the arg that the dst of edge should connect to. int GetArgIndexForEdge(const Edge* edge) const; @@ -425,12 +421,6 @@ class Encapsulator { OutsideCompilationSubgraph* LookupOrCreateOutsideCompilationSubgraph( const string& outside_compilation_id); - // Builds a ParallelCheck op that compares the output of the original - // subgraph with the encapsulated subgraph. - Status BuildParallelCheckOp( - const std::unordered_map& node_images, - Graph* graph_out); - // Builds a placeholder node used to provide the key input to a RecvAtHost // or SendFromHost node. This placeholder node will be removed by a later // pass. @@ -482,13 +472,8 @@ class Encapsulator { // Not owned. Node* host_compute_key_placeholder_ = nullptr; - // Function call node(s) in the output graph. Not owned. - // If parallel_checking is enabled, 'call_node_inputs' is the function call - // node to which inputs should be fed, and 'call_node_outputs' is the - // parallel check op from which outputs should be read. If parallel checking - // is disabled, both point to the function call node. - Node* call_node_inputs_; - Node* call_node_outputs_; + // Function call node in the output graph. Not owned. + Node* call_node_; // Maps from source (producer node/slot) and destination // (consumer node/slot) tensors in the input graph to _Arg numbers in @@ -541,13 +526,12 @@ class Encapsulator { // Copies all nodes that aren't in a compiled subgraph to the output graph. Status CopyNodesToOutputGraph( - bool parallel_checking, Graph* graph_out, - std::unordered_map* node_images); + Graph* graph_out, std::unordered_map* node_images); // Adds function call nodes for each compiled subgraph. Status AddFunctionCallNodes( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all // outside_compilation subgraphs. @@ -598,7 +582,7 @@ class Encapsulator { const string& src_outside_compilation_id, const string& dst_func_id, const string& dst_outside_compilation_id, const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out, + Graph* graph_out, std::unordered_set, NodeSlot::PairHasher>* edges_added); @@ -609,7 +593,7 @@ class Encapsulator { // Adds all edges to the output graph. Status AddEdgesToOutputGraph( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out); + Graph* graph_out); // Constructs a minimal shape inference graph that can be used to determine // the shape of send_node at the time that the subgraph is compiled. @@ -729,13 +713,7 @@ void TopologicalClusterSort( } // namespace -Node* Encapsulator::Subgraph::GetCallNodeForInputs() const { - return call_node_inputs_; -} - -Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const { - return call_node_outputs_; -} +Node* Encapsulator::Subgraph::GetCallNode() const { return call_node_; } int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const { return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input())); @@ -1075,7 +1053,7 @@ Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { if (sequencer_ != nullptr) { VLOG(2) << "ConnectSequencerToCallNode"; - graph_out->AddControlEdge(sequencer_, call_node_inputs_); + graph_out->AddControlEdge(sequencer_, call_node_); } } @@ -1200,83 +1178,16 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( return Status::OK(); } -Status Encapsulator::Subgraph::BuildParallelCheckOp( - const std::unordered_map& node_images, - Graph* graph_out) { - // Build an index mapping output positions to node/slot pairs in the - // original graph. - std::vector results_by_num(results_.size()); - for (const auto& entry : results_) { - results_by_num[entry.second] = entry.first; - } - - // Build a parallel check NodeDef. - int num_results = results_by_num.size(); - std::vector result_dtypes(num_results); - std::vector expected_outputs(num_results); - std::vector actual_outputs(num_results); - for (int i = 0; i < num_results; ++i) { - const NodeSlot& node_slot = results_by_num[i]; - result_dtypes[i] = node_slot.node->output_type(node_slot.slot); - expected_outputs[i] = - NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(), - node_slot.slot, result_dtypes[i]); - actual_outputs[i] = - NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]); - } - // Assign the parallel check op to a CPU on the same task as the cluster it is - // checking. - string device, dummy; - if (!DeviceNameUtils::SplitDeviceName( - call_node_inputs_->assigned_device_name(), &device, &dummy)) { - return errors::InvalidArgument("Could not parse device name"); - } - strings::StrAppend(&device, "/cpu:0"); - - NodeDef check_def; - TF_RETURN_IF_ERROR( - NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(), - "_parallel_check")), - "ParallelCheck") - .Device(device) - .Attr("T", result_dtypes) - .Input(expected_outputs) - .Input(actual_outputs) - .Finalize(&check_def)); - - Status s; - Node* check_op = graph_out->AddNode(check_def, &s); - if (!s.ok()) return s; - check_op->set_assigned_device_name(device); - - // TODO(phawkins): it seems redundant to call AddEdge as well as - // pass Inputs to the NodeDefBuilder, but I have been unable to find a - // way to avoid it. - for (int i = 0; i < num_results; ++i) { - const NodeSlot& node_slot = results_by_num[i]; - graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op, - i); - graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i); - } - - call_node_outputs_ = check_op; - return Status::OK(); -} - Status Encapsulator::Subgraph::AddFunctionCallNode( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { Status s; - call_node_inputs_ = graph_out->AddNode(call_node_def_, &s); + call_node_ = graph_out->AddNode(call_node_def_, &s); if (!s.ok()) return s; // Copy the assigned device and the key_annotation over. - call_node_inputs_->set_assigned_device_name(device_); - call_node_outputs_ = call_node_inputs_; + call_node_->set_assigned_device_name(device_); - if (parallel_checking) { - TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out)); - } return Status::OK(); } @@ -1627,27 +1538,17 @@ Status Encapsulator::BuildFunctionDefs( } Status Encapsulator::CopyNodesToOutputGraph( - bool parallel_checking, Graph* graph_out, - std::unordered_map* node_images) { + Graph* graph_out, std::unordered_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; string outside_compilation_id; TF_RETURN_IF_ERROR( GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); - // Don't copy nodes that going to be encapsulated, unless parallel checking - // is enabled. - if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking) - continue; + // Don't copy nodes that are going to be encapsulated. + if (IsInSubgraph(func_id, outside_compilation_id)) continue; Node* image = graph_out->CopyNode(node); - if (!outside_compilation_id.empty()) { - if (parallel_checking) { - return errors::InvalidArgument( - "Parallel checking is not supported when outside_compilation " - "clusters are present."); - } - } (*node_images)[node] = image; } (*node_images)[graph_in_->source_node()] = graph_out->source_node(); @@ -1657,10 +1558,10 @@ Status Encapsulator::CopyNodesToOutputGraph( Status Encapsulator::AddFunctionCallNodes( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { for (auto& subgraph_entry : subgraphs_) { - TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode( - node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR( + subgraph_entry.second.AddFunctionCallNode(node_images, graph_out)); } return Status::OK(); } @@ -1694,7 +1595,7 @@ Status Encapsulator::FindOutputImageOfEdgeSrc( } else { // The edge is from a subgraph to a regular node in the output graph so // use the subgraph's call node output. - *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs(); + *src_image = subgraphs_.at(src_func_id).GetCallNode(); } } else { // The source of the edge is in the output graph so use the node image in @@ -1742,7 +1643,7 @@ Status Encapsulator::FindOutputImageOfEdgeDst( } else { // The edge is to a subgraph from a regular node in the output graph so // use the subgraph's call node input. - *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs(); + *dst_image = subgraphs_.at(dst_func_id).GetCallNode(); } } else { // The destination of the edge is in the output graph so use the node image @@ -1778,8 +1679,7 @@ Status Encapsulator::CopyEdgeToOutputGraph( const Edge* edge, const string& src_func_id, const string& src_outside_compilation_id, const string& dst_func_id, const string& dst_outside_compilation_id, - const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out, + const std::unordered_map& node_images, Graph* graph_out, std::unordered_set, NodeSlot::PairHasher>* edges_added) { Node* src_image; @@ -1801,11 +1701,6 @@ Status Encapsulator::CopyEdgeToOutputGraph( graph_out->AddControlEdge(src_image, dst_image); } - // If parallel checking is enabled, also add a control edge to the - // corresponding parallel check op. - if (parallel_checking) { - graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); - } return Status::OK(); } @@ -1817,14 +1712,6 @@ Status Encapsulator::CopyEdgeToOutputGraph( FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id, dst_func_id, dst_outside_compilation_id, edge); - if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) && - parallel_checking) { - // If we are parallel checking, also feed the tensor as an input to the - // corresponding parallel check subgraph. - graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), - edge->dst_input()); - } - // Add the edge, if we have not already added it. if (edges_added ->emplace(NodeSlot(src_image, src_output), @@ -1839,8 +1726,8 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { for (const auto& ancestors : subgraph_ancestors_) { const string& subgraph = ancestors.first; for (const string& ancestor : ancestors.second) { - graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNodeForOutputs(), - subgraphs_[subgraph].GetCallNodeForInputs()); + graph_out->AddControlEdge(subgraphs_[ancestor].GetCallNode(), + subgraphs_[subgraph].GetCallNode()); } } return Status::OK(); @@ -1848,7 +1735,7 @@ Status Encapsulator::AddCallNodeDependencies(Graph* graph_out) { Status Encapsulator::AddEdgesToOutputGraph( const std::unordered_map& node_images, - bool parallel_checking, Graph* graph_out) { + Graph* graph_out) { // Set of edges already added to the output graph, represented as (src, dst) // pairs. We use the set to deduplicate edges; multiple edges in the input // graph may map to one edge in the output graph. @@ -1870,16 +1757,6 @@ Status Encapsulator::AddEdgesToOutputGraph( if (IsInSubgraph(src_func_id, src_outside_compilation_id) && IsInSubgraph(dst_func_id, dst_outside_compilation_id) && src_func_id == dst_func_id) { - if (parallel_checking) { - Node* src_image = node_images.at(edge->src()); - Node* dst_image = node_images.at(edge->dst()); - if (edge->IsControlEdge()) { - graph_out->AddControlEdge(src_image, dst_image); - } else { - graph_out->AddEdge(src_image, edge->src_output(), dst_image, - edge->dst_input()); - } - } continue; } @@ -1887,8 +1764,7 @@ Status Encapsulator::AddEdgesToOutputGraph( // unclustered graph. TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph( edge, src_func_id, src_outside_compilation_id, dst_func_id, - dst_outside_compilation_id, node_images, parallel_checking, graph_out, - &edges_added)); + dst_outside_compilation_id, node_images, graph_out, &edges_added)); } for (auto& subgraph_entry : subgraphs_) { @@ -2504,18 +2380,15 @@ Status Encapsulator::GetShapeInfoForOutsideCompilationSends( return Status::OK(); } -Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, +Status Encapsulator::BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library) { // Map from nodes in the input graph to nodes in the output graph. std::unordered_map node_images; - TF_RETURN_IF_ERROR( - CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images)); - TF_RETURN_IF_ERROR( - AddFunctionCallNodes(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(CopyNodesToOutputGraph(graph_out, &node_images)); + TF_RETURN_IF_ERROR(AddFunctionCallNodes(node_images, graph_out)); TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out)); - TF_RETURN_IF_ERROR( - AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(AddEdgesToOutputGraph(node_images, graph_out)); TF_RETURN_IF_ERROR( GetShapeInfoForOutsideCompilationSends(graph_out, library)); @@ -2528,8 +2401,8 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, Status EncapsulateSubgraphsInFunctions( string group_attribute, string outside_compilation_attribute, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, - bool parallel_checking, bool reuse_existing_functions, - std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library) { Status s; Encapsulator encapsulator(std::move(group_attribute), @@ -2543,8 +2416,7 @@ Status EncapsulateSubgraphsInFunctions( std::unique_ptr out(new Graph(library)); out->set_versions(graph_in.versions()); - TF_RETURN_IF_ERROR( - encapsulator.BuildOutputGraph(parallel_checking, out.get(), library)); + TF_RETURN_IF_ERROR(encapsulator.BuildOutputGraph(out.get(), library)); *graph_out = std::move(out); return Status::OK(); @@ -2585,8 +2457,6 @@ static Status RenumberArguments(Graph* graph, Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; - legacy_flags::EncapsulateSubgraphsPassFlags* flags = - legacy_flags::GetEncapsulateSubgraphsPassFlags(); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph, options.flib_def); @@ -2663,7 +2533,7 @@ Status EncapsulateSubgraphsPass::Run( TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, - rewrite_subgraph, flags->tf_xla_parallel_checking, + rewrite_subgraph, /*reuse_existing_functions=*/false, &graph_out, library)); if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 5fee36f022a7515504cb6faa5cca658481b784c5..e5dab7c657c79afa861b0443314d0c7801e4b66d 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -61,10 +61,6 @@ typedef std::function* graph_out, FunctionLibraryDefinition* library); + bool reuse_existing_functions, std::unique_ptr* graph_out, + FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate // subgraphs pass and that should in turn be compiled via XlaLaunch operators. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 5ec24d39a2c40a766dbb0ec51ebe798de620e24b..6a7cd932e53cc0428850ab048cc325e84ef1fce6 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -511,7 +511,6 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { std::unique_ptr graph_out; s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_out, lib_def.get()); if (!s.ok()) return s; @@ -560,8 +559,9 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { Node* b = Input(b1.opts().WithName("B")); // Give nodes 'c' and 'd' names that collide after lowercasing. Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); - Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr( - "_encapsulate", "F1")); + Node* d = Binary(b, c, + b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); Binary(a, d, b1.opts().WithName("E")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); } @@ -614,8 +614,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctions) { Node* c = Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr( "_encapsulate", "F1")); - Node* d = - Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr( + Node* d = Binary(b, c, + b1.opts().WithName("D").WithControlInput(control).WithAttr( "_encapsulate", "F2")); Binary(a, d, b1.opts().WithName("E")); TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); @@ -707,7 +707,7 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( "_cluster", "_outside", graph_before_encapsulation, - /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false, + /*rewrite_subgraph_fn=*/{}, /*reuse_existing_functions=*/false, &graph, &library)); std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; @@ -721,47 +721,6 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { EXPECT_EQ(expected_edges, GraphEdges(*graph)); } -TEST(EncapsulateSubgraphsTest, ParallelChecking) { - Scope root = Scope::NewRootScope().ExitOnError().WithDevice( - "/job:localhost/replica:0/task:0/cpu:0"); - auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); - auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); - auto add1 = ops::Add(root.WithOpName("add1"), x1, x2); - add1.node()->AddAttr("_cluster", "cluster1"); - auto add2 = ops::Add(root.WithOpName("add2"), add1, x2); - add2.node()->AddAttr("_cluster", "cluster1"); - auto out = ops::Mul(root.WithOpName("mul"), x1, add2); - - Graph graph_before_encapsulation(OpRegistry::Global()); - TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation)); - - FunctionLibraryDefinition library(OpRegistry::Global(), {}); - std::unique_ptr graph; - TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", "_outside", graph_before_encapsulation, - /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true, - /*reuse_existing_functions=*/false, &graph, &library)); - - std::vector expected_nodes = { - "add1", "add2", "cluster1", "cluster1_parallel_check/_0", - "mul", "x1", "x2"}; - EXPECT_EQ(expected_nodes, GraphNodes(*graph)); - - std::vector> expected_edges = { - {"add1:0", "add2:0"}, - {"add2:0", "cluster1_parallel_check/_0:0"}, - {"cluster1:0", "cluster1_parallel_check/_0:1"}, - {"cluster1_parallel_check/_0:0", "mul:1"}, - {"x1:0", "add1:0"}, - {"x1:0", "cluster1:0"}, - {"x1:0", "mul:0"}, - {"x2:0", "add1:1"}, - {"x2:0", "add2:1"}, - {"x2:0", "cluster1:1"}, - }; - EXPECT_EQ(expected_edges, GraphEdges(*graph)); -} - const Node* FindNodeByName(const Graph& graph, const string& name) { for (const Node* node : graph.nodes()) { if (node->name() == name) return node; @@ -814,7 +773,6 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { } return Status::OK(); }, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_after, &library)); EXPECT_EQ(2, guaranteed_consts); } @@ -859,7 +817,6 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { } return Status::OK(); }, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph_after, &library)); // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const // and another non-const, so overall non-const. @@ -1050,7 +1007,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { .WithAttr("_outside", "O1")); Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, shape2.opts()); - Node* h = Binary(ops::NodeOut(recv2, 0), e, + Node* h = Binary(ops::NodeOut(recv2, 1), e, shape2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") @@ -1075,7 +1032,7 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { {"outside_compilation_O1_host_compute"}}, {{"outside_compilation_O2_host_compute"}, "XlaHostCompute", - {"D:o:0", "F:o:0"}, + {"F:o:0", "D:o:0"}, {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, {"ancestors", @@ -1123,13 +1080,13 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { Node* recv2 = RecvAtHost(ops::NodeOut(key_constant, 0), "F1", "O2", {DT_FLOAT, DT_FLOAT}, b2.opts()); - Node* g = Binary(e, ops::NodeOut(recv2, 1), + Node* g = Binary(e, ops::NodeOut(recv2, 0), b2.opts() .WithName("G") .WithControlInputs({recv2, e}) .WithAttr("_encapsulate", "F1") .WithAttr("_outside", "O2")); - Node* h = Binary(ops::NodeOut(recv2, 0), e, + Node* h = Binary(ops::NodeOut(recv2, 1), e, b2.opts() .WithName("H") .WithAttr("_encapsulate", "F1") diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 27287e0f9637929b2e04c6a76de19c2785ec357e..902fe27acdec1cb323217e6409fbd02f62177612 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -148,7 +148,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { XlaCompiler::Options options; options.client = client; - options.device_type = &cache->device_type(); + options.device_type = cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD index 5d211f4d733d8d807426e62dd116092799184f35..5b6692f523658749f7ef48f9d7d89e97d4ce8b09 100644 --- a/tensorflow/compiler/jit/legacy_flags/BUILD +++ b/tensorflow/compiler/jit/legacy_flags/BUILD @@ -16,18 +16,6 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) -cc_library( - name = "encapsulate_subgraphs_pass_flags", - srcs = ["encapsulate_subgraphs_pass_flags.cc"], - hdrs = ["encapsulate_subgraphs_pass_flags.h"], - deps = - [ - "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "mark_for_compilation_pass_flags", srcs = ["mark_for_compilation_pass_flags.cc"], diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc deleted file mode 100644 index 856475f12c8a411cd80c1c1859323304ca4029e0..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module. - -#include -#include - -#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static EncapsulateSubgraphsPassFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new EncapsulateSubgraphsPassFlags; - flags->tf_xla_parallel_checking = false; - flag_list = new std::vector({ - Flag("tf_xla_parallel_checking", &flags->tf_xla_parallel_checking, - "Debug tool. Runs both JIT-compiled and interpreted graphs in " - "parallel and verifies they produce the same outputs."), - }); - xla::legacy_flags::ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -void AppendEncapsulateSubgraphsPassFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the EncapsulateSubgraphsPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace tensorflow diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h deleted file mode 100644 index d371bd269dbdfbf737d81490fb877fcf88661a8f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ -#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ - -// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace tensorflow { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -void AppendEncapsulateSubgraphsPassFlags( - std::vector* flag_list); - -// The values of flags associated with the XLA bridge's -// encapsulate_subgraphs_pass module. -typedef struct { - bool tf_xla_parallel_checking; // Debug tool. Runs both JIT-compiled and - // interpreted graphs in parallel and verifies - // they produce the same outputs. -} EncapsulateSubgraphsPassFlags; - -// Return a pointer to the EncapsulateSubgraphsPassFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags(); - -} // namespace legacy_flags -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_ diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 8e2ee0f1d71bc17b4c12c792c38002af4f9eb5eb..8c3882116dd4f048ea3e32c037bf4139c67a3eb9 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graphcycles/graphcycles.h" #include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/common_runtime/function.h" @@ -41,9 +42,6 @@ limitations under the License. namespace tensorflow { -const char* const kXlaClusterAttr = "_XlaCluster"; -const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; - namespace { bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { @@ -60,6 +58,14 @@ bool HasXLAKernel(const Node& node, const DeviceType& jit_device_type) { return false; } } + + // XLA does not offer guaranteed aliasing between the input and output of the + // XLA cluster so it can't implement the forward-tensor-ref semantic. Leave + // such nodes out of XLA clusters. + if (HasForwardedRefInput(node)) { + return false; + } + return FindKernelDef(jit_device_type, node.def(), nullptr, nullptr).ok(); } @@ -165,16 +171,6 @@ bool IsCompilableCall(const NodeDef& call_def, return true; } -// Returns the DeviceType corresponding to 'device'. -Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) { - DeviceNameUtils::ParsedName parsed; - if (!DeviceNameUtils::ParseFullName(device, &parsed)) { - return errors::Internal("Malformed assigned device '", device, "'"); - } - *device_type = DeviceType(parsed.type); - return Status::OK(); -} - // Tests whether `node` has a DT_RESOURCE typed input or output. bool HasResourceInputOrOutput(const Node& node) { return std::find(node.input_types().begin(), node.input_types().end(), @@ -183,18 +179,11 @@ bool HasResourceInputOrOutput(const Node& node) { DT_RESOURCE) != node.output_types().end(); } -struct NodeCompare { - bool operator()(const Node* a, const Node* b) const { - return a->id() < b->id(); - } -}; -using OrderedNodeSet = std::set; - // Returns true if the op can be decomposed into XLA ops for which // there are fusable elemental implementations. // -// TODO(hpucha): Consider a black list instead of a white list as -// implemented below. +// TODO(hpucha): Remove this code since this functionality is subsumed by +// Grappler XlaFusionOptimizer. bool IsXlaFusable(const NodeDef& node) { static const std::unordered_set* elementwise_ops = new std::unordered_set( @@ -364,7 +353,7 @@ Status FindCompilationCandidates( for (Node* node : graph.op_nodes()) { sorted_nodes.push_back(node); } - std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeCompare()); + std::sort(sorted_nodes.begin(), sorted_nodes.end(), NodeComparatorID()); for (Node* node : sorted_nodes) { VLOG(2) << "Fuel: " << fuel; @@ -379,9 +368,13 @@ Status FindCompilationCandidates( DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceTypeOfDevice(node->assigned_device_name(), &device_type)); + DeviceToDeviceType(node->assigned_device_name(), &device_type)); - if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue; + if (is_compilable_fn && !is_compilable_fn(node, device_type)) { + VLOG(2) << "Compilation rejected node: not compilable " << node->name() + << ": " << node->type_string(); + continue; + } const XlaOpRegistry::DeviceRegistration* registration; CHECK( @@ -430,46 +423,6 @@ struct Cluster { int representative = -1; }; -// Returns a string describing how an edge from src to dst would -// create a cycle. -string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src, - int dst) { - int32 max_path_size = graph.num_node_ids() + 1; - std::vector path(max_path_size); - int32 path_size = cycles.FindPath(dst, src, max_path_size, path.data()); - if (path_size == 0) { - return ""; - } - - auto node_name = [&cycles, &graph](int node_id) { - if (!FastBoundsCheck(node_id, graph.num_node_ids())) { - return string("(null)"); - } - auto* node = graph.FindNodeId(node_id); - if (node == nullptr) { - return string("(null)"); - } - return node->name(); - }; - - string description; - strings::StrAppend(&description, "Edge from ", node_name(src), " to ", - node_name(dst), " would create a cycle.\n"); - path.resize(path_size); - for (int32 node_id : path) { - string ascii_art; - if (node_id == dst) { - ascii_art = "+-> "; - } else if (node_id != src) { - ascii_art = "| "; - } else { - ascii_art = "+-- "; - } - strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); - } - return description; -} - } // anonymous namespace bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) { @@ -575,84 +528,13 @@ Status MarkForCompilationPass::RunImpl( : Env::Default(), is_compilable_fn, &compilation_candidates)); - GraphCycles cycles; - for (int i = 0; i < graph->num_node_ids(); ++i) { - // We rely on the node IDs in the cycle detection graph being consecutive - // integers starting from 0. - CHECK_EQ(i, cycles.NewNode()); + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + return Status::OK(); } - // Compute the loop structure of the graph. - std::vector control_flow_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); - - // The clustering code must avoid adding cycles to the graph to prevent - // deadlock. However, the graph may contain loops, which would trigger the - // cycle detection code. To handle loops, we alter the structure of the cycle - // detection graph, disconnecting each loop from the enclosing graph. - // Specifically, we: - // * add a new "frame" node for each loop. - // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges - // to/from the corresponding frame node. In essence, we collapse the loop - // into a single node for the purpose of cycle detection in the enclosing - // graph. - // * the body of the loop should now be disconnected from the rest of the - // graph; we make it acyclic by breaking loop backedges (edges outgoing from - // "NextIteration" nodes. - - // Map from frame name strings to node IDs in the cycle detection graph. - std::unordered_map frame_nodes; - - // Get the cycle graph node ID for frame 'frame_name', or add one if none - // exists. - auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) { - int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; - if (frame_id < 0) { - // The emplace succeeded; we have not allocated a frame node yet. - frame_id = cycles.NewNode(); - } - return frame_id; - }; - - for (Edge const* edge : graph->edges()) { - if (edge->dst()->IsEnter()) { - // Lift edges to an "Enter" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->dst()->id()].frame_name; - int dst = GetOrAddFrameNodeId(frame_name); - if (!cycles.InsertEdge(edge->src()->id(), dst)) { - return errors::Internal( - "Cycle detected when adding enter->frame edge: ", - DescribeCycle(cycles, *graph, edge->src()->id(), dst)); - } - continue; - } - if (edge->src()->IsExit()) { - // Lift edges from an "Exit" node to the corresponding frame node. - const string& frame_name = - control_flow_info[edge->src()->id()].frame_name; - int src = GetOrAddFrameNodeId(frame_name); - if (!cycles.InsertEdge(src, edge->dst()->id())) { - return errors::Internal( - "Cycle detected when adding frame->exit edge: ", - DescribeCycle(cycles, *graph, src, edge->dst()->id())); - } - // Drop the original edge. - continue; - } - if (edge->src()->IsNextIteration()) { - // Break loop back-edges. - continue; - } - if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) { - // This should never happen. All cycles in the graph should contain - // a control flow operator. - return errors::Internal( - "Found cycle in graph without control flow operator during XLA " - "compilation: ", - DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); - } - } + GraphCycles cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(graph, &cycles)); // Each compilation candidate belongs to a cluster. The cluster's // representative @@ -670,6 +552,9 @@ Status MarkForCompilationPass::RunImpl( // Repeatedly contract edges between clusters that are on the same device, // provided the contraction would not create a cycle. + // + // TODO(hpucha): Handle the case where kXlaClusterAttr is already set (for + // example, from the Grappler fusion pass). while (!worklist.empty()) { int from = worklist.front()->Get().representative; worklist.pop_front(); @@ -778,7 +663,7 @@ Status MarkForCompilationPass::RunImpl( // compilation. DeviceType device_type(""); TF_RETURN_IF_ERROR( - DeviceTypeOfDevice(n->assigned_device_name(), &device_type)); + DeviceToDeviceType(n->assigned_device_name(), &device_type)); const XlaOpRegistry::DeviceRegistration* registration; XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 703d8825d74ced8d4d69c31ccd730adc89a8bffe..772c92d369e67f431b5d030d1d5cdc5ae2700d39 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -633,5 +633,52 @@ TEST(XlaCompilationTest, ConstOp) { } } +TEST(XlaCompilationTest, DontClusterIdentityWithRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output add = ops::Add(root.WithOpName("add"), neg, neg); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map expected_clusters( + {{"negate", cluster_name}, {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + +TEST(XlaCompilationTest, ClusterIdentityWithNonRefInput) { + Scope root = Scope::NewRootScope().ExitOnError(); + Output variable = ops::Variable(root.WithOpName("variable"), + PartialTensorShape{}, DT_FLOAT); + Output read = ops::Identity(root.WithOpName("read"), variable); + Output neg = ops::Negate(root.WithOpName("negate"), read); + Output identity = ops::Negate(root.WithOpName("identity"), neg); + Output add = ops::Add(root.WithOpName("add"), identity, neg); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + TF_ASSERT_OK(root.ToGraph(graph.get())); + TF_ASSERT_OK(MarkForCompilation(&graph)); + + std::unordered_map clusters = GetClusters(*graph); + + ASSERT_FALSE(clusters.empty()); + string cluster_name = clusters.begin()->second; + + std::unordered_map expected_clusters( + {{"negate", cluster_name}, + {"identity", cluster_name}, + {"add", cluster_name}}); + EXPECT_EQ(clusters, expected_clusters); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..05b7821b8865d0f210ca9af92370e177d6043e80 --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_cluster_util.h" + +#include + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +const char* const kXlaClusterAttr = "_XlaCluster"; +const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; + +namespace { +// Returns a string describing how an edge from src to dst would +// create a cycle. +string DescribeCycle(const GraphCycles* cycles, const Graph& graph, int src, + int dst) { + int32 max_path_size = graph.num_node_ids() + 1; + std::vector path(max_path_size); + int32 path_size = cycles->FindPath(dst, src, max_path_size, path.data()); + if (path_size == 0) { + return ""; + } + + auto node_name = [cycles, &graph](int node_id) { + if (!FastBoundsCheck(node_id, graph.num_node_ids())) { + return string("(null)"); + } + auto* node = graph.FindNodeId(node_id); + if (node == nullptr) { + return string("(null)"); + } + return node->name(); + }; + + string description; + strings::StrAppend(&description, "Edge from ", node_name(src), " to ", + node_name(dst), " would create a cycle.\n"); + path.resize(path_size); + for (int32 node_id : path) { + string ascii_art; + if (node_id == dst) { + ascii_art = "+-> "; + } else if (node_id != src) { + ascii_art = "| "; + } else { + ascii_art = "+-- "; + } + strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); + } + return description; +} + +bool AlwaysForwardsRefInput(const Node& node) { return node.IsIdentity(); } + +} // namespace + +Status DeviceToDeviceType(const string& device, DeviceType* device_type) { + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(device, &parsed)) { + return errors::Internal("Malformed assigned device '", device, "'"); + } + *device_type = DeviceType(parsed.type); + return Status::OK(); +} + +bool HasForwardedRefInput(const Node& node) { + if (AlwaysForwardsRefInput(node)) { + for (const Edge* incoming_edge : node.in_edges()) { + if (incoming_edge->IsControlEdge()) { + continue; + } + + Node* incoming_node = incoming_edge->src(); + if (IsRefType(incoming_node->output_type(incoming_edge->src_output()))) { + VLOG(2) << "Node " << node.def().ShortDebugString() << " has ref input " + << incoming_node->name() << " " << incoming_node->type_string(); + return true; + } + } + } + return false; +} + +Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles) { + for (int i = 0; i < graph->num_node_ids(); ++i) { + // We rely on the node IDs in the cycle detection graph being consecutive + // integers starting from 0. + CHECK_EQ(i, cycles->NewNode()); + } + + // Compute the loop structure of the graph. + std::vector control_flow_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info)); + + // The clustering code must avoid adding cycles to the graph to prevent + // deadlock. However, the graph may contain loops, which would trigger the + // cycle detection code. To handle loops, we alter the structure of the cycle + // detection graph, disconnecting each loop from the enclosing graph. + // Specifically, we: + // * add a new "frame" node for each loop. + // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges + // to/from the corresponding frame node. In essence, we collapse the loop + // into a single node for the purpose of cycle detection in the enclosing + // graph. + // * the body of the loop should now be disconnected from the rest of the + // graph; we make it acyclic by breaking loop backedges (edges outgoing from + // "NextIteration" nodes. + + // Map from frame name strings to node IDs in the cycle detection graph. + std::unordered_map frame_nodes; + + // Get the cycle graph node ID for frame 'frame_name', or add one if none + // exists. + auto GetOrAddFrameNodeId = [&frame_nodes, cycles](const string& frame_name) { + int& frame_id = frame_nodes.emplace(frame_name, -1).first->second; + if (frame_id < 0) { + // The emplace succeeded; we have not allocated a frame node yet. + frame_id = cycles->NewNode(); + } + return frame_id; + }; + + for (Edge const* edge : graph->edges()) { + if (edge->dst()->IsEnter()) { + // Lift edges to an "Enter" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->dst()->id()].frame_name; + int dst = GetOrAddFrameNodeId(frame_name); + if (!cycles->InsertEdge(edge->src()->id(), dst)) { + return errors::Internal( + "Cycle detected when adding enter->frame edge: ", + DescribeCycle(cycles, *graph, edge->src()->id(), dst)); + } + continue; + } + if (edge->src()->IsExit()) { + // Lift edges from an "Exit" node to the corresponding frame node. + const string& frame_name = + control_flow_info[edge->src()->id()].frame_name; + int src = GetOrAddFrameNodeId(frame_name); + if (!cycles->InsertEdge(src, edge->dst()->id())) { + return errors::Internal( + "Cycle detected when adding frame->exit edge: ", + DescribeCycle(cycles, *graph, src, edge->dst()->id())); + } + // Drop the original edge. + continue; + } + if (edge->src()->IsNextIteration()) { + // Break loop back-edges. + continue; + } + if (!cycles->InsertEdge(edge->src()->id(), edge->dst()->id())) { + // This should never happen. All cycles in the graph should contain + // a control flow operator. + return errors::Internal( + "Found cycle in graph without control flow operator during XLA " + "compilation: ", + DescribeCycle(cycles, *graph, edge->src()->id(), edge->dst()->id())); + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_cluster_util.h b/tensorflow/compiler/jit/xla_cluster_util.h new file mode 100644 index 0000000000000000000000000000000000000000..bcce082aaf6044ff0654efa4d78c0f493a350d00 --- /dev/null +++ b/tensorflow/compiler/jit/xla_cluster_util.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Contains utilities for clustering compilable graph nodes via XLA. + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ + +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +// The attribute that marks nodes to be grouped into functions by the +// encapsulate subgraphs pass. +extern const char* const kXlaClusterAttr; + +// The attribute that marks nodes in a cluster to be placed outside the xla +// compilation by the encapsulate subgraphs pass. +extern const char* const kXlaOutsideCompilationAttr; + +using OrderedNodeSet = std::set; + +// Returns the DeviceType corresponding to 'device'. +Status DeviceToDeviceType(const string& device, DeviceType* device_type); + +// Returns true if `node` has a ref tensor input that it forwards to its output. +bool HasForwardedRefInput(const Node& node); + +// Creates a graph representation to enable cycle detection when clustering. +// This representation handles loops in graph by disconnecting each loop from +// the enclosing graph. +Status CreateCycleDetectionGraph(const Graph* graph, GraphCycles* cycles); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_CLUSTER_UTIL_H_ diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index ab644ff5a61c407b246b97af5328bf5cd8c1893b..b1943d3e1a7e321b5a3796a0c6e4f2b5d9ac7018 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -151,8 +151,7 @@ Status XlaCompileOnDemandOp::Compile( core::ScopedUnref cache_ref(cache); XlaCompiler::Options options; - DeviceType device_type = metadata.jit_device_type(); - options.device_type = &device_type; + options.device_type = metadata.jit_device_type(); options.client = metadata.client(); options.flib_def = new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index ea9e0366043a4a64bfe43703c55d4470693bbac8..43648402f65c656b6b4eb2e83e61ce45f1c73669 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -50,11 +50,12 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR( - XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, - name_prefix, registration, - /*transfer_as_literal=*/false, - /*shape_representation_fn=*/{}, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, + DEVICE_CPU_XLA_JIT, options, name_prefix, + registration, + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index f13b46c532e6008477849f2e06887901c90038ab..ed007d603ea1b3d27dd25f00726261cdd029c20c 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" @@ -105,6 +106,25 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( return alloc_ptr; } +namespace { + +// Default PaddedShapeFn implementation that simply returns the unpadded +// on-device shape. This is accurate for CPU and GPU devices that neither +// transpose nor pad tensors. +Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { + const tensorflow::XlaTensor* xla_tensor = + tensorflow::XlaTensor::FromTensor(&tensor); + if (xla_tensor == nullptr) { + return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape); + } + + const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); + *shape = shaped_buffer.on_device_shape(); + return Status::OK(); +} + +} // namespace + /* static */ Status XlaDevice::Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, @@ -112,7 +132,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( const XlaOpRegistry::DeviceRegistration& registration, bool transfer_as_literal, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, - std::unique_ptr* device) { + const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; @@ -133,17 +153,20 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( device->reset(new XlaDevice( options, attrs, device_ordinal, DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal, shape_representation_fn)); + platform.ValueOrDie(), transfer_as_literal, shape_representation_fn, + padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); return Status::OK(); } XlaDevice::Metadata::Metadata( int device_ordinal, se::Platform* platform, const DeviceType& device_type, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + PaddedShapeFn padded_shape_fn) : device_ordinal_(device_ordinal), device_type_(device_type), platform_(platform), - shape_representation_fn_(std::move(shape_representation_fn)) {} + shape_representation_fn_(std::move(shape_representation_fn)), + padded_shape_fn_(std::move(padded_shape_fn)) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -178,10 +201,11 @@ XlaDevice::XlaDevice( const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, se::Platform* platform, bool transfer_as_literal, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn) + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn) : LocalDevice(options, attrs), xla_metadata_(device_ordinal, platform, jit_device_name, - shape_representation_fn), + shape_representation_fn, padded_shape_fn), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index d5d345d43b16c43c7a202791b2604b39d29e8cdb..02e88ee6793e984a7b782790f8011cbcbc5a5026 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -45,13 +45,19 @@ namespace tensorflow { class XlaDevice : public LocalDevice { public: + // Given a tensor, sets `xla::Shape*` the shape of tensor's representation + // on device, fully padded. On error, the contents of `xla::Shape*` + // are undefined. + typedef std::function PaddedShapeFn; + // Wrapper class to store metadata about the XlaDevice, where it can be // retrieved e.g., when lazily creating the XlaCompilationCache device. class Metadata { public: Metadata(int device_ordinal, se::Platform* platform, const DeviceType& device_type, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + PaddedShapeFn padded_shape_fn); // The index of the device on this host. int device_ordinal() const; @@ -62,12 +68,14 @@ class XlaDevice : public LocalDevice { const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { return shape_representation_fn_; } + const PaddedShapeFn& padded_shape_fn() const { return padded_shape_fn_; } private: const int device_ordinal_; const DeviceType device_type_; se::Platform* platform_; // Not owned. XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + PaddedShapeFn padded_shape_fn_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -81,6 +89,8 @@ class XlaDevice : public LocalDevice { // 'transfer_as_literal' is true if device<->host transfers must be done using // XLA's TransferLiteral{To,From}Device interface. If false, we can use // ThenMemcpy instead. + // If padded_shape_fn is empty, a default implementation that returns + // the on-host shape is used. static Status Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, @@ -88,12 +98,16 @@ class XlaDevice : public LocalDevice { const XlaOpRegistry::DeviceRegistration& registration, bool transfer_as_literal, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, - std::unique_ptr* device); + const PaddedShapeFn& padded_shape_fn, std::unique_ptr* device); + // Creates a new XLA Device. + // If padded_shape_fn is empty, a default implementation that returns + // the logical on-device shape without padding is used. XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, se::Platform* platform, bool transfer_as_literal, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn); + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override; @@ -110,6 +124,7 @@ class XlaDevice : public LocalDevice { Tensor* tensor) override; xla::LocalClient* client() const; + const Metadata& metadata() { return xla_metadata_; } xla::StatusOr GetStream(); // If not already set, create and set GpuDeviceInfo. diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index ff30b62bad782f281bcd25275521ed8b0c4c0bfd..71e63b110b3b132a57fc291e53a165954c72a03c 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -54,16 +54,26 @@ XlaTransferManager::XlaTransferManager( client_(client), transfer_manager_(client->backend().transfer_manager()), transfer_as_literal_(transfer_as_literal), - shape_representation_fn_(std::move(shape_representation_fn)) {} + shape_representation_fn_(std::move(shape_representation_fn)) { + if (!shape_representation_fn_) { + shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) { + return shape; + }; + } +} Status XlaTransferManager::TransferLiteralToDevice( const Tensor& host_tensor, Tensor* device_tensor) const { - xla::Literal literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(host_tensor, &literal)); - VLOG(1) << "Transfer to device as literal: " << literal.ToString(); + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), + host_tensor.shape(), &xla_shape)); + xla::BorrowingLiteral literal( + static_cast(DMAHelper::base(&host_tensor)), xla_shape); const xla::ShapedBuffer& shaped_buffer = XlaTensor::FromTensor(device_tensor)->shaped_buffer(); + VLOG(1) << "Transfer to device as literal: " << literal.ToString() << " " + << shaped_buffer.ToString(); return transfer_manager_->TransferLiteralToDevice(stream_->parent(), literal, shaped_buffer); } @@ -76,7 +86,8 @@ Status XlaTransferManager::TransferLiteralFromDevice( TF_ASSIGN_OR_RETURN(std::unique_ptr literal, transfer_manager_->TransferLiteralFromDevice( stream_->parent(), shaped_buffer)); - VLOG(1) << "Transfer from device as literal: " << literal->ToString(); + VLOG(1) << "Transfer from device as literal: " << literal->ToString() << " " + << shaped_buffer.ToString(); Tensor tensor; TF_RETURN_IF_ERROR( LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); @@ -98,7 +109,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, << " " << reinterpret_cast( device_tensor->tensor_data().data()) - << " " << cpu_tensor->NumElements(); + << " " << cpu_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); void* src_ptr = const_cast(DMAHelper::base(cpu_tensor)); const int64 total_bytes = cpu_tensor->TotalBytes(); @@ -106,13 +119,8 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); CHECK(xla_tensor); - TensorShape shape; - if (shape_representation_fn_) { - shape = shape_representation_fn_(device_tensor->shape(), - device_tensor->dtype()); - } else { - shape = device_tensor->shape(); - } + TensorShape shape = shape_representation_fn_(device_tensor->shape(), + device_tensor->dtype()); if (!xla_tensor->has_shaped_buffer()) { Status s = xla_tensor->AllocateShapedBuffer( device_tensor->dtype(), shape, client_, @@ -165,7 +173,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, device_tensor->tensor_data().data()) << " " << reinterpret_cast(cpu_tensor->tensor_data().data()) - << device_tensor->NumElements(); + << " " << device_tensor->NumElements() << " " + << cpu_tensor->shape().DebugString() << " " + << device_tensor->shape().DebugString(); const int64 total_bytes = cpu_tensor->TotalBytes(); se::DeviceMemoryBase dev_src_ptr = @@ -194,6 +204,42 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(Status::OK()); } +void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, + Tensor* dst_tensor, + const StatusCallback& done) { + // TODO(phawkins): replace this code with an asynchronous implementation. + auto body = [&]() { + if (src_tensor.NumElements() == 0) { + return Status::OK(); + } + XlaTensor* xla_src = XlaTensor::FromTensor(&src_tensor); + XlaTensor* xla_dst = XlaTensor::FromTensor(dst_tensor); + CHECK(xla_src && xla_dst) + << "Missing destination tensor for device-to-device copy"; + if (!xla_dst->has_shaped_buffer()) { + TensorShape shape = + shape_representation_fn_(src_tensor.shape(), src_tensor.dtype()); + TF_RETURN_IF_ERROR( + xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_, + stream_->parent()->device_ordinal())); + } + TF_RETURN_IF_ERROR( + xla_dst->shaped_buffer().buffers().ForEachMutableElementWithStatus( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + const se::DeviceMemoryBase& from_buffer = + xla_src->shaped_buffer().buffers().element(index); + CHECK_EQ(buffer->size(), from_buffer.size()); + if (!stream_->parent()->SynchronousMemcpy(buffer, from_buffer, + buffer->size())) { + return errors::Internal("Device to device memcpy failed"); + } + return Status::OK(); + })); + return Status::OK(); + }; + done(body()); +} + XlaDeviceContext::XlaDeviceContext( se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, XlaCompiler::ShapeRepresentationFn shape_representation_fn) @@ -215,4 +261,10 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, done); } +void XlaDeviceContext::CopyDeviceTensorToDevice(const Tensor& src_tensor, + Tensor* dst_tensor, + const StatusCallback& done) { + manager_.CopyDeviceTensorToDevice(src_tensor, dst_tensor, done); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 9af9655868448ce5116db3611c5f88339135947e..ee346e5653bbf9f393df202572c2150b4989506f 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -55,6 +55,10 @@ class XlaTransferManager { void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done); + + void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, + const StatusCallback& done); + se::Stream* stream() const { return stream_; } private: @@ -72,7 +76,7 @@ class XlaTransferManager { xla::TransferManager* transfer_manager_; // True if we must use XLA's TransferManager for correct device transfers. const bool transfer_as_literal_; - const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -90,6 +94,9 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, StringPiece tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; + void CopyDeviceTensorToDevice(const Tensor& src_tensor, Tensor* dst_tensor, + const StatusCallback& done); + se::Stream* stream() const override { return manager_.stream(); } private: diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc index f68dba6b6a26c0c289fd8457ad143d62e5fb9a69..5ecb1afa7bcec910ca843ccd3a782745f2bb6ca8 100644 --- a/tensorflow/compiler/jit/xla_device_ops.cc +++ b/tensorflow/compiler/jit/xla_device_ops.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_ops.h" +#include + #include "tensorflow/compiler/jit/xla_device_context.h" +#include "tensorflow/compiler/jit/xla_tensor.h" namespace tensorflow { @@ -26,4 +29,82 @@ void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) { << type_string() << " on an XLA device. This should never happen."; } +XlaAssignVariableOp::XlaAssignVariableOp(OpKernelConstruction* c) + : AsyncOpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); +} + +void XlaAssignVariableOp::ComputeAsync(OpKernelContext* context, + DoneCallback done) { + OP_REQUIRES_ASYNC(context, dtype_ == context->input(1).dtype(), + errors::InvalidArgument( + "Variable and value dtypes don't match; respectively, ", + dtype_, " and ", context->input(1).dtype()), + done); + Var* variable = nullptr; + OP_REQUIRES_OK_ASYNC( + context, + LookupOrCreateResource( + context, HandleFromInput(context, 0), &variable, + [this, context](Var** ptr) { + *ptr = new Var(dtype_); + PersistentTensor unused; + Tensor* tmp; + AllocatorAttributes attr; + TF_RETURN_IF_ERROR(context->allocate_persistent( + dtype_, context->input(1).shape(), &unused, &tmp, attr)); + *(*ptr)->tensor() = *tmp; + return Status::OK(); + }), + done); + core::ScopedUnref s(variable); + + OP_REQUIRES_ASYNC(context, variable->tensor()->dtype() == dtype_, + errors::InvalidArgument( + "Trying to assign variable with wrong dtype. Expected ", + DataTypeString(variable->tensor()->dtype()), " got ", + DataTypeString(dtype_)), + done); + + const Tensor& value = context->input(1); + AllocatorAttributes attr; + + // Copying is unnecessary if we are the last user of the value tensor, we can + // just adopt the input tensor's buffer instead. + std::unique_ptr input_alias = context->forward_input( + 1, /*output_index=*/OpKernelContext::Params::kNoReservation, dtype_, + value.shape(), DEVICE_MEMORY, attr); + mutex_lock ml(*variable->mu()); + variable->is_initialized = true; + if (input_alias) { + *variable->tensor() = *input_alias; + done(); + return; + } + + // Need to copy, but maybe we can re-use variable's buffer? + if (!XlaTensor::RefCountIsOne(*variable->tensor()) || + !variable->tensor()->shape().IsSameSize(value.shape())) { + // Copy to new buffer + PersistentTensor unused; + Tensor* tmp; + OP_REQUIRES_OK_ASYNC(context, + context->allocate_persistent(dtype_, value.shape(), + &unused, &tmp, attr), + done); + *variable->tensor() = *tmp; + } + + XlaDeviceContext* device_context = + static_cast(context->op_device_context()); + + variable->Ref(); + device_context->CopyDeviceTensorToDevice( + value, variable->tensor(), [context, variable, done](Status status) { + variable->Unref(); + context->SetStatus(status); + done(); + }); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 9c00a0682ccdc08e7bb09e32d32f01e87e7aaf8d..11e45d2823da2b623bd3cd45f7147686b05fdb2f 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -26,7 +26,9 @@ limitations under the License. #include "tensorflow/core/kernels/identity_n_op.h" #include "tensorflow/core/kernels/identity_op.h" #include "tensorflow/core/kernels/no_op.h" +#include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/sendrecv_ops.h" +#include "tensorflow/core/kernels/shape_ops.h" #include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { @@ -41,6 +43,15 @@ class XlaDeviceDummyOp : public OpKernel { void Compute(OpKernelContext* ctx) override; }; +class XlaAssignVariableOp : public AsyncOpKernel { + public: + explicit XlaAssignVariableOp(OpKernelConstruction* c); + void ComputeAsync(OpKernelContext* context, DoneCallback done) override; + + private: + DataType dtype_; +}; + #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \ REGISTER_KERNEL_BUILDER(Name("XlaLaunch") \ .Device(DEVICE) \ @@ -73,7 +84,68 @@ class XlaDeviceDummyOp : public OpKernel { \ REGISTER_KERNEL_BUILDER( \ Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), \ - ResourceHandleOp); + ResourceHandleOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"), \ + ReadVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("Shape") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("ShapeN") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + ShapeNOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE) \ + .HostMemory("output") \ + .TypeConstraint("out_type") \ + .TypeConstraint("T", TYPES), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T", \ + TYPES), \ + RankOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"), \ + XlaAssignVariableOp); \ + REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \ + ControlTriggerOp); \ + REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \ + SwitchOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \ + REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \ + REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \ + REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \ + NextIterationOp); \ + REGISTER_KERNEL_BUILDER(Name("LoopCond") \ + .Device(DEVICE) \ + .HostMemory("input") \ + .HostMemory("output"), \ + LoopCondOp); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.cc b/tensorflow/compiler/jit/xla_fusion_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..74257b09a808a39454eace3b1a9bf57a2e071360 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.cc @@ -0,0 +1,328 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/graphcycles/graphcycles.h" +#include "tensorflow/compiler/jit/union_find.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" + +namespace tensorflow { + +// Is 'node' an operator that consumes only the shape of its input, not the +// data itself? +static bool IsShapeConsumerOp(const Node& node) { + return node.type_string() == "Shape" || node.type_string() == "ShapeN" || + node.type_string() == "Rank" || node.type_string() == "Size"; +} + +// Returns true if the op can be decomposed into XLA ops for which +// there are fusable elemental implementations. +bool IsXlaFusable(const NodeDef& node) { + static const std::unordered_set* elementwise_ops = + new std::unordered_set( + {// tf2xla/kernels/aggregate_ops.cc + "AddN", + // tf2xla/kernels/binary_ops.cc + "Add", "Sub", "Mul", "Div", "Atan2", "Complex", "FloorDiv", + "FloorMod", "BitwiseAnd", "BitwiseOr", "LeftShift", "RightShift", + "LogicalAnd", "LogicalOr", "Mod", "Maximum", "Minimum", "RealDiv", + "ReciprocalGrad", "RsqrtGrad", "SqrtGrad", "SquaredDifference", + "TruncateDiv", "TruncateMod", "Equal", "NotEqual", "Greater", + "GreaterEqual", "Less", "LessEqual", "SigmoidGrad", "SoftplusGrad", + "SoftsignGrad", "TanhGrad", "Pow", "ApproximateEqual", + // tf2xla/kernels/unary_ops.cc + "ComplexAbs", "Angle", "Conj", "Abs", "Acos", "Acosh", "Asin", + "Asinh", "Atan", "Atanh", "Ceil", "Cos", "Cosh", "Sin", "Exp", + "Expm1", "Floor", "IsFinite", "IsInf", "IsNan", "Inv", "Reciprocal", + "Log", "Log1p", "Invert", "LogicalNot", "Neg", "Rint", "Round", + "Rsqrt", "Sigmoid", "Sign", "Sinh", "Softplus", "Softsign", "Sqrt", + "Square", "Tan", "Tanh", "Real", "Imag", + // tf2xla/kernels/bcast_ops.cc + "BroadcastArgs", "BroadcastGradientArgs", + // tf2xla/kernels/bias_ops.cc + "BiasAdd", "BiasAddV1", "BiasAddGrad" /*(Reduce)*/, + // tf2xla/kernels/cast_op.cc + "Cast", + // tf2xla/kernels/concat_op.cc + "Concat", "ConcatV2", "ConcatOffset", + // tf2xla/kernels/const_op.cc + "Const", + // tf2xla/kernels/elu_op.cc + "Elu", "EluGrad", "Selu", "SeluGrad", + // tf2xla/kernels/fill_op.cc + "Fill", + // tf2xla/kernels/identity_op.cc + "Identity", "IdentityN", "PreventGradient", + "StopGradient", /*"Snapshot",*/ + // tf2xla/kernels/index_ops.cc + "ArgMax", "ArgMin", + // tf2xla/kernels/mirror_pad_op.cc + "MirrorPad", + // tf2xla/kernels/one_hot_op.cc + "OneHot", + // tf2xla/kernels/pack_op.cc + "Pack", + // tf2xla/kernels/pad_op.cc + "Pad", "PadV2", + // tf2xla/kernels/relu_op.cc + "Relu", "Relu6", "ReluGrad", "Relu6Grad", + // tf2xla/kernels/reshape_op.cc + "Reshape", + // tf2xla/kernels/reverse_op.cc + "Reverse", "ReverseV2", + // tf2xla/kernels/reverse_sequence_op.cc + "ReverseSequence", + // tf2xla/kernels/shape_op.cc + "Shape", "ShapeN", "Rank", "Size", "ExpandDims", "Squeeze", + "ZerosLike", "OnesLike", + // tf2xla/kernels/slice_op.cc + "Slice", + // tf2xla/kernels/split_op.cc + "Split", "SplitV", + // tf2xla/kernels/strided_slice_op.cc + "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", + // tf2xla/kernels/tile_ops.cc + "Tile", + // tf2xla/kernels/transpose_op.cc + "Transpose", "InvertPermutation", + // tf2xla/kernels/unpack_op.cc + "Unpack"}); + + return elementwise_ops->count(node.op()) > 0; +} + +Status XlaFusionOptimizer::Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* output) { + VLOG(2) << "Here at fusion optimizer"; + + // TODO(hpucha): Implement encapsulation and replacing with XlaLaunch op. + // Once that happens, the expected interaction between this optimizer and when + // the global_jit_level is set is as follows: Fusion optimizer will replace + // appropriate fusion clusters with XlaLaunch nodes. The remaining graph can + // be further compiled where possible via mark_for_compilation_pass. Note that + // this might lead to inefficient clustering, and it is best to use either the + // fusion optimizer or the global_jit flag, and not combine the two. + + // Create a Graph out of GraphDef. This is required currently because the + // helpers around clustering, encapsulation etc work on graphs. + FunctionLibraryDefinition function_library(OpRegistry::Global(), + item.graph.library()); + Graph graph(function_library); + ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); + shape_refiner.set_require_shape_inference_fns(false); + shape_refiner.set_disable_constant_propagation(true); + ImportGraphDefOptions options; + // Graph optimization happens at the late stage of graph execution, when + // colocation constraints are already validated previously and the device + // placement of nodes has also completed, so there is no need to validate + // colocation constraints again. + options.validate_colocation_constraints = false; + options.validate_shape = false; + TF_RETURN_IF_ERROR( + ImportGraphDef(options, item.graph, &graph, &shape_refiner)); + + // Collect nodes that can be fused via XLA, while ignoring those that + // explicitly ask for XLA: (*) nodes that are marked to be compiled + // explicitly. (*) nodes assigned to XLA device. + OrderedNodeSet compilation_candidates; + for (Node* node : graph.op_nodes()) { + // If there is a _XlaCompile annotation, ignore the node if it is + // true. Nodes are marked with this attr via experimental_jit_scope, and + // will be handled by the mark_for_compilation pass. + bool compile = false; + Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); + if (status.ok() && compile) { + continue; + } + // If there is already a _XlaCluster annotation, ignore the node. Nodes are + // marked with this attr to indicate they are already part of a cluster and + // hence ignored. + status = GetNodeAttr(node->attrs(), kXlaClusterAttr, &compile); + if (status.ok()) { + continue; + } + + // If there is an explicit XLA device placement, ignore the node. + DeviceType device_type(""); + TF_RETURN_IF_ERROR(DeviceToDeviceType(node->def().device(), &device_type)); + if (device_type.type_string().find("XLA") != string::npos) continue; + + // Assume all fusable ops are registered. + // TODO(hpucha): Check for registration if possible. + if (!IsXlaFusable(node->def())) { + continue; + } + + // XLA does not offer guaranteed aliasing between the input and output of + // the XLA cluster so it can't implement the forward-tensor-ref semantic. + // Leave such nodes out of XLA clusters. + if (HasForwardedRefInput(*node)) { + continue; + } + + compilation_candidates.insert(node); + } + + if (compilation_candidates.empty()) { + VLOG(2) << "No compilable candidates"; + *output = item.graph; + return Status::OK(); + } + + GraphCycles cycles; + TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph, &cycles)); + + // TODO(hpucha): Make clustering more robust. There are two known issues that + // we need to mitigate: (a) Non-resource variables can cause deadlocks + // when clustering changes order of execution. See b/77263461 for a specific + // example. (b) Queue operations can also cause deadlocks. See b/77261498 for + // example. + + struct Cluster { + // Identifies the node that represents this cluster in the cycle detection + // graph. + int representative = -1; + }; + + // Each compilation candidate belongs to a cluster. The cluster's + // representative names the node in the 'cycles' graph that represents the + // cluster. + std::vector> clusters(graph.num_node_ids()); + std::deque*> worklist; + for (Node* node : compilation_candidates) { + Cluster& cluster = clusters[node->id()].Get(); + cluster.representative = node->id(); + worklist.push_back(&clusters[node->id()]); + } + + // Repeatedly contract edges between clusters that are on the same device, + // provided the contraction would not create a cycle. This is a simplified + // version of the clustering in mark_for_compilation_pass that also deals with + // nodes that are explicitly tagged to be compiled/clustered. + while (!worklist.empty()) { + int from = worklist.front()->Get().representative; + worklist.pop_front(); + + Node* node_from = graph.FindNodeId(from); + if (node_from->IsControlFlow()) { + // Control flow nodes aren't compilation candidates and should never + // appear. + return errors::Internal( + "Found control flow node in clustering worklist: ", + node_from->type_string()); + } + for (int to : cycles.Successors(from)) { + if (to >= graph.num_node_ids()) { + // Node is a "frame" node that is present only in the cycle detection + // graph. No clustering is possible. + continue; + } + Node* node_to = graph.FindNodeId(to); + if (compilation_candidates.find(node_to) == + compilation_candidates.cend()) { + continue; + } + + // Do not cluster across devices. + if (node_from->def().device() != node_to->def().device()) { + VLOG(2) << "Devices " << node_from->def().device() << " " + << node_to->def().device(); + VLOG(2) << "Device names " << node_from->assigned_device_name() << " " + << node_to->assigned_device_name(); + continue; + } + + // Ops that consume shapes cannot be the root of a cluster. This is an + // optimization. + if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) { + continue; + } + + // If contracting the edge would create a cycle, bail out. + // However, just because we can't merge the clusters now does not mean + // we won't be able to merge them in the future. + // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge + // 1->3. But if we first contract 1->2 then we can later contract 1->3. + if (!cycles.ContractEdge(from, to)) continue; + + // Merge the clusters. ContractEdge uses 'from' as the number of the + // merged node, so make sure 'from' is the chosen representative. + clusters[from].Merge(&clusters[to]); + + worklist.push_back(&clusters[from]); + break; + } + } + + // Count the number of non-trivial elements in each cluster. + std::vector effective_cluster_sizes(graph.num_node_ids()); + for (const Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + // Identity nodes will be removed if the node gets marked for compilation. + // Therefore we don't want to count them towards the effective cluster size. + if (n->def().op() != "Identity") { + effective_cluster_sizes[cluster]++; + } + } + + const int min_cluster_size = 2; + int num_clusters = 0; + for (auto size : effective_cluster_sizes) { + if (size >= min_cluster_size) { + VLOG(3) << "Cluster " << num_clusters << " " << size; + num_clusters++; + } + } + + // Names for each cluster. + std::unordered_map cluster_names; + // Sequence number generator to ensure clusters have unique names. + static std::atomic cluster_sequence_num; + + for (Node* n : compilation_candidates) { + int cluster = clusters[n->id()].Get().representative; + + // Compile if this is a cluster of >= min_cluster_size compilable operators. + if (effective_cluster_sizes[cluster] >= min_cluster_size) { + string& name = cluster_names[cluster]; + + if (name.empty()) { + name = strings::StrCat("cluster_", cluster_sequence_num++); + } + n->AddAttr(kXlaClusterAttr, name); + VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; + } + } + + graph.ToGraphDef(output); + return Status::OK(); +} + +REGISTER_GRAPH_OPTIMIZER_AS(XlaFusionOptimizer, "xla-fusion"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer.h b/tensorflow/compiler/jit/xla_fusion_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..3d2309e782d38725f8db025fbfda0bf0f63d18be --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ +#define TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { + +// Optimizes graphs by fusing ops where possible, resulting in more efficient +// execution. +class XlaFusionOptimizer : public grappler::CustomGraphOptimizer { + public: + XlaFusionOptimizer() {} + ~XlaFusionOptimizer() override {} + + Status Init( + const RewriterConfig_CustomGraphOptimizer* config = nullptr) override { + return Status::OK(); + } + + string name() const override { return "xla-fusion"; }; + + Status Optimize(grappler::Cluster* cluster, + const grappler::GrapplerItem& item, + GraphDef* output) override; + + void Feedback(grappler::Cluster* cluster, const grappler::GrapplerItem& item, + const GraphDef& optimize_output, double result) override { + // Nothing to do for XlaFusionOptimizer. + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_XLA_FUSION_OPTIMIZER_H_ diff --git a/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5736760a878dc857a8558093054d0adc0f727398 --- /dev/null +++ b/tensorflow/compiler/jit/xla_fusion_optimizer_test.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/xla_fusion_optimizer.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/xla_cluster_util.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +REGISTER_OP("UncompilableNullary").Output("o: float"); +REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float"); + +class XlaFusionOptimizerTest : public grappler::GrapplerTest { + protected: + std::unordered_map GetClusters(const GraphDef& graph) { + std::unordered_map ids; + for (const NodeDef& node : graph.node()) { + string cluster; + if (GetNodeAttr(AttrSlice(node), kXlaClusterAttr, &cluster).ok()) { + CHECK(!cluster.empty()); + ids[node.name()] = cluster; + } + } + return ids; + } +}; + +TEST_F(XlaFusionOptimizerTest, Chains) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = + ops::SourceOp("UncompilableNullary", builder.opts().WithName("A")); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); + Node* d = + ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); + ops::UnaryOp("Relu", e, builder.opts().WithName("F")); + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(4, clusters.size()); + EXPECT_EQ(clusters["B"], clusters["C"]); + EXPECT_EQ(clusters["E"], clusters["F"]); + EXPECT_NE(clusters["B"], clusters["E"]); + EXPECT_TRUE(clusters.find("A") == clusters.cend()); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST_F(XlaFusionOptimizerTest, FusableOps) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Placeholder", + builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); + Node* b = ops::SourceOp( + "Placeholder", + builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); + + Node* c = ops::BinaryOp("Add", a, b, builder.opts().WithName("C")); + ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); + ops::UnaryOp("Abs", c, builder.opts().WithName("E")); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(2, clusters.size()); + EXPECT_EQ(clusters["C"], clusters["E"]); + EXPECT_TRUE(clusters.find("D") == clusters.cend()); +} + +TEST_F(XlaFusionOptimizerTest, IgnoreExplicitXLAAttrs) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp( + "Placeholder", + builder.opts().WithName("A").WithAttr("dtype", tensorflow::DT_FLOAT)); + Node* b = ops::SourceOp( + "Placeholder", + builder.opts().WithName("B").WithAttr("dtype", tensorflow::DT_FLOAT)); + + Node* c = ops::BinaryOp( + "Add", a, b, + builder.opts().WithName("C").WithDevice("/device:XLA_CPU")); + ops::BinaryOp("MatMul", a, c, builder.opts().WithName("D")); + Node* e = ops::UnaryOp("Abs", c, builder.opts().WithName("E")); + ops::UnaryOp("Cos", e, + builder.opts().WithName("F").WithAttr(kXlaCompileAttr, true)); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_TRUE(clusters.empty()); +} + +TEST_F(XlaFusionOptimizerTest, UncompilableCycles) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = + ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); + ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); + + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_TRUE(clusters.empty()); +} + +TEST_F(XlaFusionOptimizerTest, CompilableCycles) { + GraphDef graph; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::BinaryOp("Mul", a, b, builder.opts().WithName("C")); + TF_ASSERT_OK(builder.ToGraphDef(&graph)); + } + grappler::GrapplerItem item; + item.graph = graph; + + XlaFusionOptimizer optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + auto clusters = GetClusters(output); + EXPECT_EQ(3, clusters.size()); + EXPECT_EQ(clusters["A"], clusters["B"]); + EXPECT_EQ(clusters["A"], clusters["C"]); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 26842fbe5cc110fa9ce7a2767d245484fd67556d..c0d86a28c7698c302e28bab972bb2f847cc00ca4 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -49,7 +49,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, registration, /*transfer_as_literal=*/false, - /*shape_representation_fn=*/{}, &device); + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 4146996f6346446e715ffb225882cfb20359dae1..661187f4a873b03b8d013aa74cb6b6315bb4e2eb 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -48,11 +48,12 @@ Status XlaInterpreterDeviceFactory::CreateDevices( registration.compile_resource_ops = true; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create( - "Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT, - options, name_prefix, registration, - /*transfer_as_literal=*/false, - /*shape_representation_fn=*/{}, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, + DEVICE_INTERPRETER_XLA_JIT, options, + name_prefix, registration, + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, + /*padded_shape_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index a7211c9c7e281a8141d5671b345c628441b2359d..3c44c4ae6df7f3e2d60d8933561c0c71888e8c3f 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -18,7 +18,7 @@ limitations under the License. namespace tensorflow { -/*static*/ XlaTensor* XlaTensor::FromTensor(Tensor* tensor) { +/*static*/ XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) { if (tensor->NumElements() == 0) { return nullptr; } @@ -27,8 +27,8 @@ namespace tensorflow { return xla_tensor; } -/*static*/ const XlaTensor* XlaTensor::FromTensor(const Tensor* tensor) { - return FromTensor(const_cast(tensor)); +/*static*/ bool XlaTensor::RefCountIsOne(const Tensor& tensor) { + return tensor.RefCountIsOne(); } /*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor( @@ -67,6 +67,8 @@ Status XlaTensor::AllocateShapedBuffer(DataType dtype, const TensorShape& shape, index_to_buffer.second = buffer.Forget(); } + VLOG(4) << shaped_buffer.ToString(); + set_shaped_buffer(std::move(shaped_buffer)); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 6b29c82ec11e39ad525663991e179443c2b6dca7..c54001a999998f45c0cdacd752ca4036f0792857 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -34,10 +34,9 @@ class XlaTensor { public: // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast // fails. - static XlaTensor* FromTensor(Tensor* tensor); - // Downcast from a Tensor to an XlaTensor. Return nullptr if the downcast - // fails. - static const XlaTensor* FromTensor(const Tensor* tensor); + static XlaTensor* FromTensor(const Tensor* tensor); + + static bool RefCountIsOne(const Tensor& tensor); // Create a DeviceMemoryBase from a Tensor. The Tensor can be an XlaTensor, in // which case the returned value is shaped_buffer()->root_buffer(), or a @@ -62,6 +61,10 @@ class XlaTensor { CHECK(has_shaped_buffer()); return *shaped_buffer_; } + xla::ShapedBuffer& shaped_buffer() { + CHECK(has_shaped_buffer()); + return *shaped_buffer_; + } // Mutates the XlaTensor to set the ShapedBuffer. void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) { shaped_buffer_ = diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 4c291d2383163e5def54657186c2190c023832fc..e6c92f9720e1285617280f60d1c5fea443c5ebef 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -120,6 +120,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "bucketize_op_test", + size = "small", + srcs = ["bucketize_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "categorical_op_test", size = "small", @@ -532,7 +545,9 @@ tf_xla_py_test( ], deps = [ ":xla_test", + "//tensorflow/python:array_ops", "//tensorflow/python:framework", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], diff --git a/tensorflow/compiler/tests/bucketize_op_test.py b/tensorflow/compiler/tests/bucketize_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fde9759a1c209844caac99d5f303cd3e406e5370 --- /dev/null +++ b/tensorflow/compiler/tests/bucketize_op_test.py @@ -0,0 +1,78 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for bucketize_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class BucketizationOpTest(XLATestCase): + + def testInt(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + self.assertAllEqual(expected_out, + sess.run(op, {p: [-5, 0, 2, 3, 5, 8, 10, 11, 12]})) + + def testFloat(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0., 3., 8., 11.]) + expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4] + self.assertAllEqual( + expected_out, + sess.run(op, {p: [-5., 0., 2., 3., 5., 8., 10., 11., 12.]})) + + def test2DInput(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 3, 8, 11]) + expected_out = [[0, 1, 1, 2, 2], [3, 3, 4, 4, 1]] + self.assertAllEqual( + expected_out, sess.run(op, + {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) + + def testInvalidBoundariesOrder(self): + with self.test_session() as sess: + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + op = math_ops._bucketize(p, boundaries=[0, 8, 3, 11]) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "Expected sorted boundaries"): + sess.run(op, {p: [-5, 0]}) + + def testBoundariesNotList(self): + with self.test_session(): + with self.assertRaisesRegexp(TypeError, "Expected list.*"): + p = array_ops.placeholder(dtypes.int32) + with self.test_scope(): + math_ops._bucketize(p, boundaries=0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/depthwise_conv_op_test.py b/tensorflow/compiler/tests/depthwise_conv_op_test.py index 0a0d335ca76dd7ec7ca3b12f9e8a83b596daa07e..03d96a2cd8ab22a472a67f092e36224820405fa8 100644 --- a/tensorflow/compiler/tests/depthwise_conv_op_test.py +++ b/tensorflow/compiler/tests/depthwise_conv_op_test.py @@ -153,7 +153,7 @@ class DepthwiseConv2DTest(XLATestCase): dtype=data_type).reshape(filter_in_sizes) with self.test_session() as sess: if data_type == np.float32: - tolerance = 1e-5 + tolerance = 1e-4 else: self.assertEqual(data_type, np.float64) tolerance = 1e-8 @@ -339,7 +339,7 @@ class DepthwiseConv2DTest(XLATestCase): gpu_value = _GetVal(use_xla=True) cpu_value = _GetVal(use_xla=False) - self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4) + self.assertAllClose(cpu_value, gpu_value, rtol=1e-3, atol=1e-3) def testDepthwiseConv2DInputGradCompare(self): for index, (input_size, filter_size, output_size, stride, diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 52d8d6d295c428f2c3466ef2963223cc978b4277..a4154ad1e846f8241a2ab6598da36ccb6b3b653e 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -31,11 +31,13 @@ from tensorflow.python.framework import ops from tensorflow.python.layers import convolutional from tensorflow.python.layers import pooling from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import googletest +from tensorflow.python.training import adam class EagerTest(XLATestCase): @@ -117,6 +119,15 @@ class EagerTest(XLATestCase): v.assign_add(2.0) self.assertEqual(3.0, v.numpy()) + def testReadAssignRead(self): + with self.test_scope(): + v = resource_variable_ops.ResourceVariable(1.0) + val1 = v.read_value() + v.assign_add(2.0) + val2 = v.read_value() + self.assertEqual(1.0, val1.numpy()) + self.assertEqual(3.0, val2.numpy()) + def testGradient(self): def f(x): return x @@ -136,6 +147,129 @@ class EagerTest(XLATestCase): grads = backprop.implicit_grad(f)() self.assertEqual(2., grads[0][0].numpy()) + def testMultipleVariableReads(self): + # This test makes sure consecutive variable reads don't copy + # the underlying memory. + with self.test_scope(): + # Create 128MiB variables + var = resource_variable_ops.ResourceVariable( + array_ops.ones([32, 1024, 1024])) + + # Read the same variable 100 times. If the underlying tensor + # is not copied, this is a trivial operation. If it is copied, + # this will eat over 13GB and OOM. + values = [] + for _ in range(100): + values.append(var.value()) + + # The shape, shape_n, size, and rank are tested here because their + # execution kernels (as opposed to compilation only tf2xla kernels) + # are distincts from tf2xla kernels. + + def testShape(self): + def const(value): + return array_ops.shape( + constant_op.constant(value)).numpy() + + def ones(value): + return array_ops.shape( + array_ops.ones(value)).numpy() + + with self.test_scope(): + # Shapes of directly constructed tensors + self.assertAllEqual([], const(3)) + self.assertAllEqual([3], const([1.0, 2.0, 3.0])) + self.assertAllEqual([2, 2], const([[1.0, 2.0], [3.0, 4.0]])) + self.assertAllEqual([2, 1, 2], const([[[1.0, 2.0]], [[3.0, 4.0]]])) + + # Shapes of tensors created by op running on device + # We make this distinction because directly constructed tensors + # are treated differently in a few places that can influence shape: + # - they always have on_host_tensor + # - they and their shapes can be cached + # - they end up on device via a copy, instead of as program output + self.assertAllEqual([], ones([])) + self.assertAllEqual([3], ones([3])) + self.assertAllEqual([2, 2], ones([2, 2])) + self.assertAllEqual([2, 1, 2], ones([2, 1, 2])) + + def testShapeN(self): + with self.test_scope(): + # Shapes of directly constructed tensors + shapes = array_ops.shape_n([ + constant_op.constant(1.0), + constant_op.constant([1.0, 2.0, 3.0]), + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])]) + self.assertAllEqual( + [[], [3], [2, 2]], + [x.numpy().tolist() for x in shapes]) + + # Shapes of tensors created by op running on device + shapes = array_ops.shape_n([ + array_ops.ones([]), + array_ops.ones([3]), + array_ops.ones([2, 2])]) + self.assertAllEqual( + [[], [3], [2, 2]], + [x.numpy().tolist() for x in shapes]) + + def testSize(self): + with self.test_scope(): + self.assertEqual( + 1, array_ops.size(constant_op.constant(1.0)).numpy()) + self.assertEqual( + 3, array_ops.size(constant_op.constant([1.0, 2.0, 3.0])).numpy()) + self.assertEqual( + 4, array_ops.size( + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) + + def testRank(self): + with self.test_scope(): + self.assertEqual( + 0, array_ops.rank(constant_op.constant(1.0)).numpy()) + self.assertEqual( + 1, array_ops.rank(constant_op.constant([1.0, 2.0, 3.0])).numpy()) + self.assertEqual( + 2, array_ops.rank( + constant_op.constant([[1.0, 2.0], [3.0, 4.0]])).numpy()) + + def testAdam(self): + with self.test_scope(): + optimizer = adam.AdamOptimizer(0.1) + x = resource_variable_ops.ResourceVariable(10.0) + with backprop.GradientTape() as tape: + y = x * x + dy_dx = tape.gradient(y, x) + optimizer.apply_gradients([(dy_dx, x)]) + self.assertAlmostEqual(9.9, x.numpy(), places=3) + + def testAdamSparse(self): + with ops.device('/cpu:0'): + # Create 2-D embedding for 3 objects on CPU because sparse/sliced updates + # are not implemented on TPU. + embedding_matrix = resource_variable_ops.ResourceVariable( + array_ops.ones([3, 2])) + + with self.test_scope(): + with backprop.GradientTape() as tape: + embedding = embedding_ops.embedding_lookup(embedding_matrix, [1]) + y = math_ops.reduce_sum(embedding) + dy_dx = tape.gradient(y, embedding_matrix) + self.assertIsInstance(dy_dx, ops.IndexedSlices) + optimizer = adam.AdamOptimizer(0.1) + # The gradient application operations will run on CPU because optimizer + # updates are always collocated with the variable. + optimizer.apply_gradients([(dy_dx, embedding_matrix)]) + + # This assign_add will run on CPU because when an input to an + # operation is a resource, this operation is placed on the resource's + # device by the eager runtime. + embedding_matrix.assign_add(array_ops.ones([3, 2])) + + self.assertAllClose([[2.0, 2.0], + [1.9, 1.9], + [2.0, 2.0]], embedding_matrix.numpy()) + class EagerFunctionTest(XLATestCase): diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 42e637734c578fcc70473060cb156e172a0a1995..7cf953ef25ef5daf8a6d4fc9985ed8dbfb2081e5 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -65,9 +65,7 @@ class RGBToHSVTest(XLATestCase): join1 = array_ops.stack(split1) join2 = array_ops.stack(split2) batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2], - { - batch0: inp - }) + {batch0: inp}) # Verify that processing batch elements together is the same as separate self.assertAllClose(batch1, join1) @@ -401,9 +399,7 @@ class AdjustSaturationTest(XLATestCase): x = array_ops.placeholder(dtypes.float32, shape=x_shape) with self.test_scope(): y_fused = self._adjust_saturation(x, - scale).eval(feed_dict={ - x: x_np - }) + scale).eval(feed_dict={x: x_np}) self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) @@ -412,7 +408,8 @@ class ResizeBilinearTest(XLATestCase): def _assertForwardOpMatchesExpected(self, image_np, target_shape, - expected=None): + expected=None, + large_tolerance=False): if expected is None: self.fail("expected must be specified") with self.test_session() as sess, self.test_scope(): @@ -420,7 +417,11 @@ class ResizeBilinearTest(XLATestCase): resized = gen_image_ops.resize_bilinear( image, target_shape, align_corners=True) out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) - self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + if large_tolerance: + self.assertAllClose( + expected[np.newaxis, :, :, np.newaxis], out, rtol=0.03, atol=0.1) + else: + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) def _assertBackwardOpMatchesExpected(self, grads_np, @@ -555,6 +556,28 @@ class ResizeBilinearTest(XLATestCase): [[12.5, 27.5, 21.875], [42.5, 80.0, 57.5], [40.625, 72.5, 50]], dtype=np.float32)) + def testAlignCorners4x4To8x8(self): + self._assertForwardOpMatchesExpected( + (np.array([[0, 1, 2, 3]], dtype=np.float32) + np.array( + [[0], [1], [2], [3]], dtype=np.float32)) * 7.0, [8, 8], + expected=3 * + (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)), + large_tolerance=True) + + def testAlignCorners8x8To16x16(self): + self._assertForwardOpMatchesExpected( + (np.array([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7]], dtype=np.float32)) * 15.0, + [16, 16], + expected=7 * (np.array( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], + dtype=np.float32) + np.array( + [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], + [12], [13], [14], [15]], + dtype=np.float32)), + large_tolerance=True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 4b0043b6b4c7fbf57ec1507b84adf18daaea9363..6e0db54b7a74b284dc7d18bcbb07c178c664c1e5 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -125,7 +125,7 @@ class JitLaunchTest(test.TestCase): for (x, y) in zip(compiled, direct): self.assertAllClose(x, y, rtol=1e-1) else: - self.assertAllClose(compiled, direct) + self.assertAllClose(compiled, direct, rtol=1e-2) def testNoOutputs(self): with session_lib.Session() as sess: diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index d6c93088d4efff7d8306e262a79ae49d3d8ac722..f13dff96203b5480480c2a2fc9ac38ca78b7f78a 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -22,6 +22,8 @@ import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import googletest @@ -47,18 +49,18 @@ class RandomOpsTest(XLATestCase): # We use exact equality here. If the random-number generator is producing # deterministic output, all three outputs will be bitwise identical. self.assertTrue((not np.array_equal(y, z)) or - (not np.array_equal(z, w)) or - (not np.array_equal(y, w))) + (not np.array_equal(z, w)) or (not np.array_equal(y, w))) def testRandomUniformIsNotConstant(self): + def rng(dtype): - return random_ops.random_uniform(shape=[2], dtype=dtype, - maxval=1000000) + return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) def testRandomNormalIsNotConstant(self): + def rng(dtype): return random_ops.random_normal(shape=[2], dtype=dtype) @@ -70,12 +72,20 @@ class RandomOpsTest(XLATestCase): for dtype in self._random_types(): with self.test_session() as sess: with self.test_scope(): - x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2, - maxval=33) + x = random_ops.random_uniform( + shape=[1000], dtype=dtype, minval=-2, maxval=33) y = sess.run(x) self.assertTrue((y >= -2).sum() == 1000) self.assertTrue((y < 33).sum() == 1000) + def testTruncatedNormalIsNotConstant(self): + + def rng(dtype): + return random_ops.truncated_normal(shape=[2], dtype=dtype) + + # TODO(b/34339814): implement inverse erf support for non-F32 types. + self._testRngIsNotConstant(rng, dtypes.float32) + def testTruncatedNormalIsInRange(self): count = 10000 # TODO(b/34339814): implement inverse erf support for non-F32 types. @@ -87,6 +97,29 @@ class RandomOpsTest(XLATestCase): self.assertTrue((y >= -2).sum() == count) self.assertTrue((y <= 2).sum() == count) + def testShuffle1d(self): + with self.test_session() as sess: + with self.test_scope(): + x = math_ops.range(20) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = range(20) + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(set(result), set(expected)) + + def testShuffle2d(self): + with self.test_session() as sess: + with self.test_scope(): + x = array_ops.diag(math_ops.range(20)) + shuffle = random_ops.random_shuffle(x) + result = sess.run(shuffle) + expected = np.diag(range(20)).flatten() + # Compare sets to avoid randomness behavior changes but make sure still + # have all the values. + self.assertAllEqual(len(result.flatten()), len(expected)) + self.assertAllEqual(set(result.flatten()), set(expected)) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index 8ecad00f6e23b3a7746bbb473102ac847bf4cbfd..2c09b03d5a35cde2c42d8a145781270c0c908587 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -187,6 +187,25 @@ class VariableOpsTest(XLATestCase): rtol=1e-4) self.assertAllClose(np.array([1.9, 2.9], dtype=np.float32), vb, rtol=1e-4) + def testWriteOfAliasedTensor(self): + for dtype in self.numeric_types: + init = np.array([[1, 2j], [3, 4]]).astype(dtype) + update = np.array([[7, 1j], [2, 11]]).astype(dtype) + with self.test_session() as sess, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + sess.run(variables.variables_initializer([v])) + p = array_ops.placeholder(dtype) + q = array_ops.identity(p) + x = v.read_value() + # Writes the value of 'p' to 'v', but keeps a reference to the original + # value of 'v' so the variable update cannot reuse its buffer. + with ops.control_dependencies([x]): + y = v.assign(q) + result = sess.run([x, y, q], {p: update}) + self.assertAllClose(init, result[0]) + self.assertAllClose(update, result[1]) + self.assertAllClose(update, result[2]) + class StridedSliceAssignChecker(object): """Compares the results of a slice assignment using Tensorflow and numpy.""" diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index b707bd0963d71d7c4b43b8d42752b4c50e9bbf7c..f0b010fa67f2ffb3f81fd14d4d89585f716b4890 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.platform import test @@ -46,6 +47,12 @@ class XlaDeviceTest(XLATestCase): result = sess.run(z, {x: inputs}) self.assertAllCloseAccordingToType(result, inputs + inputs) + def testControlTrigger(self): + with self.test_session() as sess: + with self.test_scope(): + x = gen_control_flow_ops.control_trigger() + sess.run(x) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 42585ad4d8a17d71146e48b69f9fa56f9ff24c3e..1438f6b48c4913e60b0c0a9f5c3d67fe595cbfe8 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -1438,7 +1438,13 @@ Status FunctionalizeControlFlow(const FunctionLibraryDefinition* lookup_library, // connected to all source nodes in the graph. Many graphs violate this // invariant. std::vector cf_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info)); + std::vector unreachable_nodes; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info, &unreachable_nodes)); + if (!unreachable_nodes.empty()) { + return errors::InvalidArgument( + "The following nodes are unreachable from the source in the graph: ", + tensorflow::str_util::Join(unreachable_nodes, ", ")); + } // Builds Frames, indexed by name. std::unordered_map frames; diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index e6da157c111ad9167bf7b1e743d9afbb8fb2ad03..edd2ab6301ee891c433639ce300cde0c72929cea 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -18,6 +18,7 @@ tf_kernel_library( "bcast_ops.cc", "bias_ops.cc", "binary_ops.cc", + "bucketize_op.cc", "cast_op.cc", "categorical_op.cc", "cholesky_op.cc", diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca9a6b40688d1e8496d1b823e20d273d519f65e8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +class BucketizeOp : public XlaOpKernel { + public: + explicit BucketizeOp(OpKernelConstruction* context) : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("boundaries", &boundaries_)); + OP_REQUIRES(context, std::is_sorted(boundaries_.begin(), boundaries_.end()), + errors::InvalidArgument("Expected sorted boundaries")); + } + + void Compile(XlaOpKernelContext* context) override { + xla::XlaBuilder* builder = context->builder(); + const DataType dtype = context->input_type(0); + xla::XlaOp input = context->Input(0); + + xla::XlaOp boundaries = builder->ConstantR1(boundaries_); + // TODO(phawkins): the following behavior matches the behavior of the core + // Bucketize kernel. However, comparing an int32 or int64 against float may + // lead to inaccurate bucketing due to rounding. + if (dtype == DT_DOUBLE) { + input = builder->ConvertElementType(input, xla::F64); + boundaries = builder->ConvertElementType(boundaries, xla::F64); + } else { + input = builder->ConvertElementType(input, xla::F32); + } + xla::XlaOp comparison = builder->ConvertElementType( + builder->Ge(builder->Broadcast(input, {1}), boundaries, + /*broadcast_dimensions=*/{0}), + xla::S32); + xla::XlaOp buckets = builder->Reduce( + comparison, /*init_value=*/builder->ConstantR0(0), + /*computation=*/xla::CreateScalarAddComputation(xla::S32, builder), + /*dimensions_to_reduce=*/{0}); + context->SetOutput(0, buckets); + } + + private: + std::vector boundaries_; +}; + +REGISTER_XLA_OP(Name("Bucketize"), BucketizeOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 8b9b026643cf35216a2082dfcce9270c017bd14f..d48c6eea754f75a8879d3938f233a6a591d26d0d 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -48,11 +48,11 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Building If: " << input_types_.size() << " inputs"; - std::vector inputs(input_types_.size()); std::vector arguments(input_types_.size()); for (int i = 0; i < input_types_.size(); ++i) { XlaCompiler::Argument& arg = arguments[i]; DataType type = ctx->input_type(i + 1); + if (type == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource)); @@ -60,7 +60,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.initialized = resource->initialized(); arg.kind = XlaCompiler::Argument::kResource; arg.resource_kind = resource->kind(); - OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); arg.type = resource->type(); arg.shape = resource->shape(); @@ -79,7 +78,6 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { arg.kind = XlaCompiler::Argument::kParameter; arg.type = input_types_[i]; arg.shape = ctx->InputShape(i + 1); - inputs[i] = ctx->Input(i + 1); VLOG(2) << "Arg type: " << DataTypeString(arg.type) << " shape: " << arg.shape.DebugString(); } @@ -100,6 +98,7 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, arguments, &else_result)); + bool has_tensor_array_gradients = false; for (XlaCompiler::CompilationResult* result : {&then_result, &else_result}) { for (const XlaCompiler::ResourceUpdate& update : result->resource_updates) { XlaResource* resource; @@ -121,9 +120,21 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } + if (!resource->tensor_array_gradients().empty()) + has_tensor_array_gradients = true; } } + // Recompile the functions to update the argument shapes for tensor arrays. + if (has_tensor_array_gradients) { + then_result = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, then_branch_, + arguments, &then_result)); + else_result = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(options, else_branch_, + arguments, &else_result)); + } + // Check that both branches have identical input shapes. OP_REQUIRES(ctx, then_result.xla_input_shapes.size() == 1, errors::FailedPrecondition("Expected one input shape")); @@ -175,6 +186,19 @@ void XlaIfOp::Compile(XlaOpKernelContext* ctx) { "Mismatch in resource of then and else branch for resource ", i)); } + int num_inputs = then_result.input_mapping.size(); + std::vector inputs(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + int input_num = then_result.input_mapping[i] + 1; + if (ctx->input_type(input_num) == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); + OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b)); + } else { + inputs[i] = ctx->Input(i + 1); + } + } + xla::XlaOp outputs = b->Conditional(ctx->Input(0), b->Tuple(inputs), *then_result.computation, b->Tuple(inputs), *else_result.computation); diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index 9058cbc74762576c7e6f8ec1b2b0f6b247ac0502..79d3a6979cec4c6bda92a71dcff4ddd2151367d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -99,27 +99,34 @@ ResizeConvolutionDims ComputeResizeConvolutionParameters( return dims; } +// Form a 2D convolution kernel like: +// 1 2 3 2 1 +// 2 4 6 4 2 +// 1/9 * 3 6 9 6 3 +// 2 4 6 4 2 +// 1 2 3 2 1 +// by multiplying two 1D kernels of the form: +// 1/3 * [1 2 3 2 1] +// If the 2D kernel would be very large, the 1D kernel can be applied once in +// each dimension due to the symmetry of the kernel along all axis to reduce the +// computational intensity. +std::vector Make1DKernel(int64 n) { + std::vector kernel(n * 2 - 1); + for (int64 i = 0; i < n; ++i) { + float v = (i + 1.0f) / n; + kernel[i] = v; + kernel[n * 2 - 2 - i] = v; + } + return kernel; +} + +// Kernels with more than 16 spatial elements are considered intense and the +// kernel should applied to each dimension independently. +const int64 kMax2DKernelSize = 16; + xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, gtl::ArraySlice kernel_size, int64 channels) { - // Form a 2D convolution kernel like: - // 1 2 3 2 1 - // 2 4 6 4 2 - // 1/9 * 3 6 9 6 3 - // 2 4 6 4 2 - // 1 2 3 2 1 - // by multiplying two 1D kernels of the form: - // 1/3 * [1 2 3 2 1] - auto make_1d_kernel = [](int64 n) { - std::vector kernel(n * 2 - 1); - for (int64 i = 0; i < n; ++i) { - float v = (i + 1.0f) / n; - kernel[i] = v; - kernel[n * 2 - 2 - i] = v; - } - return kernel; - }; - xla::XlaOp channels_iota; // DT_INT32 Iota will always return status::OK(). TF_CHECK_OK( @@ -133,12 +140,37 @@ xla::XlaOp MakeBilinearResizeKernel(xla::XlaBuilder* builder, xla::PrimitiveType::F32); return builder->Mul( builder->Mul(diag, - builder->ConstantR1(make_1d_kernel(kernel_size[1])), + builder->ConstantR1(Make1DKernel(kernel_size[1])), /*broadcast_dimensions=*/{1}), - builder->ConstantR1(make_1d_kernel(kernel_size[0])), + builder->ConstantR1(Make1DKernel(kernel_size[0])), /*broadcast_dimensions=*/{0}); } +xla::XlaOp MakeBilinearResizeKernelInDim(xla::XlaBuilder* builder, + gtl::ArraySlice kernel_size, + int64 channels, int64 dim) { + xla::XlaOp channels_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK( + XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + + auto diag = builder->ConvertElementType( + builder->Eq(builder->Broadcast( + channels_iota, + {dim == 0 ? (2 * kernel_size[0] - 1) : 1, + dim == 1 ? (2 * kernel_size[1] - 1) : 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), + xla::PrimitiveType::F32); + if (dim == 1) { + return builder->Mul( + diag, builder->ConstantR1(Make1DKernel(kernel_size[1])), + /*broadcast_dimensions=*/{1}); + } + return builder->Mul(diag, + builder->ConstantR1(Make1DKernel(kernel_size[0])), + /*broadcast_dimensions=*/{0}); +} + xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, const xla::XlaOp& input, const int num_spatial_dims, @@ -165,20 +197,42 @@ xla::XlaOp ResizeUsingDilationAndConvolution(xla::XlaBuilder* builder, dimension_numbers.add_output_spatial_dimensions(1 + i); dimension_numbers.add_kernel_spatial_dimensions(i); } - dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); - dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims + 1); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims); ResizeConvolutionDims dims = ComputeResizeConvolutionParameters(in_size, out_size); - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); - xla::XlaOp output = builder->ConvGeneralDilated( - input, kernel, dims.stride, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.kernel_size, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + xla::XlaOp output; + // Split convolutions into independent dimensions if they wmuld be a very + // large kernel. + if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { + xla::XlaOp kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + output = builder->ConvGeneralDilated( + input, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } else { + xla::XlaOp kernel0 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + output = builder->ConvGeneralDilated( + input, kernel0, {dims.stride[0], 1}, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + /*lhs_dilation=*/{dims.kernel_size[0], 1}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + xla::XlaOp kernel1 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + output = builder->ConvGeneralDilated( + output, kernel1, {1, dims.stride[1]}, + /*padding=*/ + {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/{1, dims.kernel_size[1]}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } // Add broadcasts to handle expanding from a size == 1 dimension to a // size > 1 dimension. @@ -214,26 +268,63 @@ xla::XlaOp ResizeUsingDilationAndConvolutionGradOp(xla::XlaBuilder* builder, } dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); - xla::XlaOp kernel = - MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::XlaOp output; + if (dims.kernel_size[0] * dims.kernel_size[1] < kMax2DKernelSize) { + xla::XlaOp kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && grad_size[i] > 1) { + kernel = + builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), + /*broadcast_dimensions=*/{i}); + } + } - // Broadcast the input kernel where the forward op expanded from a size == 1 - // dimension to a size > 1 dimension. This has the effect of summing the - // gradient contributions in that dimension. - for (int i = 0; i < num_spatial_dims; ++i) { - if (in_size[i] == 1 && grad_size[i] > 1) { - kernel = builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), - /*broadcast_dimensions=*/{i}); + output = builder->ConvGeneralDilated( + grad, kernel, /*window_strides=*/dims.kernel_size, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.stride, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } else { + xla::XlaOp kernel0 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 0); + xla::XlaOp kernel1 = + MakeBilinearResizeKernelInDim(builder, dims.kernel_size, channels, 1); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + if (in_size[0] == 1 && grad_size[0] > 1) { + kernel0 = + builder->Add(kernel0, builder->ConstantR1(grad_size[0], 0), + /*broadcast_dimensions=*/{0}); + } + if (in_size[1] == 1 && grad_size[1] > 1) { + kernel1 = + builder->Add(kernel0, builder->ConstantR1(grad_size[1], 0), + /*broadcast_dimensions=*/{1}); } - } - xla::XlaOp output = builder->ConvGeneralDilated( - grad, kernel, /*window_strides=*/dims.kernel_size, - /*padding=*/ - {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, - {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, - /*lhs_dilation=*/dims.stride, - /*rhs_dilation=*/{1, 1}, dimension_numbers); + output = builder->ConvGeneralDilated( + grad, kernel0, /*window_strides=*/{dims.kernel_size[0], 1}, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, {0, 0}}, + /*lhs_dilation=*/{dims.stride[0], 1}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + + output = builder->ConvGeneralDilated( + output, kernel1, /*window_strides=*/{1, dims.kernel_size[1]}, + /*padding=*/ + {{0, 0}, {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/{1, dims.stride[1]}, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + } // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. // Opposite of the slice performed by the forward op. diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc index 8c8a9bbe787f3224e7444b62dcf8ad99130cf37f..65ab9da8d7ca0509a4a69c43727a0e6c0435908a 100644 --- a/tensorflow/compiler/tf2xla/kernels/no_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc @@ -24,8 +24,7 @@ namespace tensorflow { REGISTER_XLA_OP(Name("NoOp").CompilationOnly(), NoOp); // We register ControlTrigger as a no-op. This is correct since nodes seen -// by the XLA compiler are never dead. This may need rethinking when we add -// support for conditionals to XLA. -REGISTER_XLA_OP(Name("ControlTrigger"), NoOp); +// by the XLA compiler are never dead. +REGISTER_XLA_OP(Name("ControlTrigger").CompilationOnly(), NoOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 5f5bd586376ab368e443671ac8a5de23a5fd604b..105be38fe26b6667e8b4ce6da92a3969cdc0c187 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -17,6 +17,9 @@ limitations under the License. // TODO(misard,phawkins): handle random number generator seeds/states correctly. // TODO(misard,phawkins): add tests. +#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -55,6 +58,78 @@ class RandomUniformOp : public XlaOpKernel { REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"), RandomUniformOp); +class RandomShuffleOp : public XlaOpKernel { + public: + explicit RandomShuffleOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + auto builder = ctx->builder(); + xla::XlaOp input = ctx->Input(0); + TensorShape input_shape = ctx->InputShape(0); + const int64 n = input_shape.dim_size(0); + int64 num_elements = 1; + for (tensorflow::TensorShapeDim dimension : input_shape) { + num_elements *= dimension.size; + } + if (num_elements <= 1 || n <= 1) { + // No shuffling is required, so copy input directly to output + ctx->SetOutput(0, input); + } else { + // Generate the random swaps for the indices. + auto swaps_shape = xla::ShapeUtil::MakeShape(xla::S32, {n}); + auto swaps = + builder->RngUniform(builder->ConstantR0(0), + builder->ConstantR0(n), swaps_shape); + + // Generate range(n) as the initial value for the indices to be swapped. + xla::XlaOp indices; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, n, &indices)); + + // Swap the indices at i and swaps[i]. + auto swap_body_fn = [&](xla::XlaOp i, + gtl::ArraySlice loop_vars, + xla::XlaBuilder* builder) + -> xla::StatusOr> { + auto swaps = loop_vars[0]; + auto indices = loop_vars[1]; + i = builder->Reshape(i, {1}); + // temp = indices[i] + auto temp = builder->DynamicSlice(indices, i, {1}); + // swap_index = swaps[i] + auto swap_index = builder->DynamicSlice(swaps, i, {1}); + // swap_value = indices[swaps[i]] + auto swap_value = builder->DynamicSlice(indices, swap_index, {1}); + // indices[i] = indices[swaps[i]] + indices = builder->DynamicUpdateSlice(indices, swap_value, i); + // indices[swaps[i]] = temp + indices = builder->DynamicUpdateSlice(indices, temp, swap_index); + return std::vector{swaps, indices}; + }; + // for i in range(n): + auto swap_loop_result = + XlaForEachIndex(n, xla::S32, swap_body_fn, {swaps, indices}, + "indices_swap_loop", builder) + .ValueOrDie(); + auto swapped_indices = swap_loop_result[1]; + + // Gather the data using the swapped indices as the shuffled order. + auto indices_tensor_shape = TensorShape({n}); + DataType type = ctx->expected_output_dtype(0); + xla::XlaOp gather; + OP_REQUIRES_OK(ctx, XlaGather(input, input_shape, swapped_indices, + indices_tensor_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, + DT_INT32, builder, &gather)); + ctx->SetOutput(0, gather); + } + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleOp); +}; + +REGISTER_XLA_OP(Name("RandomShuffle"), RandomShuffleOp); + class RandomUniformIntOp : public XlaOpKernel { public: explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -127,13 +202,8 @@ class TruncatedNormalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - xla::Shape xla_element_shape = - xla::ShapeUtil::MakeShape(xla_shape.element_type(), {}); xla::XlaBuilder* b = ctx->builder(); - xla::XlaOp mean = XlaHelpers::Zero(b, dtype); - xla::XlaOp stddev = XlaHelpers::One(b, dtype); - xla::XlaOp candidate = b->RngNormal(mean, stddev, xla_shape); auto two_sd = [dtype](bool negate, xla::XlaBuilder* b) { return XlaHelpers::FloatLiteral(b, dtype, negate ? -2.0 : 2.0); @@ -151,34 +221,38 @@ class TruncatedNormalOp : public XlaOpKernel { // out_of_range_mask := candidate < mean-2*sd || candidate > mean+2*sd // candidate = select(out_of_range_mask, rng_normal(), candidate) // } - std::unique_ptr test_builder = - b->CreateSubBuilder("truncated_normal_test"); - { - auto* b = test_builder.get(); - xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate"); - out_of_range_mask(candidate, b); - OP_REQUIRES_OK(ctx, Any(out_of_range_mask(candidate, b), b).status()); - } - - std::unique_ptr body_builder = - b->CreateSubBuilder("truncated_normal_body"); - { - auto* b = body_builder.get(); - xla::XlaOp candidate = b->Parameter(0, xla_shape, "candidate"); - xla::XlaOp to_resample = out_of_range_mask(candidate, b); + std::vector initial_values = { + // The current candidate. + b->Broadcast(XlaHelpers::Zero(b, dtype), shape.dim_sizes()), + // The to_resample mask, where 'true' identifies a location in the + // current candidate that is out of range and must be regenerated. + b->Broadcast(b->ConstantR0(true), shape.dim_sizes()), + // Is any element in the mask true? + b->ConstantR0(true)}; + auto condition = [&](gtl::ArraySlice values, + xla::XlaBuilder* b) -> xla::StatusOr { + // Continue while any element in the mask is true. + return values[2]; + }; + auto body = + [&](gtl::ArraySlice values, + xla::XlaBuilder* b) -> xla::StatusOr> { + xla::XlaOp candidate = values[0]; + xla::XlaOp to_resample = values[1]; xla::XlaOp mean = XlaHelpers::Zero(b, dtype); xla::XlaOp stddev = XlaHelpers::One(b, dtype); - b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), candidate); - } - - xla::StatusOr test_computation = test_builder->Build(); - OP_REQUIRES_OK(ctx, test_computation.status()); - xla::StatusOr body_computation = body_builder->Build(); - OP_REQUIRES_OK(ctx, body_computation.status()); - xla::XlaOp result = b->While(test_computation.ValueOrDie(), - body_computation.ValueOrDie(), candidate); - - ctx->SetOutput(0, result); + candidate = b->Select(to_resample, b->RngNormal(mean, stddev, xla_shape), + candidate); + // Compute a new to_resample mask, and determine whether any value is + // still out of range. + to_resample = out_of_range_mask(candidate, b); + TF_ASSIGN_OR_RETURN(xla::XlaOp done, Any(to_resample, b)); + return std::vector{candidate, to_resample, done}; + }; + auto result = + XlaWhileLoop(condition, body, initial_values, "truncated_normal", b); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()[0]); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 05354bca5bb089703fdcceb6f44648bbb98d004b..d59720bef742c7441ee01a954247013559bb909c 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -43,7 +43,7 @@ class ShapeOp : public XlaOpKernel { DataType out_dtype_; }; -REGISTER_XLA_OP(Name("Shape"), ShapeOp); +REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp); class ShapeNOp : public XlaOpKernel { public: @@ -65,7 +65,7 @@ class ShapeNOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); +REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp); class RankOp : public XlaOpKernel { public: @@ -81,7 +81,7 @@ class RankOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Rank"), RankOp); +REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp); class SizeOp : public XlaOpKernel { public: @@ -100,7 +100,7 @@ class SizeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Size"), SizeOp); +REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp); class ExpandDimsOp : public XlaOpKernel { public: @@ -189,10 +189,9 @@ class SqueezeOp : public XlaOpKernel { if (!wrapped_squeeze_dims.empty()) { if (wrapped_squeeze_dims.count(i) > 0) { OP_REQUIRES(ctx, existing_dim == 1, - errors::InvalidArgument("Tried to explicitly squeeze " - "dimension ", - i, " but dimension was not 1: ", - existing_dim)); + errors::InvalidArgument( + "Tried to explicitly squeeze dimension ", i, + " but dimension was not 1: ", existing_dim)); } else { // This dimension is not being squeezed. new_shape.push_back(existing_dim); diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 6109db8e89e5ee67e0635d26e258bfe7cb70a15d..a163fa0a5b34675e46d0d7c5f4e0ccb1e3fb18eb 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -57,7 +57,7 @@ class ReadVariableOp : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp); +REGISTER_XLA_OP(Name("ReadVariableOp").CompilationOnly(), ReadVariableOp); class AssignVariableOp : public XlaOpKernel { public: @@ -67,7 +67,7 @@ class AssignVariableOp : public XlaOpKernel { ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1))); } }; -REGISTER_XLA_OP(Name("AssignVariableOp"), AssignVariableOp); +REGISTER_XLA_OP(Name("AssignVariableOp").CompilationOnly(), AssignVariableOp); class AssignAddVariableOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index 3f1384bc864abd882ebba2b90acbe0b1e664687a..20925118bf598a6436c43bd727ce40e3abafc46c 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -110,7 +110,6 @@ xla::StatusOr CholeskyUnblocked(xla::XlaBuilder* builder, FloatLiteral(body_builder, a_shape.element_type(), 0.5)); // a[..., i+1:, i] - auto ip1 = body_builder->Add(i, body_builder->ConstantR0(1)); // select the whole i-th column, then mask out all rows above i+1 TF_ASSIGN_OR_RETURN( auto a_0i, DynamicSliceInMinorDims(body_builder, body_a, {i}, {1})); diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 43e1c1e9fecec1c71db1509757251cb5d903ca49..db56b128375ce8ff2faf12c5d7ea256bdfab0f63 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,6 +40,37 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { return Status::OK(); } +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal) { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), + host_tensor.shape(), &xla_shape)); + *literal = xla::BorrowingLiteral( + static_cast(DMAHelper::base(&host_tensor)), xla_shape); + return Status::OK(); +} + +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal) { + std::vector buf_ptrs; + buf_ptrs.reserve(host_tensors.size()); + std::vector tensor_shapes(host_tensors.size()); + + for (int i = 0; i < host_tensors.size(); i++) { + // Validate runtime shapes and fail if it doesn't match the contract. + const Tensor* tensor = &host_tensors[i]; + buf_ptrs.emplace_back(static_cast(DMAHelper::base(tensor))); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(tensor->dtype(), tensor->shape(), + &tensor_shapes[i])); + } + + *literal = xla::BorrowingLiteral( + buf_ptrs, xla::ShapeUtil::MakeTupleShape(tensor_shapes)); + + return Status::OK(); +} + Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, Tensor* host_tensor) { TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index 220bec15538c36fa30abef9e729b64dbbb9f72b3..74685025c1780c5c0ba56205a98786582e9191e9 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { @@ -29,6 +30,17 @@ namespace tensorflow { // unsupported type. Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); +// Returns a BorrowingLiteral that utilizes the same underlying buffer owned by +// 'host_tensor'. +Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal); + +// Returns a BorrowingLiteral tuple that utilizes the same underlying buffers +// owned by 'host_tensors'. +Status HostTensorsToBorrowingLiteralTuple( + tensorflow::gtl::ArraySlice host_tensors, + xla::BorrowingLiteral* literal); + // Copies 'literal' to freshly allocated 'host_tensor', which is allocated of // type . // Fails if the literal's primitive type != diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 3a08aa8cf4f5cea6210cc9470d57c3387445ea6e..ac768b206e2a8d163a4253432a1911152f89ce86 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -263,8 +263,7 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, // Compile the graph into an XLA computation. XlaCompiler::Options compiler_options; compiler_options.client = client; - DeviceType device_type(DEVICE_CPU_XLA_JIT); - compiler_options.device_type = &device_type; + compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); compiler_options.flib_def = &graph->flib_def(); compiler_options.graph_def_version = graph->versions().producer(); compiler_options.allow_cpu_custom_calls = true; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index f7098917b191058c53a1d6a5923e80e5e8319d72..9c8e56a17e07348d3cfaaca0b5eb335295af05c3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -83,12 +83,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) : options_(options), initialization_status_(Status::OK()), next_step_id_(1), - device_( - new XlaCompilationDevice(SessionOptions(), *options_.device_type)), + device_(new XlaCompilationDevice(SessionOptions(), options_.device_type)), device_mgr_({device_}) { - // We no longer need the device_type. - options_.device_type = nullptr; - + CHECK(!options_.device_type.type_string().empty()); if (options_.populate_resource_manager) { initialization_status_ = (*options_.populate_resource_manager)(device_->resource_manager()); @@ -228,7 +225,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // Computes the XLA shape for argument 'arg'. Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, bool is_entry_computation, - xla::Shape* xla_shape) { + xla::Shape* xla_shape) const { switch (arg.kind) { case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; @@ -655,10 +652,70 @@ Status XlaCompiler::CompileSingleOp( .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); } + FixupSourceAndSinkEdges(graph.get()); return CompileGraph(options, name, std::move(graph), args, result); } +namespace { + +// Check that the ops of all non-functional nodes have been registered. +string ValidateFunctionDef(const FunctionDef* fdef, + const FunctionLibraryDefinition& flib_def) { + std::vector invalid_ops; + for (const NodeDef& node : fdef->node_def()) { + const string& op = node.op(); + if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { + continue; + } + const OpDef* op_def; + if (!OpRegistry::Global()->LookUpOpDef(op, &op_def).ok()) { + invalid_ops.push_back(op); + } + } + return tensorflow::str_util::Join(invalid_ops, ", "); +} + +// Check that the graph doesn't have any invalid nodes (e.g. incompatible with +// given device_type, invalid data type, missing attributes...) +Status ValidateGraph(const Graph* graph, + const FunctionLibraryDefinition& flib_def, + const DeviceType& device_type, const string& name) { + std::vector invalid_ops; + for (const Node* node : graph->nodes()) { + if (node->type_string() == FunctionLibraryDefinition::kGradientOp) { + continue; + } + const FunctionDef* fdef = flib_def.Find(node->def().op()); + if (fdef) { + string error_msg = ValidateFunctionDef(fdef, flib_def); + if (!error_msg.empty()) { + invalid_ops.push_back( + strings::StrCat(node->def().op(), ":{", error_msg, "}")); + } + continue; + } + const OpDef* op_def; + if (!OpRegistry::Global()->LookUpOpDef(node->def().op(), &op_def).ok()) { + invalid_ops.push_back(node->def().op()); + continue; + } + TF_RETURN_IF_ERROR(ValidateNodeDef(node->def(), *op_def)); + if (!FindKernelDef(device_type, node->def(), nullptr, nullptr).ok()) { + invalid_ops.push_back(node->def().op()); + } + } + if (!invalid_ops.empty()) { + return errors::InvalidArgument(strings::StrCat( + "Detected unsupported operations when trying to compile graph ", name, + " on ", device_type.type_string(), ":", + tensorflow::str_util::Join(invalid_ops, ", "))); + } + return Status::OK(); +} + +} // namespace + Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -681,6 +738,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, FunctionalizeControlFlow(flib_runtime_->GetFunctionLibraryDefinition(), graph.get(), local_flib_def_.get())); + // Detect invalid nodes. + // FunctionalizeControlFlow may remove some nodes from the graph. + TF_RETURN_IF_ERROR(ValidateGraph(graph.get(), *options_.flib_def, + options_.device_type, name)); + xla::XlaBuilder builder(name); XlaContext* context = new XlaContext( this, &builder, options_.allow_cpu_custom_calls, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index bf496bd8bc81e67056eba380288bca88737cc00d..c93850ce270502ea1df1f6469963e96e86994fa2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -244,9 +245,9 @@ class XlaCompiler { typedef std::function ShapeRepresentationFn; struct Options { - // Name of the compilation device to use. Needs to be live only during - // XlaCompiler's constructor. - const DeviceType* device_type = nullptr; + // Name of the compilation device to use. It must be set by the caller. + // The default empty value is invalid. + DeviceType device_type = DeviceType(""); xla::Client* client = nullptr; @@ -313,7 +314,7 @@ class XlaCompiler { // See the class comment for more details about the argument passing // convention. Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation, - xla::Shape* xla_shape); + xla::Shape* xla_shape) const; // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 55772ca324872f6d5fac008de7819b7fae64966a..613230452b74755ce7543ec2ab82861aa0dfeb7a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -45,8 +46,6 @@ namespace tensorflow { class XlaCompilerTest : public ::testing::Test { protected: - XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} - void SetUp() override { client_ = xla::ClientLibrary::LocalClientOrDie(); @@ -58,7 +57,7 @@ class XlaCompilerTest : public ::testing::Test { XlaCompiler::Options DefaultOptions() { XlaCompiler::Options options; - options.device_type = &cpu_device_type_; + options.device_type = DeviceType(DEVICE_CPU_XLA_JIT); options.client = client_; options.flib_def = flib_def_.get(); return options; @@ -68,7 +67,6 @@ class XlaCompilerTest : public ::testing::Test { return compiler->local_flib_def_.get(); } - DeviceType cpu_device_type_; xla::Client* client_; std::unique_ptr flib_def_; }; @@ -979,5 +977,115 @@ TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } +// Tests a graph which has a function with an invalid op. +TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { + XlaCompiler compiler(DefaultOptions()); + + FunctionDefLibrary flib; + FunctionDef fn = FillFn(); + NodeDef* node = fn.add_node_def(); + node->set_name("Invalid"); + node->set_op("InvalidOp"); /* unsupported op */ + node = fn.add_node_def(); + node->set_name("Switch"); + node->set_op("Switch"); /* control flow node */ + *flib.add_function() = fn; + + TF_ASSERT_OK(flib_def_->AddFunctionDef(fn)); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib)); + + NodeDef def; + TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get()) + .Input(value.name(), 0, DT_INT32) + .Input(shape.name(), 1, DT_INT32) + .Finalize(&def)); + Status status; + Node* fill = scope.graph()->AddNode(def, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.DoShapeInference(fill)); + scope.graph()->AddEdge(value.node(), 0, fill, 0); + scope.graph()->AddEdge(shape.node(), 0, fill, 1); + + auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); + + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + std::vector args; + XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE( + str_util::StrContains(status.error_message(), "FillFn:{InvalidOp}")) + << status.error_message(); +} + +// Tests a graph which has a node with invalid data type. +TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef shape; + shape.set_name("Shape"); + shape.set_op("Shape"); + (*shape.mutable_attr())["T"].set_type(DT_INT32); + (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */ + Status status; + Node* shape_node = graph->AddNode(shape, &status); + TF_ASSERT_OK(status); + graph->AddControlEdge(graph->source_node(), shape_node); + + std::vector args; + XlaCompiler::CompilationResult result; + XlaCompiler compiler(DefaultOptions()); + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type", + std::move(graph), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "is not in the list of allowed values")) + << status.error_message(); +} + +TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + NodeDef no_op; + no_op.set_name("NoOp"); + no_op.set_op("NoOp"); + Status status; + graph->AddNode(no_op, &status); + TF_ASSERT_OK(status); + + std::vector args; + XlaCompiler compiler(DefaultOptions()); + // No control edge linking NoOp with source/sink. + { + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + XlaCompiler::CompilationResult result; + status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(str_util::StrContains(status.error_message(), + "The following nodes are unreachable " + "from the source in the graph: NoOp")) + << status.error_message(); + } + + // Fix control edges for NoOp. + { + std::unique_ptr graph_copy(new Graph(OpRegistry::Global())); + CopyGraph(*graph, graph_copy.get()); + EXPECT_TRUE(FixupSourceAndSinkEdges(graph_copy.get())); + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp", + std::move(graph_copy), args, &result)); + EXPECT_EQ(0, result.resource_updates.size()); + } +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f1594193af09c7193f03b4685d3a7d4510d654dd..a1da176fe30ddd0d4460a51b60b2568ecc1af6aa 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -19,11 +19,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -210,8 +212,9 @@ Status XlaHelpers::Iota(xla::XlaBuilder* builder, DataType dtype, int64 size, return errors::InvalidArgument("Invalid argument type ", DataTypeString(dtype)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); + *iota = builder->ConstantLiteral(linspace_literal); return Status::OK(); } @@ -245,8 +248,8 @@ Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis, return errors::InvalidArgument("Invalid argument type ", DataTypeString(index_type)); } - xla::Literal linspace_literal; - TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + xla::BorrowingLiteral linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToBorrowingLiteral(linspace, &linspace_literal)); // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index c6deb959a59f7b79500a0948b4035ea56cd9b4a1..1b8e516770c3e217dd7c2f26ce426895b478c2e4 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -53,7 +53,6 @@ xla_proto_library( deps = [ ":xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/compiler/xla/service:session_proto", ], ) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index aacb394ae5f92aa0d87ee3a23bcc3d4ec5cd99a3..8f08d3b2e04670ad6590aca1db0fd9d25faed83f 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -86,6 +86,7 @@ cc_library( hdrs = ["executable_build_options.h"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/core:lib", @@ -109,6 +110,7 @@ cc_library( "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:source_map_util", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index c9d275a77b5cd40225f4b5c45e02c242d27d9aa1..3d596a6e65430b6e9692aabd65fc8aa84b7b873d 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -64,7 +64,7 @@ StatusOr> Client::Transfer( } StatusOr> Client::TransferToServer( - const Literal& literal, const DeviceHandle* device_handle) { + const LiteralSlice& literal, const DeviceHandle* device_handle) { TransferToServerRequest request; *request.mutable_literal() = literal.ToProto(); if (device_handle) { @@ -91,7 +91,7 @@ StatusOr> Client::TransferToServer( return MakeUnique(stub_, response.data()); } -Status Client::TransferToInfeed(const Literal& literal, int64 replica_id, +Status Client::TransferToInfeed(const LiteralSlice& literal, int64 replica_id, const DeviceHandle* device_handle) { TransferToInfeedRequest request; *request.mutable_literal() = literal.ToProto(); diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index d57e2536d0b44cda46d7c1c2513b82c9f8a31c1b..68f0d0ac78c859fde7a6a007cd250b047a7bfcda 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -107,14 +107,14 @@ class Client { // device (and its replicas if replication is enabled). Otherwise, data is // transferred to the default device (and its replicas). StatusOr> TransferToServer( - const Literal& literal, const DeviceHandle* device_handle = nullptr); + const LiteralSlice& literal, const DeviceHandle* device_handle = nullptr); // Transfer the given literal to the Infeed interface of the device. // // device_handle and replica_id together specify a particular device; a device // assigned for the given replica_id among the replicas that the given device // handle belongs to. - Status TransferToInfeed(const Literal& literal, int64 replica_id = 0, + Status TransferToInfeed(const LiteralSlice& literal, int64 replica_id = 0, const DeviceHandle* device_handle = nullptr); // Transfers from the Outfeed of the device. @@ -153,8 +153,6 @@ class Client { // // If output_layout is non-null, then the output of the computation will be // stored using that layout. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> ComputeConstant( const XlaComputation& computation, const Layout* output_layout = nullptr) const; diff --git a/tensorflow/compiler/xla/client/compile_only_client.cc b/tensorflow/compiler/xla/client/compile_only_client.cc index dc69d2097ebe14ca0e14a39849d4fcae99024fdc..5c9abad4c3126be5e45e96c770c0679fe8606788 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.cc +++ b/tensorflow/compiler/xla/client/compile_only_client.cc @@ -24,7 +24,8 @@ namespace xla { StatusOr>> CompileOnlyClient::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { + const AotCompilationOptions& options, + std::unique_ptr* metadata) { std::vector service_instances; service_instances.reserve(computations.size()); for (const AotXlaComputationInstance& instance : computations) { @@ -36,7 +37,8 @@ CompileOnlyClient::CompileAheadOfTime( service_instance.argument_layouts = instance.argument_layouts; service_instance.result_layout = instance.result_layout; } - return compiler_service_->CompileAheadOfTime(service_instances, options); + return compiler_service_->CompileAheadOfTime(service_instances, options, + metadata); } int64 CompileOnlyClient::PointerSizeForTriple(tensorflow::StringPiece triple) { diff --git a/tensorflow/compiler/xla/client/compile_only_client.h b/tensorflow/compiler/xla/client/compile_only_client.h index f9a7c31270c7a11175f47a537639a97d0c9211af..332c96503637344d56e363e19db4880c37ca9684 100644 --- a/tensorflow/compiler/xla/client/compile_only_client.h +++ b/tensorflow/compiler/xla/client/compile_only_client.h @@ -46,13 +46,15 @@ class CompileOnlyClient : public Client { const Shape* result_layout; }; - // Compiles a list of xla computations for ahead-of-time execution. This is - // intended for use in static compilation. The |options| parameter describes - // the target for which the compiler should emit code. + // Compiles a list of xla computations for ahead-of-time execution. + // This is intended for use in static compilation. The |options| + // parameter describes the target for which the compiler should emit + // code. |metadata|, if provided, is populated during compilation. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options); + const AotCompilationOptions& options, + std::unique_ptr* metadata = nullptr); // Returns the size of a pointer in bytes for a given triple. static int64 PointerSizeForTriple(tensorflow::StringPiece triple); diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc index 6e3c5cb484b8f1ef053fa287a4d462aeb886e530..7dee41f6a05025ec196b78e54015e8e71777031f 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.cc +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -87,6 +87,18 @@ ExecutableBuildOptions::dump_optimized_hlo_proto_to() const { return dump_optimized_hlo_proto_to_; } +ExecutableBuildOptions& +ExecutableBuildOptions::set_dump_unoptimized_hlo_proto_to( + tensorflow::StringPiece dirpath) { + dump_unoptimized_hlo_proto_to_ = dirpath.ToString(); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::dump_unoptimized_hlo_proto_to() const { + return dump_unoptimized_hlo_proto_to_; +} + ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to( tensorflow::StringPiece dirpath) { dump_per_pass_hlo_proto_to_ = dirpath.ToString(); diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h index 11f10983606fe02b1edb11a260edde8e5f9a726f..9dc9be4423564fb967b247c2d1df31099cb80237 100644 --- a/tensorflow/compiler/xla/client/executable_build_options.h +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -64,6 +65,13 @@ class ExecutableBuildOptions { tensorflow::StringPiece dirpath); const tensorflow::gtl::optional& dump_optimized_hlo_proto_to() const; + // If set, specifies a dirpath to dump the start-of-optimization-pipeline HLO + // protobuf to (as in DebugOptions). + ExecutableBuildOptions& set_dump_unoptimized_hlo_proto_to( + tensorflow::StringPiece dirpath); + const tensorflow::gtl::optional& dump_unoptimized_hlo_proto_to() + const; + // If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs // to (as in DebugOptions). ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to( @@ -76,6 +84,13 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_hlo_profile(bool enabled); tensorflow::gtl::optional hlo_profile() const; + void add_disabled_hlo_pass(tensorflow::StringPiece pass_name) { + disabled_hlo_passes_.push_back(std::string(pass_name)); + } + const tensorflow::gtl::ArraySlice disabled_hlo_passes() const { + return disabled_hlo_passes_; + } + // Returns a string representation of the build options, suitable for // debugging. string ToString() const; @@ -87,8 +102,10 @@ class ExecutableBuildOptions { bool result_layout_set_ = false; tensorflow::gtl::optional generate_hlo_graph_; tensorflow::gtl::optional dump_optimized_hlo_proto_to_; + tensorflow::gtl::optional dump_unoptimized_hlo_proto_to_; tensorflow::gtl::optional dump_per_pass_hlo_proto_to_; DeviceMemoryAllocator* device_allocator_ = nullptr; + std::vector disabled_hlo_passes_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index a7c55c6b2b7fe2b5541ce71bf3eaa24114522fc5..ae0308020d014e038d2f0fd7de6c5f372d6cbed1 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -185,7 +185,7 @@ StatusOr LocalExecutable::Run( run_options, backend_->StreamBorrower(), backend_->eigen_intra_op_thread_pool()); - if (executable_->dumping()) { + if (executable_->dumping_snapshot()) { return ExecuteAndDump(&service_options, arguments); } return executable_->ExecuteOnStreamWrapper( @@ -195,36 +195,36 @@ StatusOr LocalExecutable::Run( StatusOr LocalExecutable::ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments) { - executable_->session_module()->set_execution_platform( + executable_->hlo_snapshot()->set_execution_platform( backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module())); + TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->hlo_snapshot())); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer result, executable_->ExecuteOnStream(run_options, arguments, /*hlo_execution_profile=*/nullptr)); - TF_RETURN_IF_ERROR(RecordResult(&result, executable_->session_module())); - TF_RETURN_IF_ERROR(executable_->DumpSessionModule()); + TF_RETURN_IF_ERROR(RecordResult(&result, executable_->hlo_snapshot())); + TF_RETURN_IF_ERROR(executable_->DumpHloSnapshot()); return std::move(result); } Status LocalExecutable::RecordArguments( const tensorflow::gtl::ArraySlice arguments, - SessionModule* session_module) { - session_module->clear_arguments(); + HloSnapshot* hlo_snapshot) { + hlo_snapshot->clear_arguments(); for (const ShapedBuffer* argument : arguments) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*argument)); - *session_module->add_arguments() = literal->ToProto(); + *hlo_snapshot->add_arguments() = literal->ToProto(); } return Status::OK(); } Status LocalExecutable::RecordResult(const ShapedBuffer* result, - SessionModule* session_module) { - session_module->clear_result(); + HloSnapshot* hlo_snapshot) { + hlo_snapshot->clear_result(); TF_ASSIGN_OR_RETURN(std::unique_ptr literal, LiteralFromShapedBuffer(*result)); - *session_module->mutable_result() = literal->ToProto(); + *hlo_snapshot->mutable_result() = literal->ToProto(); return Status::OK(); } @@ -304,6 +304,11 @@ StatusOr> LocalClient::ShapedBufferToLiteral( shaped_buffer); } +StatusOr LocalClient::GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number) { + return local_service_->GlobalDataToShapedBuffer(data, replica_number); +} + Status LocalClient::TransferToInfeedLocal(const Literal& literal, int device_ordinal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index d63d4ec7f3744d507cc854213e430e25e861e559..4d9e0d7cd9d6ddebead1e12b23e94b529038039b 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/local_service.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" @@ -58,12 +59,18 @@ class LocalExecutable { // Validates that the given arguments and options satisfy various constraints // of the computation. + // + // The given ExecutableRunOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). Status ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& run_options, const Backend& backend); // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. + // + // The given ServiceExecutableRunOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). StatusOr ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments); @@ -72,11 +79,10 @@ class LocalExecutable { // proto. Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, - SessionModule* session_module); + HloSnapshot* hlo_snapshot); // Records the result of the computation in a SessionModule proto. - Status RecordResult(const ShapedBuffer* result, - SessionModule* session_module); + Status RecordResult(const ShapedBuffer* result, HloSnapshot* hlo_snapshot); // Returns a literal containing the contents of the given ShapedBuffer. StatusOr> LiteralFromShapedBuffer( @@ -109,6 +115,9 @@ class LocalClient : public Client { // Build and return a LocalExecutable object. The executable is compiled using // the given XlaComputation, argument layouts and options. + // + // The given ExecutableBuildOptions override any values from legacy_flags + // (TF_XLA_FLAGS environment variable). StatusOr> Compile( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -127,6 +136,11 @@ class LocalClient : public Client { StatusOr> ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer); + // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid + // as long as the handle is valid. + StatusOr GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number); + // Transfer the given literal to the infeed queue of the given device. // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD index 0d6e207971ec64515ec5e6da292910920edd101a..507a2dc5f088e159156f0ef3d663ba2819f6a2d4 100644 --- a/tensorflow/compiler/xla/client/xla_client/BUILD +++ b/tensorflow/compiler/xla/client/xla_client/BUILD @@ -37,7 +37,6 @@ cc_library( ], ) -# TODO(b/74197823): Replace computation_builder with xla_builder. cc_library( name = "xla_builder", srcs = ["xla_builder.cc"], diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index ae506317c2e4862d77cb4f0628e919871ad1aeb2..ae8fbdb2dc8b3f4fc21bcfef9692645f7e1d48b5 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -1611,14 +1611,41 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, }); } -XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids) { return NoteErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); + auto b = CreateSubBuilder("sum"); + b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), + b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); + TF_ASSIGN_OR_RETURN(auto computation, b->Build()); + return CrossReplicaSum(operand, computation, replica_group_ids, + /*channel_id=*/tensorflow::gtl::nullopt); + }); +} +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids, + const tensorflow::gtl::optional& channel_id) { + return NoteErrorOrReturn([&]() -> StatusOr { + if (channel_id.has_value()) { + return Unimplemented( + "replica_group_ids and channel_id and is not supported in AllReduce"); + } + + HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + for (int64 replica_group_id : replica_group_ids) { + instr.add_replica_group_ids(replica_group_id); + } + + AddCalledComputation(computation, &instr); return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, {operand}); diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 2b3013a91c488782098bd81994e899eae5a1f506..0329e42ed1aef8edd1537e888ddcd78f08584407 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -528,9 +528,35 @@ class XlaBuilder { tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding); - // Returns the sum of the operand value across all replicas. All replicas - // supply one input to the sum and all replicas receive the resulting sum. - XlaOp CrossReplicaSum(const XlaOp& operand); + // Returns the sum of the operand value within each subgroup of replicas. All + // replicas supply one input to the sum and all replicas receive the resulting + // sum for each subgroup. + XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice replica_group_ids = {}); + + // Enqueues an operation that do an AllReduce of the operand cross cores. Here + // AllReduce means doing a reduction on the input operand cross cores and then + // broadcasting the reduction result to those cores. The reduction function is + // defined by `computation`, which should be a commutative computation on + // scalars, e.g., add, min, or max. The way that AllReduce is applied is + // configured by: + // + // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // - `channel_id`: for Allreduce nodes from different models, if they have the + // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross models. + // + // TODO(b/79737069): Rename this to AllReduce when it's ready to use. + XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice replica_group_ids = {}, + const tensorflow::gtl::optional& channel_id = + tensorflow::gtl::nullopt); // Enqueues an operation that scatters the `source` array to the selected // indices of each window. diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index a76fdcda250168cbed2acd01bdd9ddc3b4c93b92..e8f29b83291a7cb238dc25b9f4bb743fe426a162 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -65,6 +65,16 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( + tensorflow::gtl::ArraySlice major_to_minor) { + Layout layout; + layout.set_format(DENSE); + for (int i = major_to_minor.size() - 1; i >= 0; i--) { + layout.add_minor_to_major(major_to_minor[i]); + } + return layout; +} + /* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) { Layout layout; layout.set_format(SPARSE); @@ -88,8 +98,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } // namespace /* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) { + if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) { + // Opaque and token types have empty layouts. + return Layout(); + } + // A Layout proto corresponds to a single array, not a tuple. - DCHECK(!ShapeUtil::IsTuple(shape)); + CHECK(ShapeUtil::IsArray(shape)); return CreateDefaultLayoutForRank(shape.dimensions_size()); } @@ -116,14 +131,15 @@ Layout CreateDefaultLayoutForRank(int64 rank) { SetToDefaultLayout(&element_shape); } shape->clear_layout(); - } else if (ShapeUtil::IsOpaque(*shape)) { - shape->clear_layout(); - } else { + } else if (ShapeUtil::IsArray(*shape)) { shape->mutable_layout()->set_format(DENSE); tensorflow::protobuf::RepeatedField* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); minor_to_major->Resize(shape->dimensions_size(), 0); SetDefaultLayoutToContainer(minor_to_major); + } else { + // Opaque, token types etc. have no layout. + shape->clear_layout(); } } @@ -150,18 +166,20 @@ Layout CreateDefaultLayoutForRank(int64 rank) { TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); } return Status::OK(); - } else if (ShapeUtil::IsOpaque(shape)) { - if (shape.has_layout()) { - return InvalidArgument("opaque should not have a layout field"); - } - return Status::OK(); - } else { - // Array shape. + } else if (ShapeUtil::IsArray(shape)) { if (!shape.has_layout()) { return InvalidArgument("shape %s does not have a layout", ShapeUtil::HumanString(shape).c_str()); } return ValidateLayoutForShape(shape.layout(), shape); + } else { + // Token, opaque, etc. shape. + if (shape.has_layout()) { + return InvalidArgument( + "shape of primitive type %s should not have a layout", + PrimitiveType_Name(shape.element_type()).c_str()); + } + return Status::OK(); } } @@ -171,8 +189,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (ShapeUtil::IsOpaque(shape)) { - return Status::OK(); + if (!ShapeUtil::IsArray(shape)) { + return InvalidArgument( + "shape of primitive type %s should not have a layout", + PrimitiveType_Name(shape.element_type()).c_str()); } if (layout.format() == INVALID_FORMAT) { @@ -263,7 +283,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ bool LayoutUtil::IsPadded(const Shape& shape) { - if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) || + if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) || shape.layout().padded_dimensions_size() == 0) { return false; } @@ -313,7 +333,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) { // Tuple shape: all subshapes must have a layout. return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), [](const Shape& s) { return HasLayout(s); }); - } else if (ShapeUtil::IsOpaque(shape)) { + } else if (!ShapeUtil::IsArray(shape)) { + // Opaque, token types etc. ignore layout. return true; } return shape.has_layout() && shape.layout().format() != INVALID_FORMAT; @@ -422,12 +443,9 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { /* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs) { - if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) { - return false; - } if (ShapeUtil::IsTuple(lhs)) { - if (ShapeUtil::TupleElementCount(lhs) != - ShapeUtil::TupleElementCount(rhs)) { + if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) != + ShapeUtil::TupleElementCount(rhs)) { return false; } for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) { @@ -436,9 +454,12 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { } } return true; - } else { + } else if (ShapeUtil::IsArray(lhs)) { return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && LayoutUtil::Equal(lhs.layout(), rhs.layout()); + } else { + // Layouts of non-array and non-tuple shapes is ignored. + return true; } } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index d3d6a2cc94012f7113fd1cb1b17e9c9d5323d9bf..739bbe73675c7fb855627006028eafdf703d6540 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -36,6 +36,10 @@ class LayoutUtil { // convenience function for protobuf construction.) static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + // Similar to MakeLayout, but take indices in reverse order. + static Layout MakeLayoutFromMajorToMinor( + tensorflow::gtl::ArraySlice major_to_minor); + // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) static Layout MakeSparseLayout(int64 max_sparse_elements); diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 4fd1d818e3e3b417eee9f6b14bb598bfb9480c6e..e4c825450dcd45a8fbeaacbb2ad145f94307176f 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -218,6 +218,47 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) { "elements, but shape is rank")); } +TEST_F(LayoutUtilTest, CopyTokenLayout) { + Shape src = ShapeUtil::MakeTokenShape(); + Shape dst = ShapeUtil::MakeTokenShape(); + + // Layouts are trivially the same for token types and copying layouts should + // be a nop. + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyOpaqueLayout) { + Shape src = ShapeUtil::MakeOpaqueShape(); + Shape dst = ShapeUtil::MakeOpaqueShape(); + + // Layouts are trivially the same for opaque types and copying layouts should + // be a nop. + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + +TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) { + Shape src = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})}); + Shape dst = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, ClearLayoutTuple) { Shape shape = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), @@ -236,6 +277,16 @@ TEST_F(LayoutUtilTest, ClearLayoutTuple) { EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout()); } +TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) { + // Opaque and token types trivially have layouts. + for (Shape shape : + {ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) { + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + LayoutUtil::ClearLayout(&shape); + EXPECT_TRUE(LayoutUtil::HasLayout(shape)); + } +} + TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) { Shape shape = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}), diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc index 3696fdbe12e311af3b286ef0dfe91377983b72dd..bf9679cafec72c2e9dc5796e9058c6703239c508 100644 --- a/tensorflow/compiler/xla/literal_comparison.cc +++ b/tensorflow/compiler/xla/literal_comparison.cc @@ -716,9 +716,11 @@ Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) { } return AppendStatus(result, - tensorflow::strings::Printf("expected: %s\nactual: %s", - expected.ToString().c_str(), - actual.ToString().c_str())); + tensorflow::strings::Printf( + "\nat index: %s\nexpected: %s\nactual: %s", + Literal::MultiIndexAsString(multi_index).c_str(), + ToStringTruncated(expected).c_str(), + ToStringTruncated(actual).c_str())); } Status Near(const LiteralSlice& expected, const LiteralSlice& actual, diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 4c560767dc603bf805f365d594810f4df7e90ed3..6b295897004cebce003ddd3999aacf63915ffe5f 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -807,6 +807,47 @@ std::unique_ptr LiteralBase::Relayout( return result; } +StatusOr> LiteralBase::Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const { + if (!ShapeUtil::IsArray(shape())) { + return InvalidArgument("Broadcast only supports arrays."); + } + + for (int64 i = 0; i < dimensions.size(); i++) { + TF_RET_CHECK(shape().dimensions(i) == + result_shape.dimensions(dimensions[i])); + } + + std::unique_ptr result = MakeUnique(result_shape); + + // scratch_source_index is temporary storage space for the computed index into + // the input literal. We put it here to avoid allocating an std::vector in + // every iteration of ShapeUtil::ForEachIndex. + std::vector scratch_source_index(shape().dimensions_size()); + + char* dest_data = static_cast(result->untyped_data()); + const char* source_data = static_cast(untyped_data()); + const int64 primitive_size = + ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type()); + + ShapeUtil::ForEachIndex( + result_shape, [&](tensorflow::gtl::ArraySlice output_index) { + for (int64 i = 0; i < dimensions.size(); ++i) { + scratch_source_index[i] = output_index[dimensions[i]]; + } + int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex( + result_shape, output_index); + int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex( + shape(), scratch_source_index); + memcpy(dest_data + primitive_size * dest_index, + source_data + primitive_size * source_index, primitive_size); + return true; + }); + + return std::move(result); +} + StatusOr> LiteralBase::Reshape( tensorflow::gtl::ArraySlice dimensions) const { if (!ShapeUtil::IsArray(shape())) { @@ -946,6 +987,23 @@ std::unique_ptr LiteralBase::Transpose( return new_literal; } +template +std::unique_ptr LiteralBase::SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const { + auto result_literal = MakeUnique(result_shape); + DimensionVector new_indices(ShapeUtil::Rank(result_shape)); + result_literal->EachCell( + [&](tensorflow::gtl::ArraySlice indices, NativeT /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + NativeT value = Get(new_indices); + result_literal->Set(indices, value); + }); + return result_literal; +} + std::unique_ptr LiteralBase::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const { @@ -963,51 +1021,17 @@ std::unique_ptr LiteralBase::Slice( const auto result_shape = ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, LayoutUtil::MinorToMajor(shape())); - - auto result_literal = MakeUnique(result_shape); - - DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { case F32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, float /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - float value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); + case BF16: + return SliceInternal(result_shape, start_indices); case C64: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, complex64 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - complex64 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); case S32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, int32 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - int32 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); case U32: - result_literal->EachCell( - [&](tensorflow::gtl::ArraySlice indices, uint32 /*value*/) { - for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { - new_indices[i] = indices[i] + start_indices[i]; - } - uint32 value = Get(new_indices); - result_literal->Set(indices, value); - }); - return result_literal; + return SliceInternal(result_shape, start_indices); default: LOG(FATAL) << "not yet implemented: " << PrimitiveType_Name(result_shape.element_type()); @@ -2317,28 +2341,28 @@ LiteralSlice::LiteralSlice(const LiteralBase& literal, : LiteralBase(), root_piece_(&literal.piece(view_root)) {} BorrowingLiteral::BorrowingLiteral(const char* src_buf_ptr, const Shape& shape) - : LiteralBase(), shape_(shape) { - CHECK(ShapeUtil::IsArray(shape_)); + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsArray(*shape_)); CHECK_NE(src_buf_ptr, nullptr); - CHECK(LayoutUtil::HasLayout(shape_)); + CHECK(LayoutUtil::HasLayout(*shape_)); root_piece_ = Piece(); root_piece_.set_buffer(const_cast(src_buf_ptr)); - root_piece_.set_subshape(&shape_); + root_piece_.set_subshape(shape_.get()); } BorrowingLiteral::BorrowingLiteral( tensorflow::gtl::ArraySlice src_buf_ptrs, const Shape& shape) - : LiteralBase(), shape_(shape) { - CHECK(ShapeUtil::IsTuple(shape_)); - CHECK(!ShapeUtil::IsNestedTuple(shape_)); - CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(shape_)); + : LiteralBase(), shape_(MakeUnique(shape)) { + CHECK(ShapeUtil::IsTuple(*shape_)); + CHECK(!ShapeUtil::IsNestedTuple(*shape_)); + CHECK_EQ(src_buf_ptrs.size(), ShapeUtil::TupleElementCount(*shape_)); root_piece_ = Piece(); - root_piece_.set_subshape(&shape_); - BuildPieceSubtree(shape_, &root_piece_); + root_piece_.set_subshape(shape_.get()); + BuildPieceSubtree(*shape_, &root_piece_); for (int i = 0; i < src_buf_ptrs.size(); ++i) { - const auto& src_shape = shape_.tuple_shapes(i); + const auto& src_shape = shape_->tuple_shapes(i); CHECK(ShapeUtil::IsArray(src_shape)); root_piece_.child(i).set_buffer(const_cast(src_buf_ptrs[i])); } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 609dc7a3aca646a5bb787487de101ac115df8ea5..8e4159e360e042beb31a75c432a3c7dfa7356007 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -277,6 +277,12 @@ class LiteralBase { StatusOr> Reshape( tensorflow::gtl::ArraySlice dimensions) const; + // Creates a new literal by broadcasting this literal with `dimensions` to + // yield a literal of shape `result_shape`. + StatusOr> Broadcast( + const Shape& result_shape, + tensorflow::gtl::ArraySlice dimensions) const; + // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers // in the original literal, and it specifies the order of the new dimensions @@ -536,6 +542,12 @@ class LiteralBase { friend class Literal; friend class LiteralSlice; friend class BorrowingLiteral; + + private: + template + std::unique_ptr SliceInternal( + const Shape& result_shape, + tensorflow::gtl::ArraySlice start_indices) const; }; // Class representing literal values in XLA. @@ -1087,8 +1099,10 @@ class BorrowingLiteral : public LiteralBase { const Piece& root_piece() const override { return root_piece_; }; Piece root_piece_; - // Shape of this literal. - const Shape shape_; + // Shape of this literal. Stored as unique_ptr so such that the (default) + // move construction of this class would be trivially correct: the pointer to + // Shape root_piece_ stores will still point to the correct address. + std::unique_ptr shape_; }; template diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 77f979a0d701f09162e112b69f6128008872aa18..53b926163c472c3ed7b72bf8b035d13996d59e34 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -1431,7 +1431,7 @@ TEST_F(LiteralUtilTest, LiteralSliceOfALiteralSlice) { EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); } -TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { +TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtr) { std::vector int64_values = {1, 2, 3}; const Shape literal_shape = ShapeUtil::MakeShape(S64, {3}); @@ -1443,7 +1443,7 @@ TEST_F(LiteralUtilTest, BorrowingLiteralFromOneBufferPtrTest) { EXPECT_EQ(literal.Get({2}), 3); } -TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrsTest) { +TEST_F(LiteralUtilTest, BorrowingLiteralFromMultipleBufferPtrs) { std::vector one_two_three = {1, 2, 3}; const Shape one_two_three_shape = ShapeUtil::MakeShape(S64, {3}); @@ -1810,5 +1810,35 @@ TEST_F(LiteralUtilTest, GetSparseElementAsString) { tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } +TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) { + std::unique_ptr literal = Literal::CreateR1({1, 2}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{0})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 1}, {2, 2}})); +} + +TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) { + std::unique_ptr literal = Literal::CreateR1({1, 2}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}), + /*dimensions=*/{1})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{1, 2}, {1, 2}})); +} + +TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) { + std::unique_ptr literal = Literal::CreateR0(9); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr broadcasted_literal, + literal->Broadcast( + /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}), + /*dimensions=*/{})); + EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2({{9, 9}, {9, 9}})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 932cce943f7c046a85984e6e5ed6b59dae371473..83834c1ff65ea2f9989fe08279c29056d9070adb 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -12,6 +12,7 @@ py_library( deps = [ ":pywrap_xla", "//tensorflow/compiler/xla:xla_data_proto_py", + "//tensorflow/compiler/xla/service:hlo_proto_py", ], ) @@ -53,6 +54,7 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index cb4dc1782b680fca1485e883343fbb262b86b1d1..445cee1aa7b462f7ae2b6b0771ff57f0c8f3db99 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" namespace xla { - namespace swig { // TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of @@ -97,6 +96,36 @@ const ScopedShapedBuffer* LocalShapedBuffer::shaped_buffer() const { return &shaped_buffer_; } +ShapedBuffer LocalShapedBuffer::Release() { return shaped_buffer_.release(); } + +LocalShapedBufferTuple::LocalShapedBufferTuple( + std::vector elements) + : elements_(std::move(elements)) { + for (auto* element : elements_) { + DCHECK(element != nullptr); + } +} + +LocalShapedBufferTuple::~LocalShapedBufferTuple() { + for (LocalShapedBuffer* element : elements_) { + if (element != nullptr) { + delete element; + } + } +} + +StatusOr LocalShapedBufferTuple::Release(int i) { + LocalShapedBuffer* element = elements_[i]; + if (element == nullptr) { + return InvalidArgument("Attempted to release already-released element %d.", + i); + } + elements_[i] = nullptr; + return element; +} + +int LocalShapedBufferTuple::size() const { return elements_.size(); } + static StatusOr ToBuffer(LocalClient* client, int device_ordinal, const Literal& arg) { @@ -276,6 +305,15 @@ const XlaComputation& LocalComputation::computation() const { return computation_; } +string LocalComputation::GetSerializedProto() const { + string result; + if (!computation_.proto().SerializeToString(&result)) { + LOG(ERROR) << "Failed to serialize the HloModuleProto."; + return ""; + } + return result; +} + StatusOr LocalComputation::GetReturnValueShape() const { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation_.GetProgramShape()); @@ -589,10 +627,12 @@ _FORWARD_BINOP(Or) _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) +_FORWARD_UNOP(Expm1) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) +_FORWARD_UNOP(Log1p) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) @@ -622,6 +662,54 @@ void DeleteLocalComputation(LocalComputation* computation) { delete computation; } -} // namespace swig +StatusOr DestructureLocalShapedBufferTuple( + LocalShapedBuffer* local_shaped_buffer) { + if (!ShapeUtil::IsTuple( + local_shaped_buffer->shaped_buffer()->on_device_shape())) { + return InvalidArgument( + "Attemped to destructure a LocalShapedBuffer that did not have a tuple " + "shape; shape: %s", + ShapeUtil::HumanString( + local_shaped_buffer->shaped_buffer()->on_device_shape()) + .c_str()); + } + DeviceMemoryAllocator* allocator = + local_shaped_buffer->shaped_buffer()->memory_allocator(); + ShapedBuffer tuple_buffer = local_shaped_buffer->Release(); + + // Extract some metadata we use to construct scoped buffers. + const se::Platform* platform = tuple_buffer.platform(); + int device_ordinal = tuple_buffer.device_ordinal(); + + ShapeTree& shape_tree = tuple_buffer.buffers(); + const Shape& tuple_shape = tuple_buffer.on_device_shape(); + std::vector results; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(tuple_shape); ++i) { + // Create a shaped buffer for this destructured tuple element. + const Shape& subshape = ShapeUtil::GetSubshape(tuple_shape, {i}); + VLOG(3) << "Starting tuple element " << i << " subshape: " << subshape; + ShapedBuffer shaped_buffer(subshape, subshape, platform, device_ordinal); + + ShapeUtil::ForEachSubshape( + subshape, [&](const Shape& s, const ShapeIndex& index) { + ShapeIndex original(index); + original.push_front(i); + se::DeviceMemoryBase* device_memory = + shape_tree.mutable_element(original); + shaped_buffer.set_buffer(*device_memory, index); + *device_memory = se::DeviceMemoryBase(); + }); + + VLOG(3) << "Completed tuple element: " << i; + results.push_back(new LocalShapedBuffer( + ScopedShapedBuffer(std::move(shaped_buffer), allocator))); + } + // Deallocate the root buffer. + se::DeviceMemoryBase root_buffer = tuple_buffer.root_buffer(); + TF_RETURN_IF_ERROR(allocator->Deallocate(device_ordinal, root_buffer)); + return new LocalShapedBufferTuple(std::move(results)); +} + +} // namespace swig } // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index a06b85b4ea28c4f386598901138930eaaed12079..0da3964676e9c6729229686f38bb05c8b2427bff 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { - namespace swig { // Initializes the number of replicas that XLA will be initialized with (when @@ -69,10 +68,42 @@ class LocalShapedBuffer { StatusOr > ToLiteral() const; + // Transfers ownership of the encapsulated ShapedBuffer to the caller, + // analogous to std::unique_ptr::release(). + ShapedBuffer Release(); + private: ScopedShapedBuffer shaped_buffer_; }; +// Result of a tuple destructuring operation on a LocalShapedBuffer -- this +// appears to be a simpler mechanism for the time being than an alternative like +// using SWIG to transform std::vectors into Python lists of SWIG objects +// directly. +class LocalShapedBufferTuple { + public: + // Note: any LocalShapedBuffer elements that are not Release()'d will be + // deallocated in the destructor. + explicit LocalShapedBufferTuple(std::vector elements); + + ~LocalShapedBufferTuple(); + + // Releases the ith element to the caller. Further attempts to release the ith + // element will return an invalid argument error. + StatusOr Release(int i); + + // Returns the number of elements in the destructured tuple. + int size() const; + + private: + std::vector elements_; +}; + +// Destructures a tuple-valued LocalShapedBuffer into its constitutent elements +// in LocalShapedBufferTuple form. +StatusOr DestructureLocalShapedBufferTuple( + LocalShapedBuffer* local_shaped_buffer); + // Wraps a LocalExecutable produced by compiling a // LocalComputation. The Execute method forwards to that of the // underlying LocalExecutable, and additionally handles tranferring @@ -112,6 +143,11 @@ class LocalComputation { const XlaComputation& computation() const; + // Returns the HloModuleProto contained in the XlaComputation in the + // serialized binary format. Logs an internal error and returns an empty + // string on failure. + string GetSerializedProto() const; + // Returns the return-value shape for this computation. StatusOr GetReturnValueShape() const; @@ -300,10 +336,12 @@ class LocalComputationBuilder { _FORWARD_UNOP(Not) _FORWARD_UNOP(Abs) _FORWARD_UNOP(Exp) + _FORWARD_UNOP(Expm1) _FORWARD_UNOP(Floor) _FORWARD_UNOP(Ceil) _FORWARD_UNOP(Round) _FORWARD_UNOP(Log) + _FORWARD_UNOP(Log1p) _FORWARD_UNOP(Sign) _FORWARD_UNOP(Cos) _FORWARD_UNOP(Sin) @@ -331,7 +369,6 @@ void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); void DeleteLocalComputation(LocalComputation* computation); } // namespace swig - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 04c56bbba95fbf3248df6c49700ff563c8b253c0..477df6fde25d0db760e08df9d335bd12e31ccb55 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -200,6 +200,20 @@ tensorflow::ImportNumpy(); } } +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::LocalShapedBufferTuple*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + SWIG_fail; + } +} + + %typemap(out) StatusOr< std::unique_ptr > { if ($1.ok()) { std::unique_ptr value = $1.ConsumeValueOrDie(); @@ -851,6 +865,11 @@ tensorflow::ImportNumpy(); })) { return nullptr; } + if (!HandleStringAttribute($input, "dump_unoptimized_hlo_proto_to", [&](string s) { + build_options.set_dump_unoptimized_hlo_proto_to(std::move(s)); + })) { + return nullptr; + } if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) { build_options.set_dump_per_pass_hlo_proto_to(std::move(s)); })) { @@ -900,12 +919,16 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalShapedBuffer; %unignore xla::swig::LocalShapedBuffer::FromLiteral; %unignore xla::swig::LocalShapedBuffer::ToLiteral; +%unignore xla::swig::LocalShapedBufferTuple; +%unignore xla::swig::LocalShapedBufferTuple::Release; +%unignore xla::swig::LocalShapedBufferTuple::size; %unignore xla::swig::CompiledLocalComputation; %unignore xla::swig::CompiledLocalComputation::Execute; %unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers; %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; %unignore xla::swig::LocalComputation::GetReturnValueShape; +%unignore xla::swig::LocalComputation::GetSerializedProto; %unignore xla::swig::LocalOp; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; @@ -968,10 +991,12 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Not; %unignore xla::swig::LocalComputationBuilder::Abs; %unignore xla::swig::LocalComputationBuilder::Exp; +%unignore xla::swig::LocalComputationBuilder::Expm1; %unignore xla::swig::LocalComputationBuilder::Floor; %unignore xla::swig::LocalComputationBuilder::Ceil; %unignore xla::swig::LocalComputationBuilder::Round; %unignore xla::swig::LocalComputationBuilder::Log; +%unignore xla::swig::LocalComputationBuilder::Log1p; %unignore xla::swig::LocalComputationBuilder::Sign; %unignore xla::swig::LocalComputationBuilder::Cos; %unignore xla::swig::LocalComputationBuilder::Sin; @@ -983,6 +1008,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::ReciprocalF32; %unignore xla::swig::LocalComputationBuilder::Neg; %unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::DestructureLocalShapedBufferTuple; %unignore xla::swig::DeleteLocalShapedBuffer; %unignore xla::swig::DeleteLocalComputation; %unignore xla::swig::DeleteCompiledLocalComputation; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index 1d5b75d1bee2dcee3e448d0bcb72103b539efac6..c025127c3cf1871d4def1297ed36c046cae61d4b 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -28,6 +28,7 @@ import numpy as np from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api +from tensorflow.compiler.xla.service import hlo_pb2 # Most functions are snake_case for consistency with other modules, whereas @@ -88,10 +89,12 @@ _UNARY_OPS = [ 'Not', 'Abs', 'Exp', + 'Expm1', 'Floor', 'Round', 'Ceil', 'Log', + 'Log1p', 'Sign', 'Cos', 'Sin', @@ -183,6 +186,14 @@ class LocalBuffer(object): self._delete(self.c_local_shaped_buffer) self.c_local_shaped_buffer = None + def destructure(self): + assert self.c_local_shaped_buffer is not None + result = c_api.DestructureLocalShapedBufferTuple(self.c_local_shaped_buffer) + self.c_local_shaped_buffer = None + size = result.size() + destructured = tuple(LocalBuffer(result.Release(i)) for i in xrange(size)) + return destructured + def is_deleted(self): return self.c_local_shaped_buffer is None @@ -352,6 +363,7 @@ class CompileOptions(object): def __init__(self): self.generate_hlo_graph = None self.dump_optimized_hlo_proto_to = None + self.dump_unoptimized_hlo_proto_to = None self.dump_per_pass_hlo_proto_to = None self.hlo_profile = False @@ -410,6 +422,17 @@ class LocalComputation(object): assert isinstance(c_local_computation, c_api.LocalComputation) self._delete = c_api.DeleteLocalComputation + def GetProto(self): + """Get the HloModuleProto proto object in this local computation. + + Returns: + An HloModuleProto proto object that has the whole-graph information. + """ + + serialized = self.c_local_computation.GetSerializedProto() + proto = hlo_pb2.HloModuleProto.FromString(serialized) + return proto + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): """Compiles an un-compiled local computation. @@ -1100,6 +1123,61 @@ class ComputationBuilder(object): dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) return dimension_numbers + def ConvGeneralDilated(self, lhs, rhs, window_strides, padding, lhs_dilation, + rhs_dilation, dimension_numbers): + """Enqueues a ConvGeneralDilated operation onto the computation. + + Args: + lhs: LocalOp for the rank N+2 array of inputs. + rhs: LocalOp for the rank N+2 array of kernel weights. + window_strides: length-N array-like of integer kernel strides. + padding: length-N array-like of pairs of integers of (low, high) padding. + lhs_dilation: length-N array-like of integer dilation factors. + rhs_dilation: length-N array-like of integer dilation factors. + dimension_numbers: either an xla_data_pb2.ConvolutionDimensionNumbers or a + triple (lhs_spec, rhs_spec, out_spec) where each element is a string of + length N+2 identifying by position (1) batch dimensions in lhs, rhs, and + the output with the character 'N', (2) feature dimensions in lhs and the + output with the character 'C', (3) input and output feature dimensions + in rhs with the characters 'I' and 'O' respectively, and (4) spatial + dimension correspondences between lhs, rhs, and the output using any + distinct characters. For example, to indicate dimension numbers + consistent with the Conv operation with two spatial dimensions, one + could use ('NCHW', 'OIHW', 'NCHW'). As another example, to indicate + dimension numbers consistent with the TensorFlow Conv2D operation, one + could use ('NHWC', 'HWIO', 'NHWC'). When using the latter form of + convolution dimension specification, window strides are associated with + spatial dimension character labels according to the order in which the + labels appear in the rhs_spec string, so that window_strides[0] is + matched with the dimension corresponding to the first character + appearing in rhs_spec that is not 'I' or 'O'. + + Returns: a LocalOp representing the ConvGenralDilated operation. + """ + if not isinstance(dimension_numbers, + xla_data_pb2.ConvolutionDimensionNumbers): + lhs_spec, rhs_spec, out_spec = dimension_numbers + dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() + + dimension_numbers.input_batch_dimension = lhs_spec.index('N') + dimension_numbers.input_feature_dimension = lhs_spec.index('C') + dimension_numbers.output_batch_dimension = out_spec.index('N') + dimension_numbers.output_feature_dimension = out_spec.index('C') + dimension_numbers.kernel_output_feature_dimension = rhs_spec.index('O') + dimension_numbers.kernel_input_feature_dimension = rhs_spec.index('I') + + dimension_numbers.kernel_spatial_dimensions.extend( + i for i, c in enumerate(rhs_spec) if c not in {'I', 'O'}) + dimension_numbers.input_spatial_dimensions.extend( + sorted((i for i, c in enumerate(lhs_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(lhs_spec[i]))) + dimension_numbers.output_spatial_dimensions.extend( + sorted((i for i, c in enumerate(out_spec) if c not in {'N', 'C'}), + key=lambda i: rhs_spec.index(out_spec[i]))) + return self._client.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers) + def _forward_methods_to_local_builder(): """Forward remaining ComputationBuilder methods to the C API. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index c073c02040e4d260cf760ea2b25f70d60ddd41a1..71e1d60a4e23dbfef333223c396e109533da9365 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -164,6 +164,16 @@ class ComputationsWithConstantsTest(LocalComputationTest): c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + def testGetProto(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) + built = c.Build() + proto = built.GetProto() # HloModuleProto + self.assertTrue(len(proto.computations) == 1) + self.assertTrue(len(proto.computations[0].instructions) == 3) + def testSum2DF64(self): c = self._NewComputation() c.Add( @@ -355,6 +365,55 @@ class LocalBufferTest(LocalComputationTest): with self.assertRaises(ValueError): compiled_c.ExecuteWithLocalBuffers([arg_buffer]) + def testDestructureTupleEmpty(self): + t = () + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 0) + + def testDestructureTupleOneArrayElement(self): + t = (np.array([1, 2, 3, 4], dtype=np.int32),) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 1) + array = pieces[0] + got = array.to_py() + want = NumpyArrayS32([1, 2, 3, 4]) + np.testing.assert_equal(want, got) + + def testDestructureTupleTwoArrayElementDifferentType(self): + t = (np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32), + np.array([2, 3, 4, 5], dtype=np.int32)) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 2) + array0, array1 = pieces + got = array0.to_py() + want = NumpyArrayF32([1.0, 2.0, 3.0, 4.0]) + np.testing.assert_equal(want, got) + got = array1.to_py() + want = NumpyArrayS32([2, 3, 4, 5]) + np.testing.assert_equal(want, got) + + def testDestructureTupleNested(self): + t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5])) + local_buffer = xla_client.LocalBuffer.from_pyval(t) + pieces = local_buffer.destructure() + self.assertTrue(local_buffer.is_deleted()) + self.assertEqual(len(pieces), 2) + tuple0, array1 = pieces + got = array1.to_py() + want = NumpyArrayS32([5]) + np.testing.assert_equal(want, got) + got = tuple0.to_py() + self.assertEqual(type(got), tuple) + self.assertEqual(len(got), 2) + np.testing.assert_equal(NumpyArrayF32([1.0, 2.0]), got[0]) + np.testing.assert_equal(NumpyArrayS32([3, 4]), got[1]) + class SingleOpTest(LocalComputationTest): """Tests for single ops. @@ -509,6 +568,46 @@ class SingleOpTest(LocalComputationTest): [40., 50., 0.]]]]) self._ExecuteAndCompareClose(c, expected=result) + def testConvGeneralDilatedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + dimension_numbers = ("NCHW", "OIHW", "NCHW") + c.ConvGeneralDilated(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testConvGeneralDilatedPermutedF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + + dimension_numbers = ("NHWC", "OIHW", "CWNH") + c.ConvGeneralDilated(c.Constant(np.transpose(lhs, (0, 2, 3, 1))), + c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation, + dimension_numbers) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=np.transpose(result, (1, 3, 0, 2))) + def testBooleanNot(self): c = self._NewComputation() arr = NumpyArrayBool([True, False, True]) @@ -521,6 +620,12 @@ class SingleOpTest(LocalComputationTest): c.Exp(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + def testExpm1(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Expm1(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.expm1(arr)) + def testRound(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) @@ -533,6 +638,12 @@ class SingleOpTest(LocalComputationTest): c.Log(c.Constant(arr)) self._ExecuteAndCompareClose(c, expected=np.log(arr)) + def testLog1p(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Log1p(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.log1p(arr)) + def testNeg(self): c = self._NewComputation() arr = NumpyArrayF32([3.3, 12.1]) diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 2698ba7d79e246530b6b486d3e3bc8bf101c891e..8fa6961d197dce519cf151283b8bc0836a4615c0 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -265,9 +265,9 @@ class ReferenceUtil { const Array3D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 3); - std::vector lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3()}; - std::vector rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3()}; - std::vector out_dims = {rhs.n1(), rhs.n2(), rhs.n3()}; + const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3()}; + const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()}; + int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()}; for (int i = 0; i < 3; ++i) { if (i != concatenate_dimension) { out_dims[i] = lhs_dims[i]; @@ -299,9 +299,9 @@ class ReferenceUtil { const Array4D& rhs, int concatenate_dimension) { CHECK(0 <= concatenate_dimension && concatenate_dimension < 4); - std::vector lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}; - std::vector rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; - std::vector out_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; + const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}; + const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; + int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}; for (int i = 0; i < 4; ++i) { if (i != concatenate_dimension) { out_dims[i] = lhs_dims[i]; @@ -553,12 +553,11 @@ class ReferenceUtil { const NativeT pad) { CHECK_EQ(padding.dimensions_size(), 3); - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3()}; - std::vector pad_low(3); - std::vector pad_high(3); - std::vector pad_interior(3); - std::vector output_bounds(3); + const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3()}; + int64 pad_low[3]; + int64 pad_high[3]; + int64 pad_interior[3]; + int64 output_bounds[3]; for (int64 i = 0; i < 3; ++i) { pad_low[i] = padding.dimensions(i).edge_padding_low(); pad_high[i] = padding.dimensions(i).edge_padding_high(); @@ -574,7 +573,7 @@ class ReferenceUtil { Array3D result(output_bounds[0], output_bounds[1], output_bounds[2]); - std::vector indices = {0, 0, 0}; + int indices[] = {0, 0, 0}; for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { @@ -612,12 +611,12 @@ class ReferenceUtil { const NativeT pad) { CHECK_EQ(padding.dimensions_size(), 4); - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3(), operand.n4()}; - std::vector pad_low(4); - std::vector pad_high(4); - std::vector pad_interior(4); - std::vector output_bounds(4); + const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3(), + operand.n4()}; + int64 pad_low[4]; + int64 pad_high[4]; + int64 pad_interior[4]; + int64 output_bounds[4]; for (int64 i = 0; i < 4; ++i) { pad_low[i] = padding.dimensions(i).edge_padding_low(); pad_high[i] = padding.dimensions(i).edge_padding_high(); diff --git a/tensorflow/compiler/xla/rpc/BUILD b/tensorflow/compiler/xla/rpc/BUILD index 0d56a9a477b15964ad45e798865aa8d2c7385073..1775666652f303ee095a11405537106a3eb9b056 100644 --- a/tensorflow/compiler/xla/rpc/BUILD +++ b/tensorflow/compiler/xla/rpc/BUILD @@ -42,7 +42,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "@grpc//:grpc++_unsecure", + "@grpc//:grpc++", ], ) @@ -61,7 +61,7 @@ tf_cc_test( "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@grpc//:grpc++_unsecure", + "@grpc//:grpc++", ], ) @@ -74,6 +74,6 @@ cc_library( "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "@grpc//:grpc++_unsecure", + "@grpc//:grpc++", ], ) diff --git a/tensorflow/compiler/xla/rpc/grpc_client_test.cc b/tensorflow/compiler/xla/rpc/grpc_client_test.cc index 313f11a9a957155eb277dc02ba5d2565c87e0235..d7dd9786a2bbde2d18ae81a9a9d4cc4b2cc38411 100644 --- a/tensorflow/compiler/xla/rpc/grpc_client_test.cc +++ b/tensorflow/compiler/xla/rpc/grpc_client_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include -#include "grpc++/create_channel.h" -#include "grpc++/security/credentials.h" +#include "grpcpp/create_channel.h" +#include "grpcpp/security/credentials.h" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" diff --git a/tensorflow/compiler/xla/rpc/grpc_service.cc b/tensorflow/compiler/xla/rpc/grpc_service.cc index 5f4dc6bd08f18b50e60b173432d3d305759bccea..4e1435fa30a24c320ddbedb84d37b369a3158a54 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service.cc @@ -32,19 +32,6 @@ namespace xla { return tensorflow::ToGrpcStatus(s); } -::grpc::Status GRPCService::Computation(::grpc::ServerContext* context, - const ComputationRequest* arg, - ComputationResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Computation(arg, result); }); -} - -::grpc::Status GRPCService::CreateOp(::grpc::ServerContext* context, - const OpRequest* arg, OpResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Op(arg, result); }); -} - ::grpc::Status GRPCService::Unregister(::grpc::ServerContext* context, const UnregisterRequest* arg, UnregisterResponse* result) { @@ -60,21 +47,6 @@ namespace xla { }); } -::grpc::Status GRPCService::SetReturnValue(::grpc::ServerContext* context, - const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - return DelegateRPC([this, arg, results]() { - return service_->SetReturnValue(arg, results); - }); -} - -::grpc::Status GRPCService::Execute(::grpc::ServerContext* context, - const ExecuteRequest* arg, - ExecuteResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->Execute(arg, result); }); -} - ::grpc::Status GRPCService::ExecuteGraph(::grpc::ServerContext* /*context*/, const ExecuteGraphRequest* arg, ExecuteResponse* result) { @@ -82,13 +54,6 @@ namespace xla { [this, arg, result]() { return service_->ExecuteGraph(arg, result); }); } -::grpc::Status GRPCService::ExecuteAsync(::grpc::ServerContext* context, - const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->ExecuteAsync(arg, result); }); -} - ::grpc::Status GRPCService::WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) { @@ -136,20 +101,6 @@ namespace xla { [this, arg, result]() { return service_->ResetDevice(arg, result); }); } -::grpc::Status GRPCService::IsConstant(::grpc::ServerContext* context, - const IsConstantRequest* arg, - IsConstantResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->IsConstant(arg, result); }); -} - -::grpc::Status GRPCService::ComputeConstant(::grpc::ServerContext* context, - const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->ComputeConstant(arg, result); }); -} - ::grpc::Status GRPCService::GetShape(::grpc::ServerContext* context, const GetShapeRequest* arg, GetShapeResponse* result) { @@ -157,43 +108,4 @@ namespace xla { [this, arg, result]() { return service_->GetShape(arg, result); }); } -::grpc::Status GRPCService::GetComputationShape( - ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->GetComputationShape(arg, result); - }); -} - -::grpc::Status GRPCService::GetLocalShape(::grpc::ServerContext* context, - const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - return DelegateRPC( - [this, arg, result]() { return service_->GetLocalShape(arg, result); }); -} - -::grpc::Status GRPCService::GetComputationStats( - ::grpc::ServerContext* context, const ComputationStatsRequest* arg, - ComputationStatsResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->GetComputationStats(arg, result); - }); -} - -::grpc::Status GRPCService::SnapshotComputation( - ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->SnapshotComputation(arg, result); - }); -} - -::grpc::Status GRPCService::LoadComputationSnapshot( - ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - return DelegateRPC([this, arg, result]() { - return service_->LoadComputationSnapshot(arg, result); - }); -} - } // namespace xla diff --git a/tensorflow/compiler/xla/rpc/grpc_service.h b/tensorflow/compiler/xla/rpc/grpc_service.h index 50f02796f2d45baf894841782cd96d8d51a5ba00..ca1b09b648013ad45d806040c5ddcf11d9e5604e 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service.h +++ b/tensorflow/compiler/xla/rpc/grpc_service.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ #define TENSORFLOW_COMPILER_XLA_RPC_GRPC_SERVICE_H_ -#include "grpc++/server_context.h" +#include "grpcpp/server_context.h" #include "tensorflow/compiler/xla/rpc/xla_service.grpc.pb.h" #include "tensorflow/compiler/xla/service/service.h" @@ -31,13 +31,6 @@ class GRPCService : public grpc::XlaService::Service { static StatusOr> NewService( se::Platform* platform = nullptr); - ::grpc::Status Computation(::grpc::ServerContext* context, - const ComputationRequest* arg, - ComputationResponse* result) override; - - ::grpc::Status CreateOp(::grpc::ServerContext* context, const OpRequest* arg, - OpResponse* result) override; - ::grpc::Status Unregister(::grpc::ServerContext* context, const UnregisterRequest* arg, UnregisterResponse* result) override; @@ -46,22 +39,10 @@ class GRPCService : public grpc::XlaService::Service { const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - ::grpc::Status SetReturnValue(::grpc::ServerContext* context, - const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - ::grpc::Status Execute(::grpc::ServerContext* context, - const ExecuteRequest* arg, - ExecuteResponse* result) override; - ::grpc::Status ExecuteGraph(::grpc::ServerContext* context, const ExecuteGraphRequest* arg, ExecuteResponse* result) override; - ::grpc::Status ExecuteAsync(::grpc::ServerContext* context, - const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - ::grpc::Status WaitForExecution(::grpc::ServerContext* context, const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override; @@ -86,38 +67,10 @@ class GRPCService : public grpc::XlaService::Service { const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - ::grpc::Status IsConstant(::grpc::ServerContext* context, - const IsConstantRequest* arg, - IsConstantResponse* result) override; - - ::grpc::Status ComputeConstant(::grpc::ServerContext* context, - const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - ::grpc::Status GetShape(::grpc::ServerContext* context, const GetShapeRequest* arg, GetShapeResponse* result) override; - ::grpc::Status GetComputationShape( - ::grpc::ServerContext* context, const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ::grpc::Status GetLocalShape(::grpc::ServerContext* context, - const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - - ::grpc::Status GetComputationStats(::grpc::ServerContext* context, - const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - ::grpc::Status SnapshotComputation( - ::grpc::ServerContext* context, const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - ::grpc::Status LoadComputationSnapshot( - ::grpc::ServerContext* context, const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; - private: std::unique_ptr<::xla::Service> service_; diff --git a/tensorflow/compiler/xla/rpc/grpc_service_main.cc b/tensorflow/compiler/xla/rpc/grpc_service_main.cc index e29908ccec80db76e3b5b856e57382c56430c379..c68c857c304138ff4318e243f66547c6acce1005 100644 --- a/tensorflow/compiler/xla/rpc/grpc_service_main.cc +++ b/tensorflow/compiler/xla/rpc/grpc_service_main.cc @@ -15,9 +15,9 @@ limitations under the License. // Basic server binary that exposes a xla::Service through a GRPC interface // on a configurable port. -#include "grpc++/security/server_credentials.h" -#include "grpc++/server.h" -#include "grpc++/server_builder.h" +#include "grpcpp/security/server_credentials.h" +#include "grpcpp/server.h" +#include "grpcpp/server_builder.h" #include "tensorflow/compiler/xla/rpc/grpc_service.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/init_main.h" diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.cc b/tensorflow/compiler/xla/rpc/grpc_stub.cc index 620ac6cec4f76d938e57e87849066df59514938a..7b8ab158e1396d7087a407be180ab44d2e16e121 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.cc +++ b/tensorflow/compiler/xla/rpc/grpc_stub.cc @@ -62,21 +62,6 @@ Status GRPCStub::ResetDevice(const ResetDeviceRequest* request, }); } -Status GRPCStub::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->LoadComputationSnapshot(context, *request, response); - }); -} - -Status GRPCStub::Execute(const ExecuteRequest* request, - ExecuteResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->Execute(context, *request, response); - }); -} - Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, ExecuteResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -84,13 +69,6 @@ Status GRPCStub::ExecuteGraph(const ExecuteGraphRequest* request, }); } -Status GRPCStub::ExecuteParallel(const ExecuteParallelRequest* request, - ExecuteParallelResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteParallel(context, *request, response); - }); -} - Status GRPCStub::ExecuteGraphParallel( const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) { @@ -99,13 +77,6 @@ Status GRPCStub::ExecuteGraphParallel( }); } -Status GRPCStub::ExecuteAsync(const ExecuteAsyncRequest* request, - ExecuteAsyncResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ExecuteAsync(context, *request, response); - }); -} - Status GRPCStub::WaitForExecution(const WaitForExecutionRequest* request, WaitForExecutionResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -120,13 +91,6 @@ Status GRPCStub::DeconstructTuple(const DeconstructTupleRequest* request, }); } -Status GRPCStub::GetComputationStats(const ComputationStatsRequest* request, - ComputationStatsResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetComputationStats(context, *request, response); - }); -} - Status GRPCStub::GetComputationGraphStats( const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) { @@ -135,13 +99,6 @@ Status GRPCStub::GetComputationGraphStats( }); } -Status GRPCStub::GetComputationShape(const GetComputationShapeRequest* request, - GetComputationShapeResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetComputationShape(context, *request, response); - }); -} - Status GRPCStub::GetShape(const GetShapeRequest* request, GetShapeResponse* response) { return MakeRPC([this, request, response](::grpc::ClientContext* context) { @@ -163,48 +120,6 @@ Status GRPCStub::CreateChannelHandle(const CreateChannelHandleRequest* request, }); } -// Methods used by ComputationBuilder. -Status GRPCStub::Computation(const ComputationRequest* request, - ComputationResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->Computation(context, *request, response); - }); -} - -Status GRPCStub::Op(const OpRequest* request, OpResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->CreateOp(context, *request, response); - }); -} - -Status GRPCStub::GetLocalShape(const GetLocalShapeRequest* request, - GetLocalShapeResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->GetLocalShape(context, *request, response); - }); -} - -Status GRPCStub::SetReturnValue(const SetReturnValueRequest* request, - SetReturnValueResponse* responses) { - return MakeRPC([this, request, responses](::grpc::ClientContext* context) { - return grpc_stub_->SetReturnValue(context, *request, responses); - }); -} - -Status GRPCStub::IsConstant(const IsConstantRequest* request, - IsConstantResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->IsConstant(context, *request, response); - }); -} - -Status GRPCStub::ComputeConstant(const ComputeConstantRequest* request, - ComputeConstantResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->ComputeConstant(context, *request, response); - }); -} - Status GRPCStub::ComputeConstantGraph( const ComputeConstantGraphRequest* request, ComputeConstantResponse* response) { @@ -213,14 +128,6 @@ Status GRPCStub::ComputeConstantGraph( }); } -// Methods used by Computation. -Status GRPCStub::SnapshotComputation(const SnapshotComputationRequest* request, - SnapshotComputationResponse* response) { - return MakeRPC([this, request, response](::grpc::ClientContext* context) { - return grpc_stub_->SnapshotComputation(context, *request, response); - }); -} - // Methods used by GlobalData. Status GRPCStub::Unregister(const UnregisterRequest* request, UnregisterResponse* response) { diff --git a/tensorflow/compiler/xla/rpc/grpc_stub.h b/tensorflow/compiler/xla/rpc/grpc_stub.h index 5906d45769b5749b0c590dbc0e1972077dc3e7ba..8dfcb761387d608abbb1f62974f49b976a7ff7ff 100644 --- a/tensorflow/compiler/xla/rpc/grpc_stub.h +++ b/tensorflow/compiler/xla/rpc/grpc_stub.h @@ -43,39 +43,21 @@ class GRPCStub : public ServiceInterface { Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) override; - - Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; - Status ExecuteGraph(const ExecuteGraphRequest* request, ExecuteResponse* response) override; - Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; - Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* request, ExecuteParallelResponse* response) override; - Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - Status WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override; Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - Status GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - Status GetComputationGraphStats(const ComputationGraphStatsRequest* request, ComputationStatsResponse* response) override; - Status GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - Status GetShape(const GetShapeRequest* arg, GetShapeResponse* result) override; @@ -85,30 +67,9 @@ class GRPCStub : public ServiceInterface { Status CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) override; - // Methods used by ComputationBuilder. - Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - - Status Op(const OpRequest* arg, OpResponse* result) override; - Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - - Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; - - Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; - Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) override; - // Methods used by Computation. - Status SnapshotComputation(const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) override; - // Methods used by GlobalData. Status Unregister(const UnregisterRequest* arg, UnregisterResponse* result) override; diff --git a/tensorflow/compiler/xla/rpc/xla_service.proto b/tensorflow/compiler/xla/rpc/xla_service.proto index c47164ee1b7657ae378a053f553442bee751753e..551ae895e05586daec0ffcd425f4950f76bdd50d 100644 --- a/tensorflow/compiler/xla/rpc/xla_service.proto +++ b/tensorflow/compiler/xla/rpc/xla_service.proto @@ -75,19 +75,7 @@ service XlaService { rpc GetShape(GetShapeRequest) returns (GetShapeResponse) { } - // Requests the program shape of the referenced computation. - rpc GetComputationShape(GetComputationShapeRequest) - returns (GetComputationShapeResponse) { - } - - // Requests the statistics of the given computation. - rpc GetComputationStats(ComputationStatsRequest) - returns (ComputationStatsResponse) { - } - // Requests the statistics of the given computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. rpc GetComputationGraphStats(ComputationGraphStatsRequest) returns (ComputationStatsResponse) { } @@ -121,25 +109,12 @@ service XlaService { rpc ResetDevice(ResetDeviceRequest) returns (ResetDeviceResponse) { } - // Tests if an expression is a compile-time constant. - rpc IsConstant(IsConstantRequest) returns (IsConstantResponse) { - } - - // Computes the value of a constant expression. - rpc ComputeConstant(ComputeConstantRequest) - returns (ComputeConstantResponse) { - } - // Computes the value of a constant expression. The request contains the // computation graph for the constant expression. rpc ComputeConstantGraph(ComputeConstantGraphRequest) returns (ComputeConstantResponse) { } - // Retrieves the inferred shape for a value within a computation. - rpc GetLocalShape(GetLocalShapeRequest) returns (GetLocalShapeResponse) { - } - // Requests one or more device handles from the target. The returned device // handles can be used to specify the device on which to execute computations // or transfer data. @@ -153,32 +128,6 @@ service XlaService { returns (CreateChannelHandleResponse) { } - // Requests that the referenced computation be specialized for the provided - // arguments for subsequent execution. This permits things such as value - // specialization. - rpc Specialize(SpecializeRequest) returns (SpecializeResponse) { - } - - // Modifies the provided computation so that subsequent executions - // will compute the provided ComputationDataHandle, rather than the - // last expression enqueued on that Computation. - rpc SetReturnValue(SetReturnValueRequest) returns (SetReturnValueResponse) { - } - - // Computation creates a new computation with the given name. - // A unique ComputationHandle is returned. - rpc Computation(ComputationRequest) returns (ComputationResponse) { - } - - // Adds a new op to a computation. - rpc CreateOp(OpRequest) returns (OpResponse) { - } - - // Invokes the provided computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - rpc Execute(ExecuteRequest) returns (ExecuteResponse) { - } - // Invokes the provided computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. @@ -188,38 +137,13 @@ service XlaService { // Invokes the provided list of computations in parallel with the provided // global data for each computation. Returns a list of global data output and // execution timing. - rpc ExecuteParallel(ExecuteParallelRequest) - returns (ExecuteParallelResponse) { - } - - // Invokes the provided list of computations in parallel with the provided - // global data for each computation. Returns a list of global data output and - // execution timing. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. rpc ExecuteGraphParallel(ExecuteGraphParallelRequest) returns (ExecuteParallelResponse) { } - // Invokes the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - rpc ExecuteAsync(ExecuteAsyncRequest) returns (ExecuteAsyncResponse) { - } - // Waits until the given execution (aysnchronously launched) is complete, and // returns the global data output. rpc WaitForExecution(WaitForExecutionRequest) returns (WaitForExecutionResponse) { } - - // Serializes a computation to proto form, so it can be loaded via - // LoadComputationSnapshot. - rpc SnapshotComputation(SnapshotComputationRequest) - returns (SnapshotComputationResponse) { - } - - // Loads a computation from a captured snapshot. - rpc LoadComputationSnapshot(LoadComputationSnapshotRequest) - returns (LoadComputationSnapshotResponse) { - } } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index d1722644c72646538dab77899b79d25056f2f2bf..2942edbf71f29304ebb240f0547808ae0af1ac87 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -16,19 +16,22 @@ load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) xla_proto_library( - name = "session_proto", - srcs = ["session.proto"], + name = "hlo_proto", + srcs = ["hlo.proto"], visibility = ["//visibility:public"], deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) -xla_proto_library( - name = "hlo_proto", +tf_proto_library_py( + name = "hlo_proto", # bzl adds a _py suffix only to the OSS target. srcs = ["hlo.proto"], visibility = ["//visibility:public"], - deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) xla_proto_library( @@ -266,6 +269,7 @@ cc_library( "dfs_hlo_visitor.cc", "hlo_computation.cc", "hlo_instruction.cc", + "hlo_instructions.cc", "hlo_module.cc", "hlo_opcode.cc", "hlo_sharding.cc", @@ -273,18 +277,21 @@ cc_library( hdrs = [ "dfs_hlo_visitor.h", "dfs_hlo_visitor_with_default.h", + "hlo_clone_context.h", "hlo_computation.h", + "hlo_domain_metadata.h", "hlo_instruction.h", + "hlo_instructions.h", "hlo_module.h", "hlo_opcode.h", "hlo_sharding.h", ], deps = [ + ":hlo_casting_utils", ":hlo_module_config", ":hlo_proto", ":hlo_reachability", ":name_uniquer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -297,6 +304,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:human_readable_json", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], @@ -336,8 +344,8 @@ tf_cc_test( ":hlo", ":pattern_matcher", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) @@ -375,6 +383,7 @@ cc_library( deps = [ ":hlo", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -385,20 +394,9 @@ tf_cc_test( srcs = ["hlo_matchers_test.cc"], deps = [ ":hlo_matchers", + ":hlo_parser", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", - ], -) - -cc_library( - name = "versioned_computation_handle", - srcs = ["versioned_computation_handle.cc"], - hdrs = ["versioned_computation_handle.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", ], ) @@ -407,12 +405,14 @@ tf_cc_test( srcs = ["hlo_instruction_test.cc"], deps = [ ":hlo", + ":hlo_parser", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -429,6 +429,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], @@ -533,45 +534,6 @@ tf_cc_test( ], ) -cc_library( - name = "user_computation", - srcs = ["user_computation.cc"], - hdrs = ["user_computation.h"], - deps = [ - ":hlo", - ":session_proto", - ":shape_inference", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "user_computation_test", - srcs = ["user_computation_test.cc"], - deps = [ - ":hlo_matchers", - ":user_computation", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - ], -) - cc_library( name = "platform_util", srcs = ["platform_util.cc"], @@ -617,10 +579,8 @@ cc_library( ":allocation_tracker", ":backend", ":channel_tracker", - ":compilation_cache", ":compiler", ":computation_layout", - ":computation_tracker", ":device_memory_allocator", ":executable", ":execution_tracker", @@ -631,11 +591,8 @@ cc_library( ":hlo_module_config", ":hlo_proto_util", ":platform_util", - ":session_proto", ":source_map_util", ":transfer_manager", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:service_interface", @@ -662,7 +619,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":computation_tracker", ":device_memory_allocator", ":executable", ":hlo", @@ -671,8 +627,6 @@ cc_library( ":platform_util", ":service", ":shaped_buffer", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -696,7 +650,6 @@ cc_library( ":backend", ":compiler", ":computation_layout", - ":computation_tracker", ":platform_util", ":service", "//tensorflow/compiler/xla:status_macros", @@ -793,9 +746,7 @@ cc_library( ":hlo_graph_dumper", ":hlo_proto", ":pool", - ":session_proto", ":shaped_buffer", - ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -891,34 +842,12 @@ cc_library( ], ) -cc_library( - name = "computation_tracker", - srcs = ["computation_tracker.cc"], - hdrs = ["computation_tracker.h"], - deps = [ - ":hlo", - ":hlo_module_config", - ":session_proto", - ":user_computation", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "channel_tracker", srcs = ["channel_tracker.cc"], hdrs = ["channel_tracker.h"], deps = [ ":hlo", - ":session_proto", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1024,7 +953,6 @@ tf_cc_test( ":buffer_assignment", ":buffer_value", ":call_graph", - ":computation_tracker", ":copy_insertion", ":cpu_plugin", ":flatten_call_graph", @@ -1038,9 +966,9 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -1076,9 +1004,9 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1179,9 +1107,9 @@ tf_cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1214,9 +1142,22 @@ tf_cc_test( deps = [ ":hlo_matchers", ":instruction_fusion", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", ], ) @@ -1388,9 +1329,9 @@ tf_cc_test( deps = [ ":gather_expander", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:test_macros_header", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1696,14 +1637,11 @@ tf_cc_test( name = "hlo_cost_analysis_test", srcs = ["hlo_cost_analysis_test.cc"], deps = [ - ":computation_tracker", ":cpu_plugin", ":hlo", ":hlo_cost_analysis", ":local_service", ":service", - ":user_computation", - ":versioned_computation_handle", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", @@ -1742,9 +1680,9 @@ tf_cc_test( ":cpu_plugin", ":hlo_cost_analysis", ":hlo_execution_profile", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -1925,9 +1863,9 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -2044,20 +1982,6 @@ tf_cc_test( ], ) -cc_library( - name = "compilation_cache", - srcs = ["compilation_cache.cc"], - hdrs = ["compilation_cache.h"], - deps = [ - ":executable", - ":hlo_module_config", - ":versioned_computation_handle", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - ], -) - cc_library( name = "layout_assignment", srcs = [ @@ -2228,6 +2152,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -2262,11 +2187,11 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -2288,9 +2213,9 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -2338,6 +2263,7 @@ cc_library( hdrs = ["hlo_cse.h"], deps = [ ":hlo", + ":hlo_domain_map", ":hlo_pass", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -2360,10 +2286,10 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -2402,6 +2328,78 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_domain_map", + srcs = ["hlo_domain_map.cc"], + hdrs = ["hlo_domain_map.h"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_sharding_metadata", + srcs = ["hlo_sharding_metadata.cc"], + hdrs = [ + "hlo_sharding_metadata.h", + ], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_domain_isolator", + srcs = ["hlo_domain_isolator.cc"], + hdrs = ["hlo_domain_isolator.h"], + deps = [ + ":hlo", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + ], +) + +cc_library( + name = "hlo_domain_remover", + srcs = ["hlo_domain_remover.cc"], + hdrs = ["hlo_domain_remover.h"], + deps = [ + ":hlo", + ":hlo_domain_isolator", + ":hlo_domain_map", + ":hlo_graph_dumper", + ":hlo_pass", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_domain_test", + srcs = ["hlo_domain_test.cc"], + deps = [ + ":hlo", + ":hlo_domain_isolator", + ":hlo_domain_remover", + ":hlo_parser", + ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_element_type_converter", srcs = ["hlo_element_type_converter.cc"], @@ -2483,10 +2481,10 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2544,7 +2542,6 @@ cc_library( name = "hlo_tfgraph_builder", srcs = ["hlo_tfgraph_builder.cc"], hdrs = ["hlo_tfgraph_builder.h"], - visibility = ["//tensorflow/compiler/xla/tools:__pkg__"], deps = [ ":hlo", "//tensorflow/compiler/xla:literal_util", @@ -2575,6 +2572,7 @@ cc_library( hdrs = ["hlo_graph_dumper.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_execution_profile", ":hlo_tfgraph_builder", "//tensorflow/compiler/xla:literal_util", @@ -2632,10 +2630,10 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -2772,7 +2770,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -2808,8 +2806,8 @@ tf_cc_test( ":tuple_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2832,9 +2830,10 @@ tf_cc_test( deps = [ ":while_util", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -2860,6 +2859,7 @@ tf_cc_test( ":hlo_matchers", ":while_loop_invariant_code_motion", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/core:test", ], @@ -2886,8 +2886,8 @@ tf_cc_test( ":hlo_matchers", ":while_loop_constant_sinking", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:test", ], ) @@ -2925,6 +2925,7 @@ cc_library( hdrs = ["indexed_array_analysis.h"], deps = [ ":hlo", + ":hlo_evaluator", ":hlo_pass", "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", @@ -2939,9 +2940,75 @@ tf_cc_test( ":hlo_matchers", ":indexed_array_analysis", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo", + ":hlo_lexer", + ":hlo_sharding_metadata", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_parser", + "//tensorflow/compiler/xla:window_util", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + "hlo_token.h", + ], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + ], +) + +cc_library( + name = "hlo_casting_utils", + hdrs = ["hlo_casting_utils.h"], + deps = ["//tensorflow/core:lib"], +) + +tf_cc_test( + name = "hlo_casting_utils_test", + srcs = ["hlo_casting_utils_test.cc"], + deps = [ + ":hlo", + ":hlo_casting_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index f732ed8f398c4699bd5247dc7fa1e9677340dcae..3b36939b8a6900f047bbec225aa232e0e805b5d1 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -157,6 +157,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleSubtract(HloInstruction* sub) override; + Status HandleMap(HloInstruction* map) override; + Status HandleMaximum(HloInstruction* maximum) override; Status HandleMinimum(HloInstruction* minimum) override; @@ -231,10 +233,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand); - // A Reshape or Broadcast that feeds an element-wise operation with a unique - // non-scalar operand can sink to after the operation. - StatusOr TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* reshape_or_broadcast); + // A Broadcast that feeds an element-wise operation with a unique non-scalar + // operand can sink to after the operation. + StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast); // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. @@ -1303,7 +1305,7 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // broadcast after the unary element-wise operation. TF_ASSIGN_OR_RETURN( bool sink_succeeded, - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); + TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); changed_ |= sink_succeeded; if (sink_succeeded) { return Status::OK(); @@ -1555,15 +1557,16 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { return Status::OK(); } -StatusOr AlgebraicSimplifierVisitor:: - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* reshape_or_broadcast) { +StatusOr +AlgebraicSimplifierVisitor::TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast) { + TF_RET_CHECK(broadcast->opcode() == HloOpcode::kBroadcast); bool changed = false; - if (ShapeUtil::IsScalar(reshape_or_broadcast->shape())) { + if (ShapeUtil::IsScalar(broadcast->shape())) { return false; } - HloInstruction* operand = reshape_or_broadcast->mutable_operand(0); - for (HloInstruction* user : reshape_or_broadcast->users()) { + HloInstruction* operand = broadcast->mutable_operand(0); + for (HloInstruction* user : broadcast->users()) { if (user->user_count() == 0 && user != computation_->root_instruction()) { continue; } @@ -1581,55 +1584,50 @@ StatusOr AlgebraicSimplifierVisitor:: continue; } - int64 reshape_or_broadcast_operand_index = -1; // Find the unique non-scalar operand or continue if there isn't one. - int64 scalar_count = 0; - for (int64 i = 0; i < user->operand_count(); ++i) { - if (ShapeUtil::IsScalar(user->operand(i)->shape())) { - ++scalar_count; - } else { - reshape_or_broadcast_operand_index = i; + int64 scalar_broadcast_count = 0; + int64 broadcast_use_count = 0; + for (HloInstruction* user_operand : user->operands()) { + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + ++scalar_broadcast_count; + } else if (broadcast == user_operand) { + ++broadcast_use_count; } } - if (scalar_count != user->operand_count() - 1) { + if (scalar_broadcast_count + broadcast_use_count != user->operand_count()) { continue; } - VLOG(4) << "Sinking reshape or broadcast after user:"; - VLOG(4) << " old reshape/broadcast: " << reshape_or_broadcast->ToString(); + std::vector new_operands; + new_operands.reserve(user->operand_count()); + + for (HloInstruction* user_operand : user->operands()) { + if (user_operand->opcode() == HloOpcode::kBroadcast && + ShapeUtil::IsScalar(user_operand->operand(0)->shape())) { + new_operands.push_back( + computation_->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::ChangeElementType( + operand->shape(), user_operand->shape().element_type()), + user_operand->mutable_operand(0), {}))); + } else { + CHECK_EQ(broadcast, user_operand); + new_operands.push_back(operand); + } + } + VLOG(4) << "Sinking broadcast after user:"; + VLOG(4) << " old broadcast: " << broadcast->ToString(); VLOG(4) << " old user: " << user->ToString(); - CHECK_EQ(user->operand(reshape_or_broadcast_operand_index), - reshape_or_broadcast); - auto new_user_operands = user->operands(); - new_user_operands[reshape_or_broadcast_operand_index] = operand; - auto new_user = computation_->AddInstruction(user->CloneWithNewOperands( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(operand->shape().dimensions()), - LayoutUtil::MinorToMajor(operand->shape())), - new_user_operands)); + HloInstruction* new_user = + computation_->AddInstruction(user->CloneWithNewOperands( + ShapeUtil::ChangeElementType(operand->shape(), + user->shape().element_type()), + new_operands)); VLOG(4) << " new user: " << new_user->ToString(); - HloInstruction* new_reshape_or_broadcast = nullptr; - if (reshape_or_broadcast->opcode() == HloOpcode::kReshape) { - new_reshape_or_broadcast = - computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), - new_user)); - } else { - TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast); - new_reshape_or_broadcast = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShapeWithLayout( - user->shape().element_type(), - AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), - new_user, reshape_or_broadcast->dimensions())); - } - VLOG(4) << " new reshape/broadcast: " - << new_reshape_or_broadcast->ToString(); - TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_reshape_or_broadcast)); + HloInstruction* new_broadcast = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + user->shape(), new_user, broadcast->dimensions())); + VLOG(4) << " new broadcast: " << new_broadcast->ToString(); + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_broadcast)); changed = true; } return changed; @@ -1672,16 +1670,6 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { } } - // A Reshape that feeds a unary element-wise operation can sink the - // reshape after the unary element-wise operation. - TF_ASSIGN_OR_RETURN( - bool sink_succeeded, - TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape)); - changed_ |= sink_succeeded; - if (sink_succeeded) { - return Status::OK(); - } - // Make this a bitcast if possible. if (is_layout_sensitive_ && ReshapeIsBitcast(reshape, valid_bitcast_callback_)) { @@ -1786,6 +1774,46 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } + // If the reduction results in the same number of elements, then the only + // possible side effect would be a reshape. Since the init_value is an + // identity of the reduction function, we can therefore replace the reduce + // with a simple reshape, ignoring the reduction function completely. + if (ShapeUtil::ElementsIn(reduce->shape()) == + ShapeUtil::ElementsIn(arg->shape())) { + return ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReshape(reduce->shape(), arg)); + } + + // If a reduce feeds a reduce with the same computation and initial value, + // they can be combined into a single reduce. + if (arg->opcode() == HloOpcode::kReduce && + init_value->Identical(*arg->operand(1)) && + *function == *arg->to_apply()) { + // Create a new reduce with the combined reduction dimensions of both + // reduces. + std::vector arg_dims = arg->dimensions(); + std::sort(arg_dims.begin(), arg_dims.end()); + std::vector reduce_dims = reduce->dimensions(); + std::sort(reduce_dims.begin(), reduce_dims.end()); + // Transform reduce_dims to the same rank as the operand of the operand. + for (int64 arg_dim : arg_dims) { + for (int64& dim : reduce_dims) { + if (dim >= arg_dim) { + ++dim; + } + } + } + std::vector new_dimensions; + new_dimensions.reserve(arg->dimensions().size() + + reduce->dimensions().size()); + std::merge(arg_dims.begin(), arg_dims.end(), reduce_dims.begin(), + reduce_dims.end(), std::back_inserter(new_dimensions)); + return ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateReduce(reduce->shape(), arg->mutable_operand(0), + init_value, new_dimensions, function)); + } + // A reshape that collapses multiple dimensions into a dimension being // reduced can just reduce all of those dimensions instead of doing a // collapsing reshape before a reduction. @@ -1830,15 +1858,6 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { new_reduce_dimensions, function)); } } - if (ShapeUtil::ElementsIn(reduce->shape()) == - ShapeUtil::ElementsIn(arg->shape()) || - ShapeUtil::HasZeroElements(arg->shape())) { - auto reshape = computation_->AddInstruction( - HloInstruction::CreateReshape(reduce->shape(), arg)); - return ReplaceWithNewInstruction( - reduce, HloInstruction::CreateMap(reduce->shape(), - {init_value, reshape}, function)); - } return Status::OK(); } @@ -1858,7 +1877,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateMap(reduce_window->shape(), - {operand, reduce_window->mutable_operand(1)}, + {reduce_window->mutable_operand(1), operand}, function)); } @@ -2188,6 +2207,39 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( return true; } +Status AlgebraicSimplifierVisitor::HandleMap(HloInstruction* map) { + auto* map_computation = map->to_apply(); + auto* map_root = map_computation->root_instruction(); + if (map_root->opcode() == HloOpcode::kParameter) { + ReplaceInstructionIfSameShape( + map, map->mutable_operand(map_root->parameter_number())); + return Status::OK(); + } + if (map_root->opcode() == HloOpcode::kConstant) { + if (!ShapeUtil::IsScalar(map_root->shape())) { + return Status::OK(); + } + auto clone = map_root->CloneWithNewOperands(map_root->shape(), {}); + if (ShapeUtil::IsScalar(map->shape())) { + return ReplaceWithNewInstruction(map, std::move(clone)); + } + return ReplaceWithNewInstruction( + map, + HloInstruction::CreateBroadcast( + map->shape(), computation_->AddInstruction(std::move(clone)), {})); + } + std::vector new_operands; + for (auto* root_operand : map_root->operands()) { + if (root_operand->opcode() != HloOpcode::kParameter) { + return Status::OK(); + } + new_operands.push_back( + map->mutable_operand(root_operand->parameter_number())); + } + auto clone = map_root->CloneWithNewOperands(map->shape(), new_operands); + return ReplaceWithNewInstruction(map, std::move(clone)); +} + Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { // Match the following tree: // min_operand operand diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 4e082877c776c35bab499c805fef7632765a3ee1..2605b0488cb7c6850746df94c4ab05d6b5d35de5 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -74,6 +74,44 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +// Test that Reduce(Reduce(A)) -> Reduce(A) +TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) { + HloComputation::Builder builder(TestName()); + // Create add computation. + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r4f32, "param")); + std::vector dims0({0}); + Shape r3f32 = ShapeUtil::MakeShape(F32, {5, 6, 7}); + HloInstruction* reduce0 = builder.AddInstruction( + HloInstruction::CreateReduce(r3f32, param, zero, dims0, add_computation)); + std::vector dims1({1, 2}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); + builder.AddInstruction(HloInstruction::CreateReduce(r1f32, reduce0, zero, + dims1, add_computation)); + module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + HloInstruction* root = module().entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Reduce(param, zero)); + EXPECT_EQ(root->dimensions(), std::vector({0, 2, 3})); +} + // Test that Const + A is canonicalized to A + Const. TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -143,6 +181,39 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, InlineTrivialMap) { + HloComputation::Builder builder(TestName()); + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + Shape r2f32 = ShapeUtil::MakeShape(F32, {32, 1}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + builder.AddInstruction( + HloInstruction::CreateMap(r2f32, {param0, zero}, add_computation)); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kMap); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, zero)); +} + TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); @@ -1318,32 +1389,6 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { op::Tuple(op::Bitcast(), dimensions_wrong_reshape, layout_wrong_reshape)); } -TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { - HloComputation::Builder builder(TestName()); - HloInstruction* param = - builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), "param")); - HloInstruction* movable_reshape = - builder.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param)); - HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), - HloOpcode::kMaximum, movable_reshape, zero)); - auto computation = module().AddEntryComputation(builder.Build()); - - EXPECT_THAT(computation->root_instruction(), - op::Maximum(op::Reshape(param), zero)); - - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - bitcasting_callback()); - - simplifier.Run(&module()).ValueOrDie(); - EXPECT_THAT(computation->root_instruction(), - op::Reshape(op::Maximum(param, zero))); -} - // Regression test for a bug in the reshape sinking transformation, where // moving a reshape to a scalar led to a crash. TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { @@ -1707,7 +1752,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1752,7 +1797,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -1774,7 +1819,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1797,7 +1842,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1925,7 +1970,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, window, dnums)); - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, @@ -2053,7 +2099,7 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2083,7 +2129,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2114,7 +2160,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2144,7 +2190,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); @@ -2177,7 +2223,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2193,10 +2239,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction::CreateParameter(0, r0f32, "scalar_param")); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, scalar_param, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {})); Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3}); HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( @@ -2212,10 +2256,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2230,10 +2274,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); - HloInstruction* broadcast = - builder.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, forty_two, - AsInt64Slice(broadcast_shape.dimensions()))); + HloInstruction* broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {})); HloInstruction* transpose = builder.AddInstruction(HloInstruction::CreateTranspose( @@ -2252,7 +2294,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2261,7 +2303,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2342,7 +2385,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { - auto module = CreateNewModule(); + // TODO(b/80488902): verify this module. + auto module = HloTestBase::CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2437,7 +2481,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 96e02b82b97ff2fd682638f4c6297cbc2019c481..ec13fadbc75e2315d1d6ef72e24a0faca0c7de40 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -58,8 +58,7 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { // Runs the visitor on a computation. static bool Run(HloComputation* computation, bool rewrite_training_op, - bool rewrite_inference_op, bool rewrite_grad_op, - bool use_fusion); + bool rewrite_inference_op, bool rewrite_grad_op); // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } @@ -70,21 +69,14 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { explicit BatchNormExpanderVisitor(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, bool use_fusion) + bool rewrite_grad_op) : computation_(computation), rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_fusion_(use_fusion) {} + rewrite_grad_op_(rewrite_grad_op) {} HloComputation* GetOrCreateScalarAddComputation( PrimitiveType primitive_type) { - HloComputation** scalar_add_computation = - &scalar_add_computations_[primitive_type]; - if (*scalar_add_computation) { - return *scalar_add_computation; - } - HloComputation::Builder b("scalar_add_computation"); Shape shape = ShapeUtil::MakeShape(primitive_type, {}); auto scalar_lhs = b.AddInstruction( @@ -93,26 +85,39 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { HloInstruction::CreateParameter(1, shape, "scalar_rhs")); auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); - *scalar_add_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return *scalar_add_computation; + return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); } - // Current HloComputation instance the BatchNormExpander is - // traversing. - HloComputation* computation_; - - bool rewrite_training_op_; - bool rewrite_inference_op_; - bool rewrite_grad_op_; - bool use_fusion_; - - // Whether rewrite has occurred. - bool changed_ = false; + std::unique_ptr Rsqrt( + HloInstruction* operand, + const std::function)>& + add_instruction) { + HloInstruction* exponent = add_instruction(HloInstruction::CreateBroadcast( + operand->shape(), + add_instruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + add_instruction(HloInstruction::CreateConstant( + Literal::CreateR0(-0.5f))))), + {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kPower, + operand, exponent); + } - // Cached computations for adding two scalars. - tensorflow::gtl::FlatMap - scalar_add_computations_; + std::unique_ptr Mean( + int64 element_count, HloInstruction* operand, + const std::function)>& + add_instruction) { + HloInstruction* elem_count_recip = + add_instruction(HloInstruction::CreateBroadcast( + operand->shape(), + add_instruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(operand->shape().element_type(), {}), + add_instruction(HloInstruction::CreateConstant( + Literal::CreateR0(1.0 / element_count))))), + {})); + return HloInstruction::CreateBinary(operand->shape(), HloOpcode::kMultiply, + operand, elem_count_recip); + } // Replaces the existing HLO instruction old_instruction, with // new_instruction, and marks the optimizer status as changed. @@ -136,6 +141,16 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { changed_ = true; return Status::OK(); } + // Current HloComputation instance the BatchNormExpander is + // traversing. + HloComputation* computation_; + + bool rewrite_training_op_; + bool rewrite_inference_op_; + bool rewrite_grad_op_; + + // Whether rewrite has occurred. + bool changed_ = false; }; } // namespace @@ -143,13 +158,12 @@ class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, - bool rewrite_grad_op, bool use_fusion) { + bool rewrite_grad_op) { BatchNormExpanderVisitor visitor( computation, /*rewrite_training_op=*/rewrite_training_op, /*rewrite_inference_op=*/rewrite_inference_op, - /*rewrite_grad_op=*/rewrite_grad_op, - /*use_fusion=*/use_fusion); + /*rewrite_grad_op=*/rewrite_grad_op); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -167,6 +181,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); // Expand batch norm training into smaller HLO ops. @@ -176,12 +194,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( int64 feature_index = batch_norm->feature_index(); const int64 feature_count = operand_shape.dimensions(feature_index); const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - auto elements_per_feature_literal = - Literal::CreateR0(size_in_elements / feature_count); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + int64 elements_per_feature_int64 = size_in_elements / feature_count; HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -193,8 +206,9 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = - add(HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = add(HloInstruction::CreateBroadcast( + operand_shape, + add(HloInstruction::CreateConstant(std::move(epsilon_literal))), {})); std::vector dimensions_without_feature; for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { @@ -213,8 +227,8 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( GetOrCreateScalarAddComputation(ptype); // X^2. - auto operand_squared = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, operand, operand)); + auto operand_squared = + add_binary(operand_shape, HloOpcode::kMultiply, operand, operand); // Sum[X]. auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, dimensions_without_feature, @@ -225,71 +239,48 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( feature_shape, operand_squared, zero, dimensions_without_feature, add_reduce_computation)); - // Fuse two parallel reduces together to improve performance. - if (use_fusion_ && !batch_norm->has_sharding()) { - auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); - - auto fused = computation_->CreateFusionInstruction( - {tuple, sum, squared_sum, operand_squared}, - HloInstruction::FusionKind::kInput); - - sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - - squared_sum = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); - } - // E[X]. - auto mean = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); + auto mean = add(Mean(elements_per_feature_int64, sum, add)); auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); + auto square_mean = add(Mean(elements_per_feature_int64, squared_sum, add)); // E^2[X]. - auto mean_square = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kMultiply, mean, mean)); + auto mean_square = + add_binary(feature_shape, HloOpcode::kMultiply, mean, mean); // Var[X]. - auto var = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); + auto var = + add_binary(feature_shape, HloOpcode::kSubtract, square_mean, mean_square); auto var_broadcasted = add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto var_add_epsilon = + add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add)); // X - E[X]. - auto operand_minus_mean = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, + operand, mean_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = add( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, - operand_minus_mean, rsqrt_var_add_epsilon)); + auto normalized = add_binary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply, + normalized, scale_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. - auto shifted_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); + auto shifted_normalized = add_binary(operand_shape, HloOpcode::kAdd, + scaled_normalized, offset_broadcasted); auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); @@ -331,8 +322,11 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = computation_->AddInstruction(HloInstruction::CreateBroadcast( + operand_shape, + computation_->AddInstruction( + HloInstruction::CreateConstant(std::move(epsilon_literal))), + {})); std::vector dimensions_without_feature; @@ -349,6 +343,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); auto scale_broadcasted = add( @@ -364,30 +362,23 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto var_add_epsilon = + add_binary(operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(Rsqrt(var_add_epsilon, add)); // X - E[X]. - auto operand_minus_mean = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add_binary(operand_shape, HloOpcode::kSubtract, + operand, mean_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = add( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, - operand_minus_mean, rsqrt_var_add_epsilon)); + auto normalized = add_binary(operand_shape, HloOpcode::kMultiply, + operand_minus_mean, rsqrt_var_add_epsilon); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = add(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add_binary(operand_shape, HloOpcode::kMultiply, + normalized, scale_broadcasted); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. auto shifted_normalized = HloInstruction::CreateBinary( @@ -435,6 +426,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( added_instructions.push_back(added_inst); return added_inst; }; + auto add_binary = [&](const Shape& shape, const HloOpcode opcode, + HloInstruction* a, HloInstruction* b) { + return add(HloInstruction::CreateBinary(shape, opcode, a, b)); + }; int64 instruction_count_before = computation_->instruction_count(); HloInstruction* activation = batch_norm->mutable_operand(0); @@ -450,26 +445,20 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); const int64 feature_count = activation_shape.dimensions(feature_index); - auto elements_per_feature_literal = - Literal::CreateR0(size_in_elements / feature_count); - TF_ASSIGN_OR_RETURN(elements_per_feature_literal, - elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = add( - HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + const int64 elements_per_feature_int64 = size_in_elements / feature_count; auto zero_literal = Literal::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); - auto neg_half_literal = Literal::CreateR0(-0.5f); - TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = - add(HloInstruction::CreateConstant(std::move(neg_half_literal))); - auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = + auto epsilon_scalar = add(HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon_activation = add( + HloInstruction::CreateBroadcast(activation_shape, epsilon_scalar, {})); + auto epsilon_feature = + add(HloInstruction::CreateBroadcast(feature_shape, epsilon_scalar, {})); std::vector dimensions_without_feature; @@ -489,26 +478,23 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); // rsqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kPower, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)), - neg_half)); - - auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kPower, - add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, - epsilon)), - neg_half)); + auto rsqrt_var_add_epsilon_broadcasted = + add(Rsqrt(add_binary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon_activation), + add)); + + auto rsqrt_var_add_epsilon = add(Rsqrt( + add_binary(feature_shape, HloOpcode::kAdd, variance, epsilon_feature), + add)); // X - E[X]. - auto activation_minus_mean = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); + auto activation_minus_mean = add_binary( + activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted); // Grad[Y] * (X - E[X]). auto grad_output_times_activiation_minus_mean = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, activation_minus_mean)); + add_binary(activation_shape, HloOpcode::kMultiply, grad_output, + activation_minus_mean); HloComputation* add_reduce_computation = GetOrCreateScalarAddComputation(ptype); @@ -524,25 +510,10 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - if (use_fusion_ && !batch_norm->has_sharding()) { - auto tuple = add(HloInstruction::CreateTuple( - {sum_grad_output_times_activiation_minus_mean, grad_beta})); - - auto fused = computation_->CreateFusionInstruction( - {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, - HloInstruction::FusionKind::kInput); - - sum_grad_output_times_activiation_minus_mean = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - - grad_beta = - add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); - } - // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). - auto grad_scale = add(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kMultiply, - sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); + auto grad_scale = add_binary(feature_shape, HloOpcode::kMultiply, + sum_grad_output_times_activiation_minus_mean, + rsqrt_var_add_epsilon); // I2 = Sum(Grad[Y]) auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, @@ -554,39 +525,40 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( {feature_index})); // I4 = (X - E[X]) * I3 - auto i4 = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); + auto i4 = add_binary(activation_shape, HloOpcode::kMultiply, i3, + activation_minus_mean); // I5 = I4 / (Var[X] + epsilon) - auto i5 = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, i4, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)))); + auto i5 = add_binary(activation_shape, HloOpcode::kDivide, i4, + add_binary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon_activation)); // scale * rsqrt[Var[X] + epsilon] * 1/N - auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, scale_broadcasted, - rsqrt_var_add_epsilon_broadcasted)); + auto scale_times_rsqrt_var_add_epsilon = + add_binary(activation_shape, HloOpcode::kMultiply, scale_broadcasted, + rsqrt_var_add_epsilon_broadcasted); - scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, - elements_per_feature)); + scale_times_rsqrt_var_add_epsilon = add( + Mean(elements_per_feature_int64, scale_times_rsqrt_var_add_epsilon, add)); - auto i1 = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, elements_per_feature)); + auto elements_per_feature_literal = + Literal::CreateR0(elements_per_feature_int64); + TF_ASSIGN_OR_RETURN(elements_per_feature_literal, + elements_per_feature_literal->Convert(ptype)); + auto elements_per_feature = add( + HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + auto i1 = add_binary(activation_shape, HloOpcode::kMultiply, grad_output, + add(HloInstruction::CreateBroadcast( + activation_shape, elements_per_feature, {}))); // I6 = I1 - I2 - I5 - auto i6 = add(HloInstruction::CreateBinary( + auto i6 = add_binary( activation_shape, HloOpcode::kSubtract, - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, - i1, i2)), - i5)); + add_binary(activation_shape, HloOpcode::kSubtract, i1, i2), i5); // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. - auto grad_activation = - add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - scale_times_rsqrt_var_add_epsilon, i6)); + auto grad_activation = add_binary(activation_shape, HloOpcode::kMultiply, + scale_times_rsqrt_var_add_epsilon, i6); auto tuple = HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); if (batch_norm->has_sharding()) { @@ -615,8 +587,8 @@ StatusOr BatchNormExpander::Run(HloModule* module) { bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_, - rewrite_inference_op_, rewrite_grad_op_, - use_fusion_)) { + rewrite_inference_op_, + rewrite_grad_op_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.h b/tensorflow/compiler/xla/service/batchnorm_expander.h index 4ad987085da91684bb7891070afeefd19be4138f..7ae202c583516443a6263403fb5460d1adbabd97 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -31,11 +31,10 @@ class BatchNormExpander : public HloPassInterface { // When use_fusion is set, a multi-output fusion node is created. BatchNormExpander(bool rewrite_training_op = false, bool rewrite_inference_op = false, - bool rewrite_grad_op = false, bool use_fusion = true) + bool rewrite_grad_op = false) : rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), - rewrite_grad_op_(rewrite_grad_op), - use_fusion_(use_fusion) {} + rewrite_grad_op_(rewrite_grad_op) {} ~BatchNormExpander() = default; tensorflow::StringPiece name() const override { return "batchnorm_expander"; } @@ -47,7 +46,6 @@ class BatchNormExpander : public HloPassInterface { bool rewrite_training_op_; bool rewrite_inference_op_; bool rewrite_grad_op_; - bool use_fusion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 08d0152e3cfcfcb7ae1e85f72c2f7dc856f5e8b3..1b8b2d204503576c3fcb02f6d5b37f2db45e1768 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -182,15 +182,26 @@ Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { Status BFloat16ConversionFoldingVisitor::HandleCrossReplicaSum( HloInstruction* crs) { - if (!ShapeUtil::IsTuple(crs->shape()) || - !bfloat16_support_->SupportsMixedPrecisions(*crs)) { - return DefaultAction(crs); - } - // First use DefaultAction() to handle the operands. It can't handle // tuple-shaped output. TF_RETURN_IF_ERROR(DefaultAction(crs)); + if (!bfloat16_support_->SupportsMixedPrecisions(*crs)) { + return Status::OK(); + } + + // If the output is not a tuple, we don't need special handling. + if (!ShapeUtil::IsTuple(crs->shape())) { + return Status::OK(); + } + + // If crs is the root instruction, we should keep its original output type. + // The root instruction implicitly has a use from being the result of the + // computation, and the code below does not take this use into account. + if (crs == computation_->root_instruction()) { + return Status::OK(); + } + // Then do per-tuple-element handling on the output. std::vector> per_tuple_element_gtes( crs->operand_count()); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 28e71c2054f59ba4d5d096bf7d898161877bb42f..f7b4c1405dbc8719d8fba5476e6e41d2921ea877 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -211,6 +211,17 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto builder = HloComputation::Builder(TestName()); + + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("add"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -223,7 +234,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b})); + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, + sum, /*replica_group_ids=*/{}, /*barrier=*/"")); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( @@ -233,7 +245,6 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* tuple = builder.AddInstruction( HloInstruction::CreateTuple({gte_a, convert_gte_b})); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(FoldConversions(module.get())); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 1afaefd9df9c5771fb9e134ae9050f3abb00ea4a..830f26422bdc2b3bd789e7d5926bcebac815d34a 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -228,6 +228,17 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -239,11 +250,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b})); + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, + /*replica_group_ids=*/{}, /*barrier=*/"")); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(Normalize(module.get())); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index c0b8bf903923a327fb1378eafb51a7d493d5e62d..5d3b0cb333928d3b7b042ef0a6f4969f87655d7f 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -135,6 +135,7 @@ Status GatherComputationsByAllocationType( worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. break; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: @@ -630,9 +631,8 @@ Status BufferAssignment::ComputeSummaryStats() { } } if (module_sequence.size() == module_->computation_count()) { - TF_ASSIGN_OR_RETURN( - const int64 min_size, - MinimumMemoryForSequence(module_sequence, buffer_size_)); + TF_ASSIGN_OR_RETURN(const int64 min_size, + MinimumMemoryForModule(module_sequence, buffer_size_)); stats_.total_fragmentation_bytes = stats_.total_allocation_bytes - min_size; } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index a4fb0eefaca094898ed9acad8062484d1a36afe7..efa4696130ffeff669b0d674438a45c5a9d48ef2 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/call_graph.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -33,12 +32,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" @@ -82,7 +81,7 @@ const std::vector GetInstructions(HloInstruction* root) { class BufferAssignmentTest : public HloTestBase { protected: - BufferAssignmentTest() : computation_tracker_() {} + BufferAssignmentTest() {} ~BufferAssignmentTest() override {} std::unique_ptr RunBufferAssignment(HloModule* module, @@ -252,9 +251,6 @@ class BufferAssignmentTest : public HloTestBase { return total_size; } - // Computation tracker for nested computations. - ComputationTracker computation_tracker_; - // Shapes for use in the examples. Shape s32_ = ShapeUtil::MakeShape(xla::S32, {}); Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {}); @@ -375,11 +371,11 @@ TEST_F(BufferAssignmentTest, Basic) { // param1[100] --------------/--------/ auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -422,11 +418,11 @@ TEST_F(BufferAssignmentTest, BasicUniquelyColored) { // share anything. auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -481,11 +477,11 @@ TEST_F(BufferAssignmentTest, BasicPartiallyColored) { // have the color 0, which allows the mul and add to share buffers. auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -551,11 +547,11 @@ TEST_F(BufferAssignmentTest, MultipleUsersForNode) { // auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -605,7 +601,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { // Creates the main kernel and verifies instruction counts. auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10_, "")); + HloInstruction::CreateParameter(0, f32a100x10_, "p")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation)); module->AddEntryComputation(builder.Build()); @@ -658,7 +654,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10_, "")); + HloInstruction::CreateParameter(0, f32a100x10_, "p")); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0)); auto exp2 = builder.AddInstruction( @@ -822,7 +818,7 @@ TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) auto builder = HloComputation::Builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32vec100_, "")); + HloInstruction::CreateParameter(0, f32vec100_, "p")); auto exp1 = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0)); auto tanh = builder.AddInstruction( @@ -1500,11 +1496,11 @@ TEST_F(BufferAssignmentTest, TrivialPeakBuffers) { // param1[100] --------------/--------/ auto builder = HloComputation::Builder(TestName()); auto paramscalar = - builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "")); + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p")); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, f32vec100_, "")); + HloInstruction::CreateParameter(1, f32vec100_, "p1")); auto param1 = builder.AddInstruction( - HloInstruction::CreateParameter(2, f32vec100_, "")); + HloInstruction::CreateParameter(2, f32vec100_, "p2")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec100_, HloOpcode::kMultiply, paramscalar, param0)); auto add = builder.AddInstruction( @@ -1540,7 +1536,7 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { // be {%rev, %neg, %concat}. This occurs right at the concat itself. auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32vec100_, "")); + HloInstruction::CreateParameter(0, f32vec100_, "p")); auto log = builder.AddInstruction( HloInstruction::CreateUnary(f32vec100_, HloOpcode::kLog, param)); auto rev = builder.AddInstruction( @@ -1677,7 +1673,7 @@ class WhileBufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { auto sequence = - CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( module, xla::MakeUnique(module, sequence), ByteSizeOf, @@ -1797,7 +1793,7 @@ ENTRY %test_module { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); // Run CopyInsertion and check if the graph constructed above doesn't need // any copies inserted for BufferAssignment to run. @@ -2107,7 +2103,7 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { RunCopyInsertion(module.get()); auto sequence = - CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + ScheduleComputationsInModule(*module, ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo sequence for the // root computation, so we overwrite that entry with a manually diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index a8053d15e124319c5c898f0034b9aaa95a007a89..a23427f00ccd88bb0fe1d973a667f80ca54b14cd 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -57,6 +57,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h index c7763f2ca3e68490cd0cd9b4ba4d7bd180134080..fac0afd672ff3ed083aacf778dd9c4f90a2ee870 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.h +++ b/tensorflow/compiler/xla/service/channel_tracker.h @@ -19,9 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc deleted file mode 100644 index b16907da9e9c909d2639f83895db27d724a84a7b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/compilation_cache.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/compilation_cache.h" - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { - -std::shared_ptr CompilationCache::Insert( - std::unique_ptr executable, - const HloModuleConfig& module_config) { - tensorflow::mutex_lock lock(mutex_); - - CacheKey key = - BuildKey(executable->entry_computation_handle(), module_config); - VLOG(2) << "inserting cache key: " << key; - if (cache_.count(key) == 0) { - cache_.emplace(key, std::move(executable)); - } else { - // Executable already exists in the cache. This can happen if two Execute - // calls for a new computation are received simultaneously by the - // service. In this case, we discard the Executable given as a parameter and - // return what is in the cache. This is necessary because the service relies - // on the cache to keep ownership of the Executable. We only want to store - // one Executable for a given computation version and we can't discard the - // executable which is in the cache because it may be in use. - executable.reset(); - } - return cache_.at(key); -} - -std::shared_ptr CompilationCache::LookUp( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const { - tensorflow::mutex_lock lock(mutex_); - - CacheKey key = BuildKey(versioned_handle, module_config); - VLOG(2) << "looking up cache key: " << key; - if (cache_.count(key) == 0) { - VLOG(2) << "cache key not found: " << key; - return nullptr; - } else { - std::shared_ptr result = cache_.at(key); - VLOG(2) << "hit executable with module config: " - << result->module_config().compilation_cache_key(); - return result; - } -} - -CompilationCache::CacheKey CompilationCache::BuildKey( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const { - // The computation shape is represented entirely by its ProgramShape member, - // so just serialize the proto as part of the key. - return tensorflow::strings::StrCat(versioned_handle.handle.handle(), "::", - versioned_handle.version, "::", - module_config.compilation_cache_key()); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h deleted file mode 100644 index 09989726ae6629aa65cb1dd84c16408a75019fa5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/compilation_cache.h +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" - -namespace xla { - -// A cache which stores Executables indexed by computation handle and version. -class CompilationCache { - public: - CompilationCache() {} - - // Insert the given Executable into the cache. Return a bare Executable - // pointer for the caller to use. Note: the returned pointer will *not* be the - // same as the given unique pointer if the computation already exists in the - // cache. See comments in the .cc implementation for details of this case. - // - // module_config is provided by the caller, instead of being taken from the - // executable, so that we can insert keys into the compilation cache that are - // devoid of layout (where XLA gets to choose what layout to compile). - // - // A shared_ptr is returned so the caller can keep the Executable from being - // destructed in the event that the Executable is evicted from the - // computation cache (and the cache's shared_ptr to the Executable is - // destructed). - std::shared_ptr Insert(std::unique_ptr executable, - const HloModuleConfig& module_config); - - // Lookup the Executable for the specified versioned computation in the cache. - // Return a shared_ptr to the Executable if it exists in the cache. Return - // nullptr otherwise. - std::shared_ptr LookUp( - const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const; - - protected: - mutable tensorflow::mutex mutex_; - - // Map from versioned handle with program layout to Executable built - // for that computation version and program layout. - using CacheKey = string; - - CacheKey BuildKey(const VersionedComputationHandle& versioned_handle, - const HloModuleConfig& module_config) const; - std::map> cache_ GUARDED_BY(mutex_); - - private: - TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_ diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index d39fd7307ae1b5bd0c431f98c413011ca081050b..7426672a7a2a9102bd5ea98bd51092982e1e09b4 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -64,7 +63,8 @@ CompileOnlyService::CompileOnlyService(const ServiceOptions& options, StatusOr>> CompileOnlyService::CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { + const AotCompilationOptions& options, + std::unique_ptr* metadata) { std::vector> hlo_modules; for (const AotXlaComputationInstance& instance : computations) { TF_RET_CHECK(instance.computation.has_program_shape()); @@ -101,59 +101,8 @@ CompileOnlyService::CompileAheadOfTime( hlo_modules.push_back(std::move(hlo_module)); } - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); -} - -StatusOr>> -CompileOnlyService::CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& options) { - std::vector> hlo_modules; - for (const AotComputationInstance& instance : computations) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(instance.computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - const DebugOptions& debug_options = options.debug_options(); - - // Dump computation proto state if flag is set. - const string& directory_path = debug_options.xla_dump_computations_to(); - if (!directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - string filename = tensorflow::strings::StrCat( - "computation_", versioned_handle.handle.handle(), "__", - session_module->entry().name(), "__version_", - versioned_handle.version); - const string& per_host_path = tensorflow::io::JoinPath( - directory_path, tensorflow::port::Hostname()); - - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(per_host_path, filename, - *session_module)); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - ExecutionOptions execution_options; - *execution_options.mutable_debug_options() = debug_options; - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, user_computation)); - - TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, - computation_tracker_.BuildHloModule( - versioned_handle, *module_config, - /*include_unreachable_instructions=*/true)); - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); - hlo_modules.push_back(std::move(hlo_module)); - } - - return compiler_->CompileAheadOfTime(std::move(hlo_modules), options); + return compiler_->CompileAheadOfTime(std::move(hlo_modules), options, + metadata); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index 7f2ce0e8974c01b09664235d7b9d19555b2705a3..1ac950bdd66bd034dfdafa8598ec506221e99c2f 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -38,24 +38,7 @@ class CompileOnlyService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // A description of a computation to compile using CompileAheadOfTime. - struct AotComputationInstance { - ComputationHandle computation; - std::vector argument_layouts; - const Shape* result_layout = nullptr; - }; - - // Compiles a list of computations for ahead-of-time execution. This is - // intended for use in static compilation. See - // |CompileOnlyClient::CompileAheadOfTime| for additional details. - StatusOr>> - CompileAheadOfTime( - const tensorflow::gtl::ArraySlice computations, - const AotCompilationOptions& Options); - // A description of a xla computation to compile using CompileAheadOfTime. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. struct AotXlaComputationInstance { HloModuleProto computation; std::vector argument_layouts; @@ -65,31 +48,21 @@ class CompileOnlyService : public Service { // Compiles a list of xla computations for ahead-of-time execution. This is // intended for use in static compilation. See // |CompileOnlyClient::CompileAheadOfTime| for additional details. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr>> CompileAheadOfTime( const tensorflow::gtl::ArraySlice computations, const AotCompilationOptions& options); - // Override Service methods that require or imply the existence of an - // execute backend. Note that this does not include TransferToClient, as - // computing constants produces global data that we may wish to transfer. - Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } - Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } + StatusOr>> + CompileAheadOfTime( + const tensorflow::gtl::ArraySlice computations, + const AotCompilationOptions& options, + std::unique_ptr* metadata); + Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) override { return Unimplemented("CompileOnlyService does not support devices."); } - Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override { - return Unimplemented("CompileOnlyService does not support execution."); - } Status WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) override { return Unimplemented("CompileOnlyService does not support execution."); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 8b01a6c4b5004d03e6e7d23b99b923fdcdeaff99..0dceed853dcbae211657f00433866cfe10c51fc7 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -28,6 +28,27 @@ namespace xla { /* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( tensorflow::LINKER_INITIALIZED); +std::vector> +Compiler::ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const { + CHECK(executor != nullptr); + return {}; +} + +// Define a default version where metadata is not used. +StatusOr>> +Compiler::CompileAheadOfTime( + std::vector> modules, + const AotCompilationOptions& options, + std::unique_ptr* metadata) { + if (metadata != nullptr) { + return Unimplemented( + "Populating AotCompilationMetadata is not implemented on this " + "compiler."); + } + return CompileAheadOfTime(std::move(modules), options); +} + /* static */ std::map* Compiler::GetPlatformCompilerFactories() { static auto* r = new std::map; diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index a4b59d1ba9b24e3f886a7feb51181ae8f990951f..d1144f97bb2ab29d3d18f3b3f65a38af46e68dd1 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -24,9 +24,11 @@ limitations under the License. #include #include #include +#include #include "tensorflow/compiler/xla/service/buffer_value.h" #include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" @@ -34,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -91,6 +94,19 @@ class AotCompilationOptions { DebugOptions debug_options_; }; +// Abstract superclass describing metadata produced during ahead-of-time +// compilation. +class AotCompilationMetadata { + public: + AotCompilationMetadata(const AotCompilationMetadata&) = delete; + AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete; + + virtual ~AotCompilationMetadata() = default; + + protected: + AotCompilationMetadata() = default; +}; + // Abstract compiler interface that is subclassed for compilation on a // particular platform. // @@ -153,12 +169,29 @@ class Compiler { std::vector> stream_exec, DeviceMemoryAllocator* device_allocator) = 0; + // Returns the backend configurations that the backend will consider for the + // given HLO. Returns no configurations if the backend does not support + // configurations for the given HLO. + // + // The stream executor is passed in to provide information about the hardware + // that the backend configurations would be targeting. + virtual std::vector> + ComputeBackendConfigs(const HloInstruction& hlo, + se::StreamExecutor* executor) const; + // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. virtual StatusOr>> CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& options) = 0; + // Similar to CompileAheadOfTime above but AotCompilationMetadata + // has an argument that can be populated during compilation. + virtual StatusOr>> + CompileAheadOfTime(std::vector> modules, + const AotCompilationOptions& options, + std::unique_ptr* metadata); + ///// // The Compiler class also serves as a point to register compiler objects // for the various platforms. diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h index 53c3a3f7b738687db3098acfaef1ae87860d0440..6975f387b4864bf28ea0ad23d7d4602b5b346e08 100644 --- a/tensorflow/compiler/xla/service/computation_layout.h +++ b/tensorflow/compiler/xla/service/computation_layout.h @@ -32,12 +32,21 @@ namespace xla { // mutable layouts. class ComputationLayout { public: + // Creates a new ComputationLayout with the given result layout. + explicit ComputationLayout(ShapeLayout result_layout) + : result_layout_(std::move(result_layout)) {} + // Constructs a ComputationLayout from a ProgramShape. The layouts of the // parameters and results are set to the default layout. Layouts in the // ProgramShape are ignored if ignore_layouts is true. explicit ComputationLayout(const ProgramShape& program_shape, bool ignore_layouts = true); + // Adds a new parameter layout to the computation layout. + void add_parameter_layout(ShapeLayout shape_layout) { + parameter_layouts_.push_back(std::move(shape_layout)); + } + // Returns the layout of a particular parameter. const ShapeLayout& parameter_layout(int64 param_no) const { return parameter_layouts_[param_no]; diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc deleted file mode 100644 index 70e25eebdb068db893e24aec0f72d09090ac7027..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/computation_tracker.cc +++ /dev/null @@ -1,256 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/computation_tracker.h" - -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" - -using ::tensorflow::strings::Appendf; - -namespace xla { - -ComputationTracker::ComputationTracker() : next_computation_(1) {} - -ComputationHandle ComputationTracker::NewComputation( - const string& computation_name) { - tensorflow::mutex_lock lock(computation_mutex_); - ComputationHandle computation_handle; - int64 handle_value = next_computation_++; - computation_handle.set_handle(handle_value); - opaque_to_computation_[handle_value] = - MakeUnique(computation_name, computation_handle); - return computation_handle; -} - -StatusOr ComputationTracker::LoadSessionModule( - const SessionModule& session_module) { - tensorflow::mutex_lock lock(computation_mutex_); - - // For each embedded computation, create a new computation based on its - // serialized data, and place the mapping from the old computation handle to - // the new computation handle. - - // Build a mapping from old embedded computation handles to new computation - // handles. We build the ID mapping first since the embedded computations are - // in no particular order and may refer to each other. - std::map old_to_new; - for (const SessionComputation& computation : - session_module.embedded_computations()) { - const int64 old_handle = computation.computation_handle().handle(); - if (!old_to_new.emplace(old_handle, AllocateHandle()).second) { - return InvalidArgument("Duplicate embedded computation handle %lld", - old_handle); - } - } - - // Create a new computation from each serialized embedded computation. - for (const SessionComputation& computation : - session_module.embedded_computations()) { - const int64 old_handle = computation.computation_handle().handle(); - const ComputationHandle& new_handle = old_to_new[old_handle]; - TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], - UserComputation::MakeWithRemapping( - computation, new_handle, old_to_new)); - } - - // Finally, place the entry computation in the tracker with all of the - // remappings populated from the above. - const int64 old_handle = session_module.entry().computation_handle().handle(); - TF_ASSIGN_OR_RETURN( - old_to_new[old_handle], - LoadSessionComputation(session_module.entry(), &old_to_new)); - return old_to_new[old_handle]; -} - -StatusOr> -ComputationTracker::SnapshotComputation(const ComputationHandle& computation) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation)); - const VersionedComputationHandle entry_versioned_handle = - user_computation->GetVersionedHandle(); - std::set visited; - std::list post_order; - { - tensorflow::mutex_lock lock(computation_mutex_); - ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order); - } - auto session_module = MakeUnique(); - *session_module->mutable_entry() = - Resolve(entry_versioned_handle.handle) - .ValueOrDie() - ->CloneSessionComputation(entry_versioned_handle.version); - for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) { - *session_module->add_embedded_computations() = - Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version); - } - return std::move(session_module); -} - -StatusOr ComputationTracker::Resolve( - const ComputationHandle& computation) const { - tensorflow::mutex_lock lock(computation_mutex_); - return ResolveInternal(computation); -} - -ComputationHandle ComputationTracker::AllocateHandle() { - int64 handle_value = next_computation_++; - ComputationHandle result; - result.set_handle(handle_value); - return result; -} - -StatusOr ComputationTracker::LoadSessionComputation( - const SessionComputation& session_computation, - std::map* old_to_new) { - TF_RET_CHECK(old_to_new != nullptr); - const ComputationHandle new_handle = AllocateHandle(); - (*old_to_new)[session_computation.computation_handle().handle()] = new_handle; - TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()], - UserComputation::MakeWithRemapping( - session_computation, new_handle, *old_to_new)); - return new_handle; -} - -StatusOr ComputationTracker::ResolveInternal( - const ComputationHandle& computation) const { - auto it = opaque_to_computation_.find(computation.handle()); - if (it == opaque_to_computation_.end()) { - return NotFound("computation handle not found: %lld", computation.handle()); - } - UserComputation* user_computation = it->second.get(); - return user_computation; -} - -void ComputationTracker::ComputeComputationPostOrder( - const VersionedComputationHandle& versioned_handle, - std::set* visited, - std::list* post_order) const { - if (visited->count(versioned_handle) > 0) { - CHECK_EQ(1, visited->count(versioned_handle)); - return; - } - - UserComputation* computation = - ResolveInternal(versioned_handle.handle).ValueOrDie(); - std::vector embedded_handles = - computation->GetEmbeddedComputations(versioned_handle.version); - - for (const auto& embedded_handle : embedded_handles) { - ComputeComputationPostOrder(embedded_handle, visited, post_order); - } - - visited->insert(versioned_handle); - post_order->push_back(versioned_handle); -} - -StatusOr> ComputationTracker::BuildHloModule( - const VersionedComputationHandle& entry_handle, - const HloModuleConfig& config, - bool include_unreachable_instructions) const { - tensorflow::mutex_lock lock(computation_mutex_); - - VLOG(1) << "BuildHloModule(" << entry_handle - << ", include_unreachable_instructions=" - << include_unreachable_instructions << ")"; - XLA_VLOG_LINES(1, ToStringInternal()); - - TF_ASSIGN_OR_RETURN(UserComputation * entry_computation, - ResolveInternal(entry_handle.handle)); - - // Build a topological sort of the entry and any embedded computations as a - // list. The root of the computation will be the last element in the list. - std::set visited; - std::list post_order; - ComputeComputationPostOrder(entry_handle, &visited, &post_order); - - // Map from ComputationHandle value and computation version to HloComputation. - std::map hlo_computations; - - // The resolver lambda resolves VersionedHandles to embedded - // HloComputation*. This is required by UserComputation::BuildHloComputation - // when lowering calling operations (map, reduce etc). - auto resolver = [&hlo_computations]( - const VersionedComputationHandle& versioned_handle) -> HloComputation* { - CHECK_GT(hlo_computations.count(versioned_handle), 0); - return hlo_computations.at(versioned_handle); - }; - - // Print the post-order list for this entry computation. - if (VLOG_IS_ON(2)) { - VLOG(2) << "Visiting UserComputations in post order:"; - for (const VersionedComputationHandle& versioned_handle : post_order) { - VLOG(2) << " " << versioned_handle; - } - } - - string module_name = - tensorflow::strings::StrCat(entry_computation->name(), "_module"); - auto module = MakeUnique(module_name, entry_handle, config); - for (auto versioned_handle : post_order) { - UserComputation* computation = - ResolveInternal(versioned_handle.handle).ValueOrDie(); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_computation, - computation->BuildHloComputation(versioned_handle.version, resolver, - config.debug_options(), - include_unreachable_instructions)); - - // Add the newly created computation to VersionedHandle-to-HloComputation - // map. - DCHECK_EQ(0, hlo_computations.count(versioned_handle)); - hlo_computations[versioned_handle] = hlo_computation.get(); - - if (computation == entry_computation) { - module->AddEntryComputation(std::move(hlo_computation)); - } else { - module->AddEmbeddedComputation(std::move(hlo_computation)); - } - } - - return std::move(module); -} - -string ComputationTracker::ToString() const { - tensorflow::mutex_lock lock(computation_mutex_); - return ToStringInternal(); -} - -string ComputationTracker::ToStringInternal() const { - string out; - Appendf(&out, "ComputationTracker(%p):\n", this); - for (const auto& handle_computation : opaque_to_computation_) { - int64 handle = handle_computation.first; - const std::unique_ptr& computation = - handle_computation.second; - Appendf(&out, " %4lld : %s \"%s\"\n", handle, - computation->GetVersionedHandle().ToString().c_str(), - computation->name().c_str()); - } - return out; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/computation_tracker.h b/tensorflow/compiler/xla/service/computation_tracker.h deleted file mode 100644 index d42d66adefe7faa2751da4cd80b392a38917ce70..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/computation_tracker.h +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// Tracks computations for the XLA service; computations can be registered -// with a UserComputation instance and can be resolved from a handle for later -// use. -// -// This class is also capable of serializing/deserializing computations that it -// tracks (and to serialize properly you need to serialize all referred-to -// computations as well). -class ComputationTracker { - public: - ComputationTracker(); - - // Creates a new UserComputation object and returns the corresponding - // ComputationHandle for it. - // - // Precondition: user_computation is not already present in the map. - ComputationHandle NewComputation(const string& computation_name); - - // Restores session data for a computation that has been serialized, and - // allocates a new computation handle for it. - StatusOr LoadSessionModule( - const SessionModule& session_module); - - // Snapshots a computation (referenced by the provided handle) at its latest - // version, returning a module where it is the entry, and any referred-to - // computations are entrained as "embedded" (non-entry) computations. - StatusOr> SnapshotComputation( - const ComputationHandle& computation); - - // Resolves a ComputationHandle to a UserComputation that is present in the - // map. - StatusOr Resolve( - const ComputationHandle& computation) const; - - // Builds an HLO module using the specified computation as the entry. The - // module will include the entry computation as well as all computations which - // are called directly or indirectly from the entry computation via operations - // like "map". config is the HLO module configuration to use for the - // constructed module. - // If include_unreachable_instructions is true, then instructions - // which are not reachable from the root are lowered into HloInstructions - // including unreachable parameters. This ensures the entry HloComputation has - // the same program shape (ProgramShape) as the entry UserComputation. - StatusOr> BuildHloModule( - const VersionedComputationHandle& entry_handle, - const HloModuleConfig& config, - bool include_unreachable_instructions = true) const; - - string ToString() const; - - private: - // Bumps the next_computation_ number and returns the allocated number wrapped - // in a ComputationHandle. - ComputationHandle AllocateHandle() - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Loads a session computation into a UserComputation, registers it, and - // returns the computation handle of the registered computation. If old_to_new - // is provided, it is used for remapping references to computations present in - // session_computation. - // - // old_to_new will be updated with the mapping from session_computation's old - // handle to the returned handle value, and may not be null. - StatusOr LoadSessionComputation( - const SessionComputation& session_computation, - std::map* old_to_new) - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Internal implementation of Resolve method which requires, but does not - // acquire the mutex. - StatusOr ResolveInternal( - const ComputationHandle& computation) const - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Builds a post order sort of a computation ("entry") and all of its embedded - // computations including all transitively embedded computations. An embedded - // computation (the callee) will always appear in the sort before the - // computation which calls the embedded computation (the caller). Necessarily, - // the entry computation is the last element in the sort. visited and - // post_order should be empty when calling. post_order contains the post order - // sort when the function return. - void ComputeComputationPostOrder( - const VersionedComputationHandle& versioned_handle, - std::set* visited, - std::list* post_order) const - EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_); - - // Guards the computation mapping. Marked mutable so that the Resolve method - // can remain const; Resolve does't really modify the tracker in any way, but - // it has to lock the mutex for safety. - mutable tensorflow::mutex computation_mutex_; - - // The next sequence number to assign to a computation, guarded by the same - // mutex as the mapping as they'll be mutated at the same time. - int64 next_computation_ GUARDED_BY(computation_mutex_); - - // Mapping from ComputationHandle value to the corresponding registered - // UserComputation object. - std::map> opaque_to_computation_ - GUARDED_BY(computation_mutex_); - - TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_ diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 153f062d015e49db11c4c9ae0a2a61e76c020f02..684fff8a6fd4fd045a0de9df9660b887f32bdf40 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1636,8 +1636,7 @@ void BM_SequentialWhiles(int num_iters, int num_whiles) { for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); + HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_SequentialWhiles"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1677,8 +1676,7 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { for (int i = 0; i < num_iters; ++i) { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); + HloModule module("BM_SequentialWhiles", config); auto builder = HloComputation::Builder("BM_ParallelWhiles"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1750,8 +1748,7 @@ void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { std::vector tuple_params(num_tuple_inputs); for (int i = 0; i < num_iters; ++i) { auto builder = HloComputation::Builder("BM_ParallelWhiles"); - HloModule module("BM_ManyElementTuple", VersionedComputationHandle(), - config); + HloModule module("BM_ManyElementTuple", config); for (int j = 0; j < num_tuple_inputs; ++j) { tuple_params[j] = builder.AddInstruction( HloInstruction::CreateParameter(j, element_shape, "")); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index bfd85f257fb9550a6babb2459a7227ca9003a14f..b703be0f39e2032bc58479f0b957f9d8b01a77c3 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -151,7 +151,14 @@ cc_library( "@llvm//:target", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep "@llvm//:x86_disassembler", # fixdeps: keep - ], + ] + select({ + "@org_tensorflow//tensorflow:linux_ppc64le": [ + "@llvm//:powerpc_disassembler", + "@llvm//:powerpc_code_gen", + ], + "//conditions:default": [ + ], + }), alwayslink = True, # Contains compiler registration ) @@ -649,10 +656,10 @@ tf_cc_test( deps = [ ":cpu_instruction_fusion", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", ], ) @@ -706,9 +713,9 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -898,6 +905,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", ], @@ -958,7 +966,7 @@ tf_cc_test( ":ir_emission_utils", ":target_machine_features_fake", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 25b18eff20f901fc34343a12bfbd353ecec49cfb..d039132535071661d047579587385210719fede3 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -264,8 +264,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile, pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_fusion=*/false); + /*rewrite_grad_op=*/true); pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, @@ -550,8 +549,8 @@ StatusOr> CpuCompiler::RunBackend( // and reduced memory usage (as compared to using DependencyHloOrdering). TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction(), - DFSMemoryScheduler)); + ScheduleComputationsInModule(*module, BufferSizeBytesFunction(), + DFSMemoryScheduler)); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -730,7 +729,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, - CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction())); + ScheduleComputationsInModule(*module, BufferSizeBytesFunction())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc index d12fa6bb9ad2054bdc052c9d7b3729cc28e11f6d..8727c72b6e42517b1859e98ecadb41bbceed761c 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_eigen_tensor_alignment_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace cpu { @@ -40,7 +40,7 @@ ENTRY DotOperation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloInstruction* dot = module->entry_computation()->root_instruction(); @@ -71,7 +71,7 @@ ENTRY ConvOperation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloInstruction* conv = module->entry_computation()->root_instruction(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 46fe060817b0264d90574b45a94cf1f6e5964593..97e10a89a209c057685709e7a5034052ff4376ed 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -172,7 +172,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* computation = module->entry_computation(); TransposeFolding transpose_folding( @@ -202,7 +202,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* computation = module->entry_computation(); TransposeFolding transpose_folding( @@ -233,7 +233,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* computation = module->entry_computation(); TransposeFolding transpose_folding( @@ -775,7 +775,7 @@ TEST_P(GatherLoopFusionTest, GatherLoopFusion) { string hlo_string = tensorflow::strings::StrCat( "HloModule ", spec.test_name, "\n\n", spec.hlo_computation_text); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); RunFusionAndCheckOpcodesWereFused( module.get(), diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index e75fcb6bc9719f7453d5f0cb52d1673cef1fd3df..3ed7876715f64191f6e652d2b5cb1673df9a1b94 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace { @@ -24,6 +25,7 @@ const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaEnableExperimentalLlvmIrGemm = "xla_enable_experimental_llvm_ir_gemm"; +const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -62,6 +64,43 @@ bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaEnableExperimentalLlvmIrGemm) > 0; } +static tensorflow::StringPiece RemoveSuffix(tensorflow::StringPiece str, + tensorflow::StringPiece suffix) { + CHECK_GE(str.size(), suffix.size()); + CHECK_EQ(str.substr(str.size() - suffix.size()), suffix); + return str.substr(0, str.size() - suffix.size()); +} + +tensorflow::gtl::optional> LlvmIrGemmTileSize( + const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + auto it = extra_options_map.find(kLlvmIrGemmTileSize); + if (it == extra_options_map.end()) { + return tensorflow::gtl::nullopt; + } + + std::vector tile_components = + tensorflow::str_util::Split(it->second, ':'); + CHECK_EQ(tile_components.size(), 3); + + int64 tile_size_m; + int64 tile_size_k; + int64 tile_size_n_in_vector_width; + + CHECK(tensorflow::strings::safe_strto64(tile_components[0], &tile_size_m)); + CHECK(tensorflow::strings::safe_strto64(tile_components[1], &tile_size_k)); + + tensorflow::StringPiece tile_size_n_in_vector_width_str = + RemoveSuffix(tile_components[2], "*vectwidth"); + + CHECK(tensorflow::strings::safe_strto64(tile_size_n_in_vector_width_str, + &tile_size_n_in_vector_width)); + + return std::tuple(tile_size_m, tile_size_k, + tile_size_n_in_vector_width); +} + } // namespace options } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 106dfbbc62dfba8d3de74e0a2ae3bb247bd91d67..429b9e16cbdd6f623919533582481f1640118081 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -29,6 +29,8 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config); bool EnableExperimentalLlvmIrGemm(const HloModuleConfig& config); tensorflow::gtl::optional LlvmIrGemvTilingFactor( const HloModuleConfig& config); +tensorflow::gtl::optional> LlvmIrGemmTileSize( + const HloModuleConfig& config); } // namespace options } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index af69fc3da9869aa2df958ecc5c064ee37dd9ea21..8eb39d615fd482cdcea716ba7b105c643a2d8b87 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -42,17 +42,17 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { namespace { -// Loads a tile of values from a 2D tensor. -class TileLoader { +// Provides tiled access to an in-memory rank 2 array. +class MemoryTile { public: - // Constructs a TileLoader that will load a tile consisting of + // Constructs a MemoryTile that can operate on tiles consisting of // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at // `major_dim_offset` in the major dimension. The tile size along the minor // dimension is the vector size, and that is implicitly determined by `vsl`. - TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, + MemoryTile(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, llvm::Value* matrix, int64 matrix_size_along_minor_dim, llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) - : vsl_(vsl) { + : vsl_(vsl), ir_builder_(ir_builder) { pointers_.reserve(tile_size_along_major_dim); for (int64 i = 0; i < tile_size_along_major_dim; i++) { llvm::Value* total_offset = ir_builder->CreateMul( @@ -62,9 +62,10 @@ class TileLoader { } } - // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at - // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the - // minor dimension. + // Load a tile consisting of `tile_size_along_major_dim` vectors from position + // {major: `major_dim_offset`, minor: `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. std::vector LoadTile(llvm::Value* minor_dim_offset) const { std::vector result; result.reserve(pointers_.size()); @@ -74,11 +75,104 @@ class TileLoader { return result; } + // Stores `tile` to position {major: `major_dim_offset`, minor: + // `minor_dim_offset`}. + // + // Note: `major_dim_offset` is a parameter to the constructor. + void StoreTile(tensorflow::gtl::ArraySlice tile, + llvm::Value* minor_dim_offset) const { + CHECK_EQ(tile.size(), pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + vsl_->StoreVector(tile[i], pointers_[i], minor_dim_offset); + } + } + + // Loads a tile of size [`tile_size_along_major_dim`, + // `tile_size_along_middle_dim`] from position {major: `major_dim_offset`, + // minor: `minor_dim_offset`} and then broadcasts each element into a vector + // of size vsl_.vector_size(). The (i,j)'th element of the return value is + // the (i,j)'th element in the tile broadcasted into an LLVM vector. + // + // Note: `major_dim_offset` is a parameter to the constructor. + std::vector> LoadBroadcastTile( + llvm::Value* minor_dim_offset, int64 tile_size_along_middle_dim) const { + std::vector> result; + result.resize(pointers_.size()); + for (int64 i = 0; i < pointers_.size(); i++) { + for (int64 j = 0; j < tile_size_along_middle_dim; j++) { + result[i].push_back(vsl_->LoadBroadcast( + pointers_[i], ir_builder_->CreateAdd(minor_dim_offset, + ir_builder_->getInt64(j)))); + } + } + return result; + } + private: VectorSupportLibrary* vsl_; + llvm::IRBuilder<>* ir_builder_; std::vector pointers_; }; +// The base class for the classes representing the GEMV emitter configurations. +// +// The IR emitted (modulo the LLVM values representing the input and output +// buffers) by the row major and column major GEMV emitters should be a function +// of their configuration. This is important because their configuration is +// used as a key to cache the generated IR. +class GemvConfig { + public: + // Mixin for convenience. + template + struct User { + public: + PrimitiveType scalar_type() const { + return derived().config().scalar_type(); + } + int64 tile_rows() const { return derived().config().tile_rows(); } + int64 tile_cols() const { return derived().config().tile_cols(); } + int64 m() const { return derived().config().m(); } + int64 k() const { return derived().config().k(); } + int64 has_addend() const { return derived().config().has_addend(); } + + private: + const T& derived() const { return *static_cast(this); } + }; + + PrimitiveType scalar_type() const { return scalar_type_; } + int64 tile_rows() const { return tile_rows_; } + int64 tile_cols() const { return tile_cols_; } + int64 m() const { return m_; } + int64 k() const { return k_; } + bool has_addend() const { return has_addend_; } + + string GetCacheKey() const { + return tensorflow::strings::StrCat( + name_, "_", PrimitiveType_Name(scalar_type()), "_", tile_rows(), "_", + tile_cols(), "_", m(), "_", k(), has_addend() ? "_with_addend" : ""); + } + + protected: + explicit GemvConfig(string name, PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, bool has_addend) + : name_(std::move(name)), + scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + has_addend_(has_addend) {} + + private: + string name_; + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + bool has_addend_; +}; + // Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the // layout of the vector does not matter). This implementation uses a tiling // scheme to improve performance. @@ -140,38 +234,46 @@ class TileLoader { // TODO(sanjoy): We should investigate if using gather loads and scatter stores // can be used here have the same inner loop for both column-major and row-major // matrix-vector products. -class ColumnMajorMatrixVectorProductEmitter { +class ColumnMajorMatrixVectorProductEmitter + : public GemvConfig::User { public: - ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, - int64 tile_rows, int64 tile_cols, - int64 m, int64 k, llvm::Value* lhs, + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"col_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + ColumnMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, llvm::Value* rhs, llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* ir_builder) - : scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), + : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), - vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") { - CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast(tile_rows_))); + vsl_(config.scalar_type(), /*vector_size=*/config.tile_rows(), + ir_builder_, "") { + CHECK(tile_rows() > 0 && IsPowerOfTwo(static_cast(tile_rows()))); + CHECK(!has_addend() || addend != nullptr); } void Emit(); + const Config& config() const { return config_; } + private: void EmitOuterLoopBody(llvm::Value* column, int64 column_count, bool is_first_column); - TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) { - return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/m_, + MemoryTile GetLhsMemoryTile(llvm::Value* column_start, int64 column_count) { + return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/m(), /*major_dim_offset=*/column_start, /*tile_size_along_major_dim=*/column_count); } @@ -188,18 +290,14 @@ class ColumnMajorMatrixVectorProductEmitter { return result; } - void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, int64 columns, bool is_first_column); void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column); - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; + Config config_; llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* addend_; @@ -211,26 +309,26 @@ class ColumnMajorMatrixVectorProductEmitter { void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( llvm::Value* column, int64 column_count, bool is_first_column) { - TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column, + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*column_start=*/column, /*column_count=*/column_count); std::vector rhs_tile = LoadRhsTile(column, /*count=*/column_count); - EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile, + EmitInnerLoopTiled(&lhs_memory_tile, rhs_tile, /*columns=*/column_count, is_first_column); EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); } void ColumnMajorMatrixVectorProductEmitter::Emit() { // See the comment on the class declaration for the algorithm used here. - int64 column_remainder = k_ % tile_cols_; - int64 column_limit = k_ - column_remainder; + int64 column_remainder = k() % tile_cols(); + int64 column_limit = k() - column_remainder; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_, - [&](llvm::Value* column, bool is_first_column) { - EmitOuterLoopBody(column, tile_cols_, is_first_column); - }); + ksl_.ForReturnVoid("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols(), + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols(), is_first_column); + }); if (column_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, @@ -239,29 +337,30 @@ void ColumnMajorMatrixVectorProductEmitter::Emit() { } void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - TileLoader* lhs_tile_loader, const std::vector& rhs_tile, + MemoryTile* lhs_memory_tile, const std::vector& rhs_tile, int64 columns, bool is_first_column) { - int64 row_limit = m_ - (m_ % tile_rows_); - - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, - /*step=*/tile_rows_, [&](llvm::Value* row) { - std::vector lhs_tile = - lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = - is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) - : vsl_.GetZeroVector()) - : vsl_.LoadVector(result_, row); - for (int i = 0; i < columns; i++) { - accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); - } - vsl_.StoreVector(accumulator, result_, row); - }); + int64 row_limit = m() - (m() % tile_rows()); + + ksl_.ForReturnVoid( + "dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows(), [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); } void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { - int64 row_start = m_ - (m_ % tile_rows_); - if (row_start == m_) { + int64 row_start = m() - (m() % tile_rows()); + if (row_start == m()) { return; } @@ -274,25 +373,25 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // initialized. // } - ksl_.For( + ksl_.ForReturnVoid( "dot.inner.epilg.outer", /*start=*/current_tile_col, /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), /*step=*/1, /*peel_first_iteration=*/false, [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); llvm::Value* total_offset = - ir_builder_->CreateMul(col, ir_builder_->getInt64(m_)); + ir_builder_->CreateMul(col, ir_builder_->getInt64(m())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For( - "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_, + ksl_.ForReturnVoid( + "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m(), /*step=*/1, [&](llvm::Value* scalar_row) { llvm::Value* product = vsl_.Mul( vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( is_first_scalar_col, ir_builder_->getInt1(is_first_tiled_column)); - ksl_.If( + ksl_.IfReturnVoid( setting_result_first_time, /*true_block_generator=*/ [&]() { @@ -365,51 +464,55 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( // // We have an inner epilogue loop to deal with the "B" sub-matrix and an outer // epilogue loop to deal with the C,D submatrix. -class RowMajorMatrixVectorProductEmitter { +class RowMajorMatrixVectorProductEmitter + : public GemvConfig::User { public: - RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, - int64 tile_cols, int64 m, int64 k, - llvm::Value* lhs, llvm::Value* rhs, - llvm::Value* addend, llvm::Value* result, + class Config : public GemvConfig { + public: + explicit Config(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, + int64 m, int64 k, bool has_addend) + : GemvConfig(/*name=*/"row_major_gemv", scalar_type, + /*tile_rows=*/tile_rows, /*tile_cols=*/tile_cols, /*m=*/m, + /*k=*/k, /*has_addend=*/has_addend) {} + }; + + RowMajorMatrixVectorProductEmitter(const Config& config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* ir_builder) - : scalar_type_(scalar_type), - tile_rows_(tile_rows), - tile_cols_(tile_cols), - m_(m), - k_(k), + : config_(config), lhs_(lhs), rhs_(rhs), addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), - vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") { - CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast(tile_cols_))); + vsl_(scalar_type(), /*vector_size=*/tile_cols(), ir_builder_, "") { + CHECK(tile_cols() > 0 && IsPowerOfTwo(static_cast(tile_cols()))); + CHECK(!has_addend() || addend != nullptr); } void Emit(); + const Config& config() const { return config_; } + private: - TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) { - return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, - /*matrix_size_along_minor_dim=*/k_, + MemoryTile GetLhsMemoryTile(llvm::Value* row_start, int64 row_count) { + return MemoryTile(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/k(), /*major_dim_offset=*/row_start, /*tile_size_along_major_dim=*/row_count); } void EmitOuterLoopBody(llvm::Value* row, int64 row_count); - void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows, + void EmitInnerLoopTiled(MemoryTile* lhs_memory_tile, int64 rows, std::vector* vector_accumulators); void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, std::vector* scalar_accumulators); - PrimitiveType scalar_type_; - int64 tile_rows_; - int64 tile_cols_; - int64 m_; - int64 k_; + Config config_; llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* addend_; @@ -421,7 +524,7 @@ class RowMajorMatrixVectorProductEmitter { void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, int64 row_count) { - TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row, + MemoryTile lhs_memory_tile = GetLhsMemoryTile(/*row_start=*/row, /*row_count=*/row_count); std::vector vector_accumulators; std::vector scalar_accumulators; @@ -429,7 +532,7 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); } - EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count, + EmitInnerLoopTiled(&lhs_memory_tile, /*rows=*/row_count, &vector_accumulators); EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, &scalar_accumulators); @@ -466,12 +569,13 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, void RowMajorMatrixVectorProductEmitter::Emit() { // See the comment on the class declaration for the algorithm used here. - int64 row_remainder = m_ % tile_rows_; - int64 row_limit = m_ - row_remainder; + int64 row_remainder = m() % tile_rows(); + int64 row_limit = m() - row_remainder; - ksl_.For("dot.outer.tiled", - /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_, - [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); }); + ksl_.ForReturnVoid( + "dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows(), + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows()); }); if (row_remainder != 0) { EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); @@ -479,45 +583,46 @@ void RowMajorMatrixVectorProductEmitter::Emit() { } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( - TileLoader* lhs_tile_loader, int64 rows, + MemoryTile* lhs_memory_tile, int64 rows, std::vector* vector_accumulators) { - int64 column_limit = k_ - (k_ % tile_cols_); - - ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, - /*step=*/tile_cols_, [&](llvm::Value* col) { - std::vector lhs_tile = - lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col); - llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); - for (int i = 0; i < rows; i++) { - llvm::Value* old_sum = (*vector_accumulators)[i].Get(); - (*vector_accumulators)[i].Set( - vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); - } - }); + int64 column_limit = k() - (k() % tile_cols()); + + ksl_.ForReturnVoid("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols(), [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_memory_tile->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set(vsl_.Add( + old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); } void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( llvm::Value* current_tile_row, int64 rows, std::vector* scalar_accumulators) { - int64 column_start = k_ - (k_ % tile_cols_); - if (column_start == k_) { + int64 column_start = k() - (k() % tile_cols()); + if (column_start == k()) { return; } for (int r = 0; r < rows; r++) { llvm::Value* total_offset = ir_builder_->CreateMul( ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row), - ir_builder_->getInt64(k_)); + ir_builder_->getInt64(k())); llvm::Value* lhs_base_pointer = vsl_.ComputeOffsetPointer(lhs_, total_offset); - ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_, - /*step=*/1, [&](llvm::Value* scalar_col) { - llvm::Value* product = - vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), - vsl_.LoadScalar(rhs_, scalar_col)); - llvm::Value* old_value = (*scalar_accumulators)[r].Get(); - (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); - }); + ksl_.ForReturnVoid( + "dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k(), + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); } } @@ -543,141 +648,234 @@ class MatrixMatrixBlockPanelEmitter { int64 k() const { return k_; } int64 n() const { return n_; } + string ToString() const { + return tensorflow::strings::StrCat(m(), "x", k(), "x", n()); + } + private: const int64 m_; const int64 k_; const int64 n_; }; - // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies - // `lhs` with `rhs` and stores the result in `result`. + // Represents the configuration of the GEBP emitter. The LLVM IR emitted by + // the emitter, modulo the LLVM values holding the input and output buffers, + // must be a function of the instance of `Config` passed to it. // - // `m`, `k` and `n` are the matrix multiplication dimensions. + // `dims` holds the matrix multiplication dimensions. // // `max_vectorization_width` is the maximum vector width (i.e. the width of // the largest vector register we will use). This can be larger than the // largest vector register supported by the machine -- LLVM will legalize // these large vector widths into legally sized vectors. + // + // `max_vector_count` is the maximum number of vectors of size + // `max_vectorization_width` that we will attempt to process at once. + // // `min_vectorization_width` is the smallest vector width the emitter will use // -- below that it will devolve to using a scalar loop. // - // `k_tiling_factor` is the number of elements along the reduction dimensions - // that we will attempt to process at once. - explicit MatrixMatrixBlockPanelEmitter( - llvm::Value* lhs, llvm::Value* rhs, llvm::Value* result, Dimensions dims, - int max_vectorization_width, int min_vectorization_width, - int k_tiling_factor, const TargetMachineFeatures& target_machine_features, - llvm::IRBuilder<>* ir_builder, PrimitiveType primitive_type) + // The innermost reduction loop executes the matrix multiply in tiles of size + // [`tile_size_m`, `tile_size_k`] from the LHS and [`tile_size_k`, + // ] in the RHS. + class Config { + public: + explicit Config(PrimitiveType scalar_type, Dimensions dims, + int64 max_vectorization_width, int64 max_vector_count, + int64 min_vectorization_width, int64 tile_size_m, + int64 tile_size_k) + : scalar_type_(scalar_type), + dims_(dims), + max_vectorization_width_(max_vectorization_width), + max_vector_count_(max_vector_count), + min_vectorization_width_(min_vectorization_width), + tile_size_m_(tile_size_m), + tile_size_k_(tile_size_k) {} + + string GetCacheKey() const { + return tensorflow::strings::StrCat( + "gebp_", PrimitiveType_Name(scalar_type()), "_", dims().ToString(), + "_", max_vectorization_width(), "_", min_vectorization_width(), "_", + tile_size_m(), "_", tile_size_k()); + } + + PrimitiveType scalar_type() const { return scalar_type_; } + Dimensions dims() const { return dims_; } + int64 max_vectorization_width() const { return max_vectorization_width_; } + int64 max_vector_count() const { return max_vector_count_; } + int64 min_vectorization_width() const { return min_vectorization_width_; } + + int64 tile_size_m() const { return tile_size_m_; } + int64 tile_size_k() const { return tile_size_k_; } + + private: + PrimitiveType scalar_type_; + Dimensions dims_; + int64 max_vectorization_width_; + int64 max_vector_count_; + int64 min_vectorization_width_; + int64 tile_size_m_; + int64 tile_size_k_; + }; + + // Creates an instance of MatrixMatrixBlockPanelEmitter that matrix-multiplies + // `lhs` with `rhs` and stores the result in `result`. + explicit MatrixMatrixBlockPanelEmitter(Config config, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* ir_builder) : lhs_(lhs), rhs_(rhs), result_(result), - dims_(dims), - max_vectorization_width_(max_vectorization_width), - min_vectorization_width_(min_vectorization_width), - k_tiling_factor_(k_tiling_factor), - target_machine_features_(target_machine_features), + config_(config), ir_builder_(ir_builder), - primitive_type_(primitive_type), ksl_(ir_builder_) { - CHECK(max_vectorization_width > 0 && - IsPowerOfTwo(static_cast(max_vectorization_width))); - CHECK(min_vectorization_width > 0 && - IsPowerOfTwo(static_cast(min_vectorization_width))); - CHECK_GT(k_tiling_factor, 0); + CHECK(max_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(max_vectorization_width()))); + CHECK_GT(max_vector_count(), 0); + CHECK(min_vectorization_width() > 0 && + IsPowerOfTwo(static_cast(min_vectorization_width()))); + CHECK_GE(max_vectorization_width(), min_vectorization_width()); + CHECK_GT(tile_size_k(), 0); } void Emit(); private: - // We can only iterate the `n` dimension for an extent that is divisible by - // the vectorization width. So we emit an outer loop that first processes the - // largest extent in `n` that is divisible by max_vectorization_width, then - // the largest remaining extent that is divisible by max_vectorization_width / - // 2 etc. This function emits that outermost loop. - void EmitChunkedLoopOverN(); - - // This emits a loop that loops over the `k` dimension in multiples of - // `k_tiling_factor` as much as possible and then emits a remainder epilogue. - void EmitLoopOverK(VectorSupportLibrary* vsl, llvm::Value* n_start, - llvm::Value* n_end); - - // This emits the inner reduction loop. This inner reduction loop processes - // all indices in the `m` dimension, [`k_start`, `k_end`) in the k dimension - // and [`n_start`, `n_end`) in the `n` dimension. - void EmitInnerLoop(int64 k_tiling_factor, llvm::Value* k_start, - llvm::Value* k_end, llvm::Value* n_start, - llvm::Value* n_end, VectorSupportLibrary* vsl); - - llvm::Value* getInt64(int64 value) { return ir_builder_->getInt64(value); } + // The HandleResiduesOnX helpers split the iteration space for dimension X + // into a multiple of the tile size on dimension X and an epilogue. These + // helpers ultimately call into `EmitTiledGemm` for emitting the + // tiled GEMM kernel. + + void HandleResiduesOnN(); + void HandleResiduesOnK(VectorSupportLibrary* vsl, llvm::Value* n_start, + llvm::Value* n_end); + void HandleResiduesOnM(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end); + + // This emits a tiled GEMM kernel. For a detailed description see the comment + // on the implementation. + void EmitTiledGemm(VectorSupportLibrary* vsl, int64 tile_size_k, + llvm::Value* k_start, llvm::Value* k_end, + llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, + llvm::Value* m_end); + + llvm::Value* GetInt64(int64 value) { return ir_builder_->getInt64(value); } + + Config config() const { return config_; } + Dimensions dims() const { return config().dims(); } + + int64 max_vectorization_width() const { + return config().max_vectorization_width(); + } + int64 max_vector_count() const { return config().max_vector_count(); } + int64 min_vectorization_width() const { + return config().min_vectorization_width(); + } + int64 tile_size_m() const { return config().tile_size_m(); } + int64 tile_size_k() const { return config().tile_size_k(); } + PrimitiveType scalar_type() const { return config().scalar_type(); } llvm::Value* lhs_; llvm::Value* rhs_; llvm::Value* result_; - Dimensions dims_; - - int64 max_vectorization_width_; - int64 min_vectorization_width_; - int64 k_tiling_factor_; + Config config_; - const TargetMachineFeatures& target_machine_features_; llvm::IRBuilder<>* ir_builder_; - PrimitiveType primitive_type_; KernelSupportLibrary ksl_; }; -void MatrixMatrixBlockPanelEmitter::Emit() { EmitChunkedLoopOverN(); } +void MatrixMatrixBlockPanelEmitter::Emit() { HandleResiduesOnN(); } + +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnN() { + // We can only iterate the `n` dimension for an extent that is divisible by + // the vectorization width. So we emit an outer loop that first processes the + // largest extent in `n` that is divisible by max_vectorization_width, then + // the largest remaining extent that is divisible by max_vectorization_width / + // 2 etc. + + int64 current_vectorization_width = + max_vector_count() * max_vectorization_width(); + int64 current_vector_count = max_vector_count(); -void MatrixMatrixBlockPanelEmitter::EmitChunkedLoopOverN() { - int64 current_vectorization_width = max_vectorization_width_; int64 n_start = 0; - while (n_start != dims_.n() && - current_vectorization_width >= min_vectorization_width_) { - int64 n_end = dims_.n() - (dims_.n() % current_vectorization_width); + while (n_start != dims().n() && + current_vectorization_width >= min_vectorization_width()) { + int64 n_end = dims().n() - (dims().n() % current_vectorization_width); if (n_start != n_end) { - VectorSupportLibrary vsl(primitive_type_, current_vectorization_width, + VectorSupportLibrary vsl(scalar_type(), current_vectorization_width, ir_builder_, "gebp"); - EmitLoopOverK(&vsl, getInt64(n_start), getInt64(n_end)); + HandleResiduesOnK(&vsl, GetInt64(n_start), GetInt64(n_end)); n_start = n_end; } - current_vectorization_width /= 2; + if (current_vector_count == 1) { + current_vectorization_width /= 2; + } else { + current_vector_count--; + current_vectorization_width = + current_vector_count * max_vectorization_width(); + } } - if (n_start != dims_.n()) { - VectorSupportLibrary vsl(primitive_type_, 1, ir_builder_, "gebp"); - ksl_.For("epi.n", n_start, dims_.n(), 1, [&](llvm::Value* n_i) { + if (n_start != dims().n()) { + VectorSupportLibrary vsl(scalar_type(), 1, ir_builder_, "gebp"); + ksl_.ForReturnVoid("epi.n", n_start, dims().n(), 1, [&](llvm::Value* n_i) { llvm::Value* n_i_next = ir_builder_->CreateAdd(n_i, ir_builder_->getInt64(1)); - EmitLoopOverK(&vsl, n_i, n_i_next); + HandleResiduesOnK(&vsl, n_i, n_i_next); }); } } -void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, - llvm::Value* n_start, - llvm::Value* n_end) { +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnK(VectorSupportLibrary* vsl, + llvm::Value* n_start, + llvm::Value* n_end) { int64 k_start = 0; - int64 k_end = dims_.k() - (dims_.k() % k_tiling_factor_); + int64 k_end = dims().k() - (dims().k() % tile_size_k()); if (k_end != k_start) { - EmitInnerLoop(k_tiling_factor_, getInt64(k_start), getInt64(k_end), n_start, - n_end, vsl); + HandleResiduesOnM(vsl, tile_size_k(), GetInt64(k_start), GetInt64(k_end), + n_start, n_end); k_start = k_end; } - if (k_start != dims_.k()) { - EmitInnerLoop(dims_.k() - k_start, getInt64(k_start), getInt64(dims_.k()), - n_start, n_end, vsl); + if (k_start != dims().k()) { + HandleResiduesOnM(vsl, dims().k() - k_start, GetInt64(k_start), + GetInt64(dims().k()), n_start, n_end); } } +void MatrixMatrixBlockPanelEmitter::HandleResiduesOnM( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end) { + const int64 m_end = dims().m() - dims().m() % tile_size_m(); + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, tile_size_m(), + GetInt64(0), GetInt64(m_end)); + + if (m_end != dims().m()) { + EmitTiledGemm(vsl, tile_size_k, k_start, k_end, n_start, n_end, + dims().m() - m_end, GetInt64(m_end), GetInt64(dims().m())); + } +} + +// The loop structure is: +// +// Iterate over dimension M as m: +// Iterate over dimension N as n: +// Iterate over dimension K as k: +// OutputTile[m,n] += Dot(LhsTile[m,k], RhsTile[k,n]) +// +// I.e. a just a tiled version of a "naive" GEMM. +// // The tiling scheme is as follows: // // Let the LHS be: // -// +---+---+---+ -// | a | b | c | . -// +---+---+---+ . -// | | | | . -// +---+---+---+ +// +----+----+----+ +// | a0 | b0 | c0 | . +// +----+----+----+ . +// | a1 | b1 | c1 | . +// +----+----+----+ // .. .. // // and the RHS be: @@ -691,75 +889,87 @@ void MatrixMatrixBlockPanelEmitter::EmitLoopOverK(VectorSupportLibrary* vsl, // +----+----+----+----+ . // ...... ...... // -// and let k_tiling_factor be 3 and the vector width (implicitly denoted by -// `vsl`) be 4. +// and let tile_size_m=2, tile_size_k=3 and the vector width (implicitly denoted +// by `vsl`) be 4. Then we want to matrix multiply this tile to get a [2,4] +// matrix that we can increment the result matrix by. // -// Then we +// First broadcast the rows row in LHS to 3 vectors of width 4, giving us a rank +// 3 array, L, of dimension [2,3,4]: // -// 1. broadcast the first row in LHS to 3 vectors of width 4 -// 2. elementwise multiply the RHS rows with these broadcasted vectors -// 3. elementwise add them: +// L[0,_,_] * L[1,_,_] +// * +// +----+----+----+----+ * +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | a1 | a1 | a1 | a1 | +// +----+----+----+----+ * +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | b1 | b1 | b1 | b1 | +// +----+----+----+----+ * +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | c1 | c1 | c1 | c1 | +// +----+----+----+----+ * +----+----+----+----+ // -// +---+---+---+---+ +----+----+----+----+ -// | a | a | a | a | * | p0 | p1 | p2 | p3 | + -// +---+---+---+---+ +----+----+----+----+ // -// +---+---+---+---+ +----+----+----+----+ -// | b | b | b | b | * | q0 | q1 | q2 | q3 | + -// +---+---+---+---+ +----+----+----+----+ +// Then we FMA L[0,_,_] with the RHS to get the first row of the result and +// L[1,_,_] with the RHS to get the second row of the result. For example, +// L[0,_,_] is computed as: // -// +---+---+---+---+ +----+----+----+----+ -// | c | c | c | c | * | r0 | r1 | r2 | r3 | -// +---+---+---+---+ +----+----+----+----+ +// +----+----+----+----+ +----+----+----+----+ +// | a0 | a0 | a0 | a0 | * | p0 | p1 | p2 | p3 | + +// +----+----+----+----+ +----+----+----+----+ // -// to get: +// +----+----+----+----+ +----+----+----+----+ +// | b0 | b0 | b0 | b0 | * | q0 | q1 | q2 | q3 | + +// +----+----+----+----+ +----+----+----+----+ // -// +----------------+----------------+----------------+----------------+ -// | a*p0+b*q0+c*r0 | a*p1+b*q1+c*r1 | a*p2+b*q2+c*r2 | a*p3+b*q3+c*r3 | -// +----------------+----------------+----------------+----------------+ +// +----+----+----+----+ +----+----+----+----+ +// | c0 | c0 | c0 | c0 | * | r0 | r1 | r2 | r3 | +// +----+----+----+----+ +----+----+----+----+ // -// which we increment into the appropriate region in the result. -void MatrixMatrixBlockPanelEmitter::EmitInnerLoop( - int64 k_tiling_factor, llvm::Value* k_start, llvm::Value* k_end, - llvm::Value* n_start, llvm::Value* n_end, VectorSupportLibrary* vsl) { - ksl_.For("dot.m", 0, dims_.m(), 1, [&](llvm::Value* m_i) { - // This outer loop iterates over all of the M dimension - llvm::Value* result_row_begin = vsl->ComputeOffsetPointer( - result_, /*offset_elements=*/m_i, /*scale=*/dims_.n()); - llvm::Value* lhs_row_begin = vsl->ComputeOffsetPointer( - lhs_, /*offset_elements=*/m_i, /*scale=*/dims_.k()); - - ksl_.For("dot.k", k_start, k_end, k_tiling_factor, [&](llvm::Value* k_i) { - // broadcasted_a is the broadcasted set of vectors denoted as , - // etc. in the diagram. - std::vector broadcasted_a; - broadcasted_a.reserve(k_tiling_factor); - for (int i = 0; i < k_tiling_factor; i++) { - broadcasted_a.push_back(vsl->LoadBroadcast( - lhs_row_begin, ir_builder_->CreateAdd(getInt64(i), k_i))); - } +// to get: +// +// +-------------------+-------------------+-------------------+--------- +// | a0*p0+b0*q0+c0*r0 | a0*p1+b0*q1+c0*r1 | a0*p2+b0*q2+c0*r2 | ... +// +-------------------+-------------------+-------------------+--------- +void MatrixMatrixBlockPanelEmitter::EmitTiledGemm( + VectorSupportLibrary* vsl, int64 tile_size_k, llvm::Value* k_start, + llvm::Value* k_end, llvm::Value* n_start, llvm::Value* n_end, + int64 tile_size_m, llvm::Value* m_start, llvm::Value* m_end) { + ksl_.ForReturnVoid( + "dot.m", m_start, m_end, tile_size_m, [&](llvm::Value* m_i) { + MemoryTile result_memory_tile( + vsl, ir_builder_, /*matrix=*/result_, + /*matrix_size_along_minor_dim=*/dims().n(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + MemoryTile lhs_memory_tile(vsl, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/dims().k(), + /*major_dim_offset=*/m_i, + /*tile_size_along_major_dim=*/tile_size_m); + ksl_.ForReturnVoid( + "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { + TileVariable result_tile_var(vsl, + result_memory_tile.LoadTile(n_i)); + ksl_.ForReturnVoid( + "dot.k", k_start, k_end, tile_size_k, [&](llvm::Value* k_i) { + MemoryTile rhs_memory_tile(vsl, ir_builder_, rhs_, + dims().n(), k_i, tile_size_k); + std::vector> lhs_tile = + lhs_memory_tile.LoadBroadcastTile(k_i, tile_size_k); + std::vector rhs_tile = + rhs_memory_tile.LoadTile(n_i); + std::vector result_tile = + result_tile_var.Get(); + for (int64 r_m_i = 0; r_m_i < tile_size_m; r_m_i++) { + for (int64 r_k_i = 0; r_k_i < tile_size_k; r_k_i++) { + result_tile[r_m_i] = + vsl->MulAdd(lhs_tile[r_m_i][r_k_i], rhs_tile[r_k_i], + result_tile[r_m_i]); + } + } + result_tile_var.Set(result_tile); + }); - // rhs_loader will be used to load the tile off of the RHS, denoted as - // <, ...> in the diagram. - TileLoader rhs_loader(vsl, ir_builder_, rhs_, dims_.n(), k_i, - k_tiling_factor); - ksl_.For( - "dot.n", n_start, n_end, vsl->vector_size(), [&](llvm::Value* n_i) { - // This loop iterates over the N dimension. It loads the tile from - // RHS, does the FMA resulting in the - // in the diagram and increments - // the result. - std::vector tile = rhs_loader.LoadTile(n_i); - llvm::Value* result_accumulator = - vsl->LoadVector(result_row_begin, n_i); - for (int i = 0; i < tile.size(); i++) { - result_accumulator = - vsl->MulAdd(tile[i], broadcasted_a[i], result_accumulator); - } - vsl->StoreVector(result_accumulator, result_row_begin, n_i); - }); - }); - }); + result_memory_tile.StoreTile(result_tile_var.Get(), n_i); + }); + }); } } // namespace @@ -827,8 +1037,6 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( return false; } - VLOG(2) << "Emitting GEBP kernel in LLVM IR"; - llvm::Value* lhs = lhs_array_.GetBasePointer(); llvm::Value* rhs = rhs_array_.GetBasePointer(); llvm::Value* target = target_array_.GetBasePointer(); @@ -846,14 +1054,41 @@ bool DotOpEmitter::EmitExperimentalGebpDotIfEnabled( target, ir_builder_->getInt8(0), size_bytes, target_machine_features_.minimum_alignment_for_allocation(size_bytes)); - MatrixMatrixBlockPanelEmitter::Dimensions gebp_dims(/*m=*/m, /*k=*/k, - /*n=*/n); - MatrixMatrixBlockPanelEmitter gebp_emitter( - /*lhs=*/lhs, /*rhs=*/rhs, /*result=*/target, gebp_dims, - /*max_vectorization_width=*/8, /*min_vectorization_width=*/4, - /*k_tiling_factor=*/8, target_machine_features_, ir_builder_, - primitive_type); - gebp_emitter.Emit(); + int64 max_target_vector_width = + target_machine_features_.vector_register_num_elements( + *ir_builder_->GetInsertBlock()->getParent(), primitive_type); + + int64 tile_size_m, tile_size_k, tile_size_n_in_vector_width; + std::tie(tile_size_m, tile_size_k, tile_size_n_in_vector_width) = + GetGemmTileSize(); + + MatrixMatrixBlockPanelEmitter::Config config( + /*scalar_type=*/primitive_type, + MatrixMatrixBlockPanelEmitter::Dimensions{/*m=*/m, /*k=*/k, /*n=*/n}, + /*max_vectorization_width=*/max_target_vector_width, + /*max_vector_count=*/tile_size_n_in_vector_width, + /*min_vectorization_width=*/std::min(4, max_target_vector_width), + /*tile_size_m=*/tile_size_m, /*tile_size_k=*/tile_size_k); + + VLOG(2) << "Emitting GEBP kernel in LLVM IR with config " + << config.GetCacheKey(); + + const bool enable_fast_math = + hlo_module_config_.debug_options().xla_enable_fast_math(); + const bool optimize_for_size = + options::OptimizeForSizeRequested(hlo_module_config_); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs, rhs, target, + [this, config](llvm::Value* lhs, llvm::Value* rhs, llvm::Value* target) { + MatrixMatrixBlockPanelEmitter gebp_emitter( + config, /*lhs=*/lhs, /*rhs=*/rhs, + /*result=*/target, ir_builder_); + gebp_emitter.Emit(); + }); + return true; } @@ -942,47 +1177,39 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - int64 tile_rows = vector_register_element_size; - int64 tile_cols = tiling_factor; - - string kernel_name = tensorflow::strings::StrCat( - "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, - "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + ColumnMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/primitive_type, + /*tile_rows=*/vector_register_element_size, /*tile_cols=*/tiling_factor, + /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, - lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, tile_rows, tile_cols, m, k, primitive_type]( - llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, - llvm::Value* result_op) { + [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, + llvm::Value* addend_op, llvm::Value* result_op) { ColumnMajorMatrixVectorProductEmitter emitter( - primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, - addend_op, result_op, ir_builder_); + config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); emitter.Emit(); }); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - int64 tile_rows = tiling_factor; - int64 tile_cols = vector_register_element_size; - - string kernel_name = tensorflow::strings::StrCat( - "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, - "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + RowMajorMatrixVectorProductEmitter::Config config( + /*scalar_type=*/primitive_type, + /*tile_rows=*/tiling_factor, /*tile_cols=*/vector_register_element_size, + /*m=*/m, /*k=*/k, /*has_addend=*/addend_array_ != nullptr); KernelSupportLibrary::EmitAndCallOutlinedKernel( /*enable_fast_math=*/enable_fast_math, - /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, - lhs_op, rhs_op, + /*optimize_for_size=*/optimize_for_size, ir_builder_, + config.GetCacheKey(), lhs_op, rhs_op, addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, - [this, tile_rows, tile_cols, m, k, primitive_type]( - llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, - llvm::Value* result_op) { + [this, config](llvm::Value* lhs_op, llvm::Value* rhs_op, + llvm::Value* addend_op, llvm::Value* result_op) { RowMajorMatrixVectorProductEmitter emitter( - primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, - addend_op, result_op, ir_builder_); + config, lhs_op, rhs_op, addend_op, result_op, ir_builder_); emitter.Emit(); }); } @@ -1074,8 +1301,11 @@ Status DotOpEmitter::Emit() { // from messing up the vectorization. std::unique_ptr reduction_loop = loop_nest.AddLoop( 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction", - /*prevent_unrolling=*/lhs_reduction_along_minor_dimension && - rhs_reduction_along_minor_dimension); + /*unroll_mode=*/ + (lhs_reduction_along_minor_dimension && + rhs_reduction_along_minor_dimension) + ? xla::llvm_ir::UnrollMode::kNoUnroll + : xla::llvm_ir::UnrollMode::kDefaultUnroll); // The final entry in the rhs and lhs indexes is the indvar of the // reduction loop. diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index d88ccea0dbc845c0d9a580a5b118c57c888fb557..ed2a18976a0f1a88e7bb4632d3a63167d5c146ad 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -143,6 +143,17 @@ class DotOpEmitter { .value_or(kDefaultTilingFactor); } + std::tuple GetGemmTileSize() const { + // Tuned for broadwell - Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz + // + // TODO(b/80093688): Tune for other architectures and centralize this + // information in one place. + const std::tuple kDefaultTileSize = + std::tuple(11, 9, 1); + return options::LlvmIrGemmTileSize(hlo_module_config_) + .value_or(kDefaultTileSize); + } + // Returns true if we should use an experimental implementation of GEMM // (general matrix matrix multiplication) if possible. bool EnableExperimentalLlvmIrGemm() const { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc index abb2471e6ae6b2f2949ab2e91235e5047ae404f8..530ebce854fedf4e4db12139d5b56087b1176a6c 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -35,7 +35,7 @@ ENTRY Conv { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* entry_computation = module->entry_computation(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 13bd5e73db500e20b0e8c33bf921ee2457e126e5..59223fddac2f5f7e2e85de4d37e4b6c5760ae697 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -160,39 +160,44 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } -llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { - llvm::GlobalVariable* result; +llvm::Constant* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { + llvm::Constant* result; // We avoid creating large constants in the LLVM IR since LLVM is not // efficient for large constant arrays. We still emit "small enough" constant // arrays into the Ir, in the off chance the LLVM optimizer can do something // interesting with it. + // + // TODO(b/29904935): Remove the large constant pool. const int kMaxInternalConstantSizeInBytes = 128; if (external_constant_pool_ && ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { string global_name = tensorflow::strings::StrCat( "constant_global_", external_global_constant_counter_++); - result = new llvm::GlobalVariable( + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/IrShapeType(literal.shape()), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::ExternalLinkage, /*Initializer=*/nullptr, /*Name=*/AsStringRef(global_name)); - result->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); external_constant_pool_->Insert(global_name, literal, MinimumAlignmentForShape(literal.shape())); + result = result_global; } else { llvm::Constant* initializer = llvm_ir::ConvertLiteralToIrConstant(literal, module_); - result = new llvm::GlobalVariable( + llvm::GlobalVariable* result_global = new llvm::GlobalVariable( /*Module=*/*module_, /*Type=*/initializer->getType(), /*isConstant=*/true, /*Linkage=*/llvm::GlobalValue::PrivateLinkage, /*Initializer=*/initializer, /*Name=*/""); - result->setAlignment(MinimumAlignmentForShape(literal.shape())); + result_global->setAlignment(MinimumAlignmentForShape(literal.shape())); + result = llvm::ConstantExpr::getBitCast( + result_global, IrShapeType(literal.shape())->getPointerTo()); } return result; } @@ -200,7 +205,7 @@ llvm::GlobalVariable* IrEmitter::EmitGlobalForLiteral(const Literal& literal) { Status IrEmitter::HandleConstant(HloInstruction* constant) { VLOG(2) << "HandleConstant: " << constant->ToString(); const Literal& literal = constant->literal(); - llvm::GlobalVariable* global_for_const; + llvm::Constant* global_for_const; auto it = emitted_literals_.find(&literal); if (it != emitted_literals_.end()) { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index f49cfc1dc378bb80da3ddf995363acfa2081067b..32c536e18fee86cc60067ba3b25ab1eb0e4233df 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -527,7 +527,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address); - llvm::GlobalVariable* EmitGlobalForLiteral(const Literal& literal); + // Returns a ConstExpr bitcast. + llvm::Constant* EmitGlobalForLiteral(const Literal& literal); const HloModuleConfig& hlo_module_config_; @@ -548,7 +549,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { } }; - tensorflow::gtl::FlatMap emitted_literals_; diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 63d0f7b95c7e45913c707471dbe2dc62e05251d6..4fa5984b0466b178a587e97cbced97deac749f74 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -38,7 +38,7 @@ class SimpleCostModel : public ParallelCostModel { const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. // Return target parallel task count in [1, max_parallelism_]. return std::min(max_parallelism_, - std::max(1LL, instruction_cost / min_cost_per_thread)); + std::max(int64{1}, instruction_cost / min_cost_per_thread)); } private: @@ -63,7 +63,7 @@ class DefaultCostModel : public ParallelCostModel { int64 max_parallelism; // Calculate flops-to-bytes-ratio for 'instruction'. const int64 bytes_accessed = - std::max(1LL, cost_analysis_->bytes_accessed(*instruction)); + std::max(int64{1}, cost_analysis_->bytes_accessed(*instruction)); const float flops_to_bytes_ratio = cost_analysis_->flop_count(*instruction) / static_cast(bytes_accessed); @@ -93,7 +93,7 @@ class DefaultCostModel : public ParallelCostModel { } // Return target parallel task count in [1, max_parallelism_]. return std::min(max_parallelism, - std::max(1LL, instruction_cost / min_cost_per_thread)); + std::max(int64{1}, instruction_cost / min_cost_per_thread)); } private: diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc index 92da5f71c23d5e1450b39ea8b7bb8345f6fabb3b..f8c8dd5e93d53db8d87be0208b5cf4daac3464f1 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML) #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" #include "third_party/intel_mkl_ml/include/mkl_cblas.h" #include "third_party/intel_mkl_ml/include/mkl_service.h" diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.cc b/tensorflow/compiler/xla/service/cpu/shape_partition.cc index 42fe955f1917e0268dc739e44fbd0a7afb39185c..d12c5396148d32adb178b955a34e050cc56784da 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.cc +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.cc @@ -115,7 +115,7 @@ ShapePartitionIterator::ShapePartitionIterator( for (int i = 0; i < dimension_partition_sizes_.size(); ++i) { const int64 dim_size = shape_.dimensions(dimensions_[i]); dimension_partition_sizes_[i] = - std::max(1LL, dim_size / dimension_partition_counts_[i]); + std::max(int64{1}, dim_size / dimension_partition_counts_[i]); } // Calculate the partition strides for each dimension. diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 67f776e7b5883f425b41c05342b74bebe223e17f..66ae5ef0f66e90982102d73e474f5d0582f5415c 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -152,9 +152,9 @@ tf_cc_test( srcs = ["cpu_literal_caching_test.cc"], deps = [ "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -166,9 +166,9 @@ tf_cc_test( srcs = ["cpu_outfeed_test.cc"], deps = [ "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc index ed8f375bd6186e4805fe9ded5be9ae7c9f4d5c84..faac927027c48e44eb8ff1fcc4109fbc177fc579 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_external_constants_test.cc @@ -64,8 +64,8 @@ TEST_F(CpuExternalConstantsTest, BasicNegative) { // The constant array in this test case is small enough that there is no need // to externalize it. TestWithArray(/*rows=*/4, /*cols=*/4, R"( -CHECK-NOT: @constant_global_0 = external constant [4 x [4 x float]], align 8 -CHECK: @0 = private constant [4 x [4 x float]] {{.*}}, align 8 +CHECK-NOT: @constant_global_0 = external constant [16 x float], align 8 +CHECK: @0 = private constant [16 x float] {{.*}}, align 8 )"); } } // namespace diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index d6e0425c5542be89835571f0103b1829f63cc2c2..27044b1d62027e3b83744c486cb790269e505aff 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" namespace xla { namespace cpu { @@ -55,12 +55,12 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [2 x [3 x [2 x float]]] -CHECK-NOT: private constant [2 x [3 x [2 x float]]] +CHECK: private constant [12 x float] +CHECK-NOT: private constant [12 x float] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", @@ -78,34 +78,34 @@ TEST_F(CpuDuplicateConstantsTest, RepeatedTupleConstants) { HloModule RepeatedConstants while_body { - arg_body = (f32[2,1]{1,0}, f32[2]{0}) parameter(0) - ROOT const = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) + arg_body = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) + ROOT const = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) } while_cond { - arg_cond = (f32[2,1]{1,0}, f32[2]{0}) parameter(0) + arg_cond = (f32[2,1]{1,0}, f32[1]{0}) parameter(0) ROOT unknown = pred[] infeed() } ENTRY main { param = f32[2,3,2] parameter(0) - const_a = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) - const_b = (f32[2,1]{1,0}, f32[2]{0}) while((f32[2,1]{1,0}, f32[2]{0}) const_a), condition=while_cond, body=while_body + const_a = (f32[2,1]{1,0}, f32[1]{0}) constant((f32[2,1], f32[1]) ( f32[2,1] { { 1 }, { 2 } }, {2} )) + const_b = (f32[2,1]{1,0}, f32[1]{0}) while((f32[2,1]{1,0}, f32[1]{0}) const_a), condition=while_cond, body=while_body - out0 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_a) - ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[2]{0}) const_b) + out0 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_a) + ROOT out1 = () outfeed((f32[2,1]{1,0}, f32[1]{0}) const_b) } )"; string filecheck_pattern = R"( +CHECK: private constant [1 x float] CHECK: private constant [2 x float] -CHECK: private constant [2 x [1 x float]] +CHECK-NOT: private constant [1 x float] CHECK-NOT: private constant [2 x float] -CHECK-NOT: private constant [2 x [1 x float]] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index 879372eb13884cdb7edd8cfb3e8b4bac4e314951..1ee279290b6fcfe775ce9867d424b1c031f5d2bd 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" namespace xla { namespace cpu { @@ -37,11 +37,11 @@ ENTRY main { )"; string filecheck_pattern = R"( -CHECK: private constant [2 x [3 x [2 x float]]] +CHECK: private constant [12 x float] )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); CpuAotCompilationOptions options{ /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index cd1165e23812861ba9951546b7dd744529232196..c444d151858d3a152a01b99657ffae89ebc6b487 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -427,5 +427,27 @@ llvm::Value* LlvmVariable::Get() const { void LlvmVariable::Set(llvm::Value* new_value) { ir_builder_->CreateStore(new_value, alloca_); } + +TileVariable::TileVariable(VectorSupportLibrary* vector_support, + std::vector initial_value) { + for (llvm::Value* initial_vector_value : initial_value) { + storage_.emplace_back(vector_support, initial_vector_value); + } +} + +std::vector TileVariable::Get() const { + std::vector result; + c_transform(storage_, std::back_inserter(result), + [&](VectorVariable vect_var) { return vect_var.Get(); }); + return result; +} + +void TileVariable::Set(tensorflow::gtl::ArraySlice value) { + CHECK_EQ(value.size(), storage_.size()); + for (int64 i = 0, e = value.size(); i < e; i++) { + storage_[i].Set(value[i]); + } +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index edcaec584997b17dce30b8c46fda4abc78441064..49c2a4e2f4bae9e1672b7d2fe891301bce08bd4b 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace cpu { @@ -317,6 +318,21 @@ class ScalarVariable : public LlvmVariable { Set(initial_value); } }; + +// This wraps a set of alloca-backed stack variables that can, as a whole, store +// a tile. A "tile" is a sequence of vectors that is typically used as a 2D +// grid of scalar values (e.g. for tiled GEMMs). +class TileVariable { + public: + TileVariable(VectorSupportLibrary* vector_support, + std::vector initial_value); + + std::vector Get() const; + void Set(tensorflow::gtl::ArraySlice value); + + private: + std::vector storage_; +}; } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index b9d7ec9c2e17e560580fcea060bf552c42fe3b3c..ee2b455730f8f520db6652f0352f8a96291cac73 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -197,6 +197,10 @@ class DfsHloVisitorBase { return HandleElementwiseUnary(hlo); } + virtual Status HandleDomain(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; @@ -239,6 +243,8 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; + virtual Status HandleGenerateToken(HloInstructionPtr token) = 0; + // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". virtual Status FinishVisit(HloInstructionPtr root) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 240faebe62f5cee4f61b3c36b5e8f653cfd6db8e..6934e00a4b665e9e6a4302e0c0a8ce1d5bb94373 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -188,6 +188,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } + Status HandleGenerateToken(HloInstructionPtr token) override { + return DefaultAction(token); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 9a8bab353ef6b1e0b05b250d35296bc3cef8bc37..93fea7ead7a86bb34c449668fd88a58145681eb1 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -456,17 +456,15 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm::ConstantFP::get(type, 1.0))); } case HloOpcode::kIsFinite: { - // (x == x) && abs(x) != inf + // abs(x) o!= inf, this works because the comparison returns false if + // either operand is NaN. auto type = operand_value->getType(); - auto equal_self = - ir_builder_->CreateFCmpOEQ(operand_value, operand_value); auto abs_value = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_); auto infinity = llvm::ConstantFP::getInfinity(type); auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); - auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite); return ir_builder_->CreateZExt( - result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: return ir_builder_->CreateFNeg(operand_value); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc index b43dc0c65d9b6e7c05e06010ba2ff2eb27392295..8980d4303353a132ada2b3c685b4f2856c33c6a1 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter_test.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -33,7 +33,7 @@ class ElementalIrEmitterExecutionTest : public HloTestBase { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text, config)); + ParseHloString(hlo_text, config)); EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), args, nullopt)); } }; diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 8119478ce934da06969024905e5e054e0b509b03..6df172db8e541c5cef7aab04f3d8611fc735e8b0 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -129,20 +129,6 @@ StatusOr Executable::ExecuteOnStreamWrapper( return return_value; } -Status Executable::DumpSessionModule() { - TF_RET_CHECK(dumping()); - const string& directory_path = - module_config().debug_options().xla_dump_executions_to(); - VersionedComputationHandle versioned_handle = entry_computation_handle(); - // This filename does not include the version number because the computation - // is only ever executed at one version. - string filename = tensorflow::strings::Printf( - "computation_%lld__%s__execution_%lld", versioned_handle.handle.handle(), - session_module_->entry().name().c_str(), ++execution_count_); - return Executable::DumpToDirectory(directory_path, filename, - *session_module_); -} - Status Executable::DumpHloSnapshot() { TF_RET_CHECK(dumping_snapshot()); TF_RET_CHECK(hlo_snapshot_->has_hlo() && @@ -156,26 +142,6 @@ Status Executable::DumpHloSnapshot() { return Executable::DumpToDirectory(directory_path, filename, *hlo_snapshot_); } -/* static */ Status Executable::DumpToDirectory( - const string& directory_path, string filename, - const SessionModule& session_module) { - tensorflow::Env* env = tensorflow::Env::Default(); - if (!env->IsDirectory(directory_path).ok()) { - // NB! CreateDir does not work reliably with multiple XLA threads -- two - // threads can race to observe the absence of the dump directory and - // simultaneously try to create it, causing the "losing" thread to get a - // "directory already exists" error. - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory_path)); - } - filename = SanitizeFileName(std::move(filename)); - string file_path = tensorflow::io::JoinPath(directory_path, filename); - string result; - TF_RET_CHECK( - tensorflow::SerializeToStringDeterministic(session_module, &result)); - return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path, - result); -} - /* static */ Status Executable::DumpToDirectory( const string& directory_path, string filename, const HloSnapshot& hlo_session) { diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 4f0466c544738fa1ec4602ee5104daee8d969c83..dc1f26ea65cc707d4f0522af2aa3ec40621632f1 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -27,9 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/shaped_buffer.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -132,26 +130,12 @@ class Executable { const HloModuleConfig& module_config() const { return hlo_module_->config(); } - // Returns the versioned computation handle of the computation computed by - // this executable. - const VersionedComputationHandle& entry_computation_handle() const { - return hlo_module_->entry_computation_handle(); - } - // The shape (including layout) that results from this execution. This is the // shape of the DeviceMemoryBase result value in ExecuteOnStream above. const Shape& host_result_shape() const { return hlo_module_->config().host_entry_computation_layout().result_shape(); } - // TODO(b/74197823): Delete the session module dumping helpers. - void set_session_module(std::unique_ptr session_module) { - session_module_ = std::move(session_module); - } - bool dumping() const { return session_module_ != nullptr; } - SessionModule* session_module() const { return session_module_.get(); } - Status DumpSessionModule(); - // Dumping helpers. void set_hlo_snapshot(std::unique_ptr hlo_snapshot) { hlo_snapshot_ = std::move(hlo_snapshot); @@ -160,10 +144,6 @@ class Executable { HloSnapshot* hlo_snapshot() const { return hlo_snapshot_.get(); } Status DumpHloSnapshot(); - // Dump session_module to directory_path/filename. - static Status DumpToDirectory(const string& directory_path, string filename, - const SessionModule& session_module); - // Dump hlo snapshot to directory_path/filename. static Status DumpToDirectory(const string& directory_path, string filename, const HloSnapshot& hlo_session); @@ -179,9 +159,6 @@ class Executable { // around. const std::unique_ptr hlo_module_; - // SessionModule this was compiled from. Null if not dumping executions. - std::unique_ptr session_module_; - // HloSnapshot this was compiled from. Null if not dumping executions. std::unique_ptr hlo_snapshot_; diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/service/g3doc/hlo_parser.md similarity index 100% rename from tensorflow/compiler/xla/tools/parser/README.md rename to tensorflow/compiler/xla/service/g3doc/hlo_parser.md diff --git a/tensorflow/compiler/xla/service/gather_expander_test.cc b/tensorflow/compiler/xla/service/gather_expander_test.cc index 1c72ca066502eb549bf8638cdf0b7827b06f92d7..020ffcd106862cb2641a9f3bceb70acdd969a458 100644 --- a/tensorflow/compiler/xla/service/gather_expander_test.cc +++ b/tensorflow/compiler/xla/service/gather_expander_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gather_expander.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -36,7 +36,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); Status status = GatherExpander{}.Run(module.get()).status(); EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); @@ -63,7 +63,7 @@ ENTRY main { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text)); + ParseHloString(hlo_text)); TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get())); ASSERT_TRUE(changed); diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 4012f87f2bf69d1ab056da5d6c750441c7404980..5e02631a5856a45762027d246144d6f6dd913c26 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1,6 +1,8 @@ # Description: # GPU-specific components in XLA service implementation. +load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") + licenses(["notice"]) # Apache 2.0 package(default_visibility = [":friends"]) @@ -23,6 +25,11 @@ filegroup( load("//tensorflow:tensorflow.bzl", "tf_cc_test") +xla_proto_library( + name = "backend_configs", + srcs = ["backend_configs.proto"], +) + cc_library( name = "gpu_constants", srcs = ["gpu_constants.cc"], @@ -133,6 +140,7 @@ cc_library( "ir_emitter_unnested.h", ], deps = [ + ":backend_configs", ":cudnn_convolution_runner", ":elemental_ir_emitter", ":gpu_constants", @@ -156,6 +164,7 @@ cc_library( "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", @@ -266,6 +275,7 @@ cc_library( "while_thunk.h", ], deps = [ + ":backend_configs", ":buffer_allocations", ":cudnn_convolution_runner", ":infeed_manager", @@ -322,6 +332,7 @@ cc_library( srcs = ["cudnn_convolution_algorithm_picker.cc"], hdrs = ["cudnn_convolution_algorithm_picker.h"], deps = [ + ":backend_configs", ":cudnn_convolution_runner", ":gpu_executable", ":ir_emission_utils", @@ -338,6 +349,7 @@ cc_library( srcs = ["cudnn_convolution_runner.cc"], hdrs = ["cudnn_convolution_runner.h"], deps = [ + ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", @@ -401,10 +413,41 @@ tf_cc_test( srcs = ["instruction_fusion_test.cc"], deps = [ ":instruction_fusion", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:multi_output_fusion", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "multi_output_fusion_test", + srcs = ["multi_output_fusion_test.cc"], + deps = [ + ":multi_output_fusion", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", ], ) @@ -446,9 +489,9 @@ tf_cc_test( ":instruction_fusion", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -508,6 +551,7 @@ cc_library( ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", + ":multi_output_fusion", ":pad_insertion", ":partition_assignment", ":stream_assignment", @@ -542,6 +586,8 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_constant_sinking", + "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", @@ -587,14 +633,18 @@ cc_library( srcs = ["gpu_layout_assignment.cc"], hdrs = ["gpu_layout_assignment.h"], deps = [ + ":gpu_options", ":ir_emission_utils", + ":stream_executor_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -691,6 +741,27 @@ cc_library( ], ) +cc_library( + name = "gpu_options", + srcs = ["gpu_options.cc"], + hdrs = ["gpu_options.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "stream_executor_util", + srcs = ["stream_executor_util.cc"], + hdrs = ["stream_executor_util.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + tf_cc_test( name = "gpu_hlo_support_checker_test", srcs = ["gpu_hlo_support_checker_test.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/backend_configs.proto b/tensorflow/compiler/xla/service/gpu/backend_configs.proto new file mode 100644 index 0000000000000000000000000000000000000000..640c6392b8b820c708b853c2a3cea4d4116e85a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/backend_configs.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package xla.gpu; + +// Backend configs for XLA:GPU. +// +// These are metadata that the GPU backend attaches to HloInstrucitons and later +// uses during e.g. codegen. +// +// Remember that proto3 doesn't give clients a way to tell the difference +// between a field not being present and a field having the default value. +// Choose your defaults carefully. +// +// No guarantee is made about the stability of these protos. +// +// See HloInstruction::backend_config() for more info. + +// Backend config for a convolution that runs through cudnn. +message CudnnConvBackendConfig { + // Opaque algorithm number of cudnn algorithm chosen for this conv. + int64 algorithm = 1; + + // Whether we may use tensor cores when running this conv. Even if this is + // true, cudnn may choose not to use tensor cores, e.g. because the GPU or + // selected algorithm doesn't support it. + bool tensor_ops_enabled = 2; +} diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 6a46bdb9b438f81dc564b9033f5d302f90b6a997..3dc98c4c93ea2b9b68dd3ee27794a39847f8756c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -316,21 +317,20 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( Shape new_call_shape = ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), ShapeUtil::MakeShape(U8, {scratch_bytes})}); - HloInstruction* algorithm_hlo = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(algorithm))); - HloInstruction* tensor_ops_enabled_hlo = - computation->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(tensor_ops_enabled))); + + CudnnConvBackendConfig backend_config; + backend_config.set_algorithm(algorithm); + backend_config.set_tensor_ops_enabled(tensor_ops_enabled); HloInstruction* new_call = computation->AddInstruction(HloInstruction::CreateCustomCall( new_call_shape, - {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo, - tensor_ops_enabled_hlo}, + {instr->mutable_operand(0), instr->mutable_operand(1)}, instr->custom_call_target())); new_call->set_window(instr->window()); new_call->set_convolution_dimension_numbers( instr->convolution_dimension_numbers()); + TF_RETURN_IF_ERROR(new_call->set_backend_config(backend_config)); // Repackage new_call so it has the same shape as the original call, namely // (conv_result, u8[0]). diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index 10b4c3de89989c52cfea5273c3d5b0beef76abd2..0645fbb3ad39f1f1649caf45a6068b5a196c30b9 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -113,8 +115,17 @@ Status RunCudnnConvolution( // cuDNN's convolution APIs support the BDYX layout for activations/output and // the OIYX layout for weights. + DataLayout input_dl; + FilterLayout filter_dl; + DataLayout output_dl; + + TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl), + XlaConvLayoutsToStreamExecutorLayouts( + dnums, input_shape.layout(), filter_shape.layout(), + output_shape.layout())); + BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(DataLayout::kBatchDepthYX) + input_descriptor.set_layout(input_dl) .set_feature_map_count( input_shape.dimensions(dnums.input_feature_dimension())) .set_count(input_shape.dimensions(dnums.input_batch_dimension())); @@ -126,7 +137,7 @@ Status RunCudnnConvolution( } FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(FilterLayout::kOutputInputYX) + filter_descriptor.set_layout(filter_dl) .set_input_feature_map_count( filter_shape.dimensions(dnums.kernel_input_feature_dimension())) .set_output_feature_map_count( @@ -149,7 +160,7 @@ Status RunCudnnConvolution( } BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(DataLayout::kBatchDepthYX) + output_descriptor.set_layout(output_dl) .set_feature_map_count( output_shape.dimensions(dnums.output_feature_dimension())) .set_count(output_shape.dimensions(dnums.output_batch_dimension())); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index e5e2a0478a0659986ddec8d6785827b14b9efb56..b812dd7d3fbb25f279e87f79b647e299f29073ea 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -53,11 +53,17 @@ using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; using tensorflow::strings::StrAppend; +namespace { // Returns whether operand is a floating-point literal with the given value. bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { - return operand->opcode() == HloOpcode::kConstant && - operand->literal().IsAllFloat(value); + if (operand->opcode() == HloOpcode::kConstant && + operand->literal().IsAllFloat(value)) { + return true; + } + return operand->opcode() == HloOpcode::kBroadcast && + IsFPLiteralWithValue(operand->operand(0), value); } +} // namespace GpuElementalIrEmitter::GpuElementalIrEmitter( const HloModuleConfig& hlo_module_config, llvm::Module* module, diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index 2217776c7d5a5f92c520d56222988f80401be9e4..b22bb1d39ba177ef42673c7a3755694b43c15d14 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace gpu { @@ -40,7 +40,7 @@ class FusionMergerTest : public HloTestBase {}; // Tuple // TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule MergeSharedFusionInstruction comp.3 { @@ -104,7 +104,7 @@ ENTRY MergeSharedFusionInstruction.Computation0 { // // Fusion2 is not merged because it exceeds the threshold flops-to-bytes ratio. TEST_F(FusionMergerTest, FlopsToBytesRatioThresholdExceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule FlopsToBytesRatioThresholdExceeded comp.2 { @@ -162,7 +162,7 @@ ENTRY FlopsToBytesRatioThresholdExceeded.Computation1 { // is merged into Fusion0 and Fusion1) would exceed the bytes transferred // threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdExeceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule BytesTransferredThresholdExeceeded comp.2 { @@ -210,7 +210,7 @@ ENTRY BytesTransferredThresholdExeceeded.Computation2 { // Fusion2 is reduced for this test which makes the merge operation into its // operand below the bytes transferred threshold. TEST_F(FusionMergerTest, BytesTransferredThresholdNotExeceeded) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule BytesTransferredThresholdNotExeceeded comp.2 { @@ -253,7 +253,7 @@ ENTRY BytesTransferredThresholdNotExeceeded.Computation2 { // Check that we're willing to merge f1_computation into f2_computation, even // though f2 is an input fusion node. TEST_F(FusionMergerTest, WillMergeIntoInputFusion) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule m f1_computation { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index d50153d8a31077e759bd6104d5bca8868a554fde..afefc740d707cd6fd01420e950eafd37abe80119 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" @@ -73,6 +74,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -160,8 +163,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true, - /*use_fusion=*/false); + /*rewrite_grad_op=*/true); // Rewrite gather ops into smaller ones. pass.AddPass(); @@ -174,6 +176,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -200,18 +203,28 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker(); pipeline.AddPass(); pipeline.AddPass(); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + + { + HloPassPipeline pipeline("layout_assignment"); + pipeline.AddPass( + hlo_module->mutable_device_entry_computation_layout(), stream_exec); + + // The LayoutAssignment pass may leave behind kCopy instructions which are + // duplicate or NOPs, so remove them with algebraic simplification and CSE. + pipeline.AddPass>( + /*is_layout_sensitive=*/true, + /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { + return true; + }); // Choose the fastest algorithm for each conv. // - // In theory doing this here is way too early: It needs to happen after - // layout assignment, because the layout of the inputs/outputs affects the - // speed of the conv. But currently we only allow only one input/output - // layout when calling cudnn, so there's no ambiguity. - // - // We pick the algorithm at this early stage so we can generate better HLO. - // After CudnnConvolutionRewriter, our convolutions are CustomCalls which - // return a tuple (conv_result, scratch_memory), and the each conv uses 0 - // bytes of scratch: + // We pick the algorithm before fusion so we can generate better HLO. After + // CudnnConvolutionRewriter, our convolutions are CustomCalls which return a + // tuple (conv_result, scratch_memory), and the each conv uses 0 bytes of + // scratch: // // customcall = (f32[...], f32[0]) // return gte(customcall, 0) @@ -227,35 +240,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // The new tuple and gte instructions then be simplified away, because // nobody is expected to use the scratch value. // - // However, if we were to run CudnnConvolutionAlgorithmPicker after layout - // assignment, fusion would already have run, and the gte(customcall, 0) - // would probably already be into a fusion node. We can't simplify across - // HloComputation boundaries, so in this case we wouldn't be able to - // simplify away the new_tuple bits. - // - // We'll need to revisit this if we ever allow multiple layouts for the - // inputs/outputs of a cudnn convolution. + // However, if we were to run CudnnConvolutionAlgorithmPicker after fusion + // the gte(customcall, 0) would probably already be into a fusion node. We + // can't simplify across HloComputation boundaries, so in this case we + // wouldn't be able to simplify away the new_tuple bits. pipeline.AddPass(stream_exec, device_allocator); // Clean up new_tuple described above. pipeline.AddPass(); - pipeline.AddPass(); - TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); - } - - { - HloPassPipeline pipeline("layout_assignment"); - pipeline.AddPass( - hlo_module->mutable_device_entry_computation_layout()); - - // The LayoutAssignment pass may leave behind kCopy instructions which are - // duplicate or NOPs, so remove them with algebraic simplification and CSE. - pipeline.AddPass>( - /*is_layout_sensitive=*/true, - /*valid_bitcast_callback=*/[](const Shape&, const Shape&) { - return true; - }); pipeline.AddPass(/*is_layout_sensitive=*/true); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -266,6 +259,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); + fusion.AddPass(); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); @@ -282,6 +276,15 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); } } + + { + // Do an aggressive LICM pass over while loops. In particular, this hoists + // constants that were sunk by WhileLoopConstantSinking. Leaving them in + // the while loop may result in unnecessary copies. + HloPassPipeline pipeline("while-loop-licm"); + pipeline.AddPass(true); + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index d9560779f313d5a559c3eb0f5b28ff5dd210d9d5..c5ccdd4a7dcec02ddab8a1f748659de41f6202d2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -78,12 +78,6 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } - } else if (IsCustomCallToDnnConvolution(*hlo)) { - // The last two arguments to a CUDNN convolution are two HLO constants for - // cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied. - for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { - TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); - } } else if (ImplementedAsLibraryCall(*hlo) || hlo->opcode() == HloOpcode::kCrossReplicaSum) { // For all other library calls and cross-replica-sum, materialize all the diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 89f1e625884568bf7370b3801d851ef4846c2a98..8bf62dde8b9948375fc493fd1a524cfa7b062502 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -18,31 +18,72 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace gpu { -// cuDNN convolutions are called with specific layouts on the input, output, -// and filter: -// -// input: DataLayout::kBatchDepthYX -// output: DataLayout::kBatchDepthYX -// filter: FilterLayout::kOutputInputYX -// -// The order dimensions in the constant name is major-to-minor (eg, the -// most-major dimension of the input is batch, most-minor is X). The -// specific dimension numbers these named dimensions correspond to is -// determined by the ConvolutionDimensionNumbers argument. Y is spatial -// dimension 0, and X is spatial dimension 1. -// -// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. -static Status AddBackendConstraintsToDnnConvCustomCall( +using stream_executor::dnn::DataLayout; +using stream_executor::dnn::FilterLayout; + +static bool IsVoltaOrLater(const se::StreamExecutor& stream_executor) { + int major, minor; + CHECK(stream_executor.GetDeviceDescription().cuda_compute_capability(&major, + &minor)); + return major >= 7; +} + +// Returns (input, filter, output) layouts. +static std::tuple +HeuristicLayoutAssignment(const HloInstruction* instr, + stream_executor::StreamExecutor* stream_executor) { + // DataLayout and FilterLayout uses weird enum names. Translations: + // N <=> Batch or Output + // C <=> Depth or Input + // H <=> Y + // W <=> X + // + // Therefore kOutputInputYX means NHWC; kBatchDepthYX means NCHW. + + // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x + // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version + // changes, as well as the hardware updates. + if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 && + IsVoltaOrLater(*stream_executor))) { + return std::make_tuple(DataLayout::kBatchDepthYX, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + } + VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString(); + // For BackwardInput that has stride, full NHWC layouts run significantly + // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC). + // + // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW, + // NHWC). + if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget && + window_util::HasStride(instr->window())) { + return std::make_tuple(DataLayout::kBatchYXDepth, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX); + } + return std::make_tuple(DataLayout::kBatchYXDepth, + FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth); +} + +// Adds layout constraints on the cudnn custom-call instruction. The layout +// constraints are represented in terms of minor_to_major fields of both +// operands and the output shape. Depending on the underlying algorithm, one of +// { NCHW, NHWC } ^ 3 = 8 different layout combinations may be chosen. +Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall( HloInstruction* instr, LayoutConstraints* constraints) { CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); Shape input_shape; @@ -66,39 +107,25 @@ static Status AddBackendConstraintsToDnnConvCustomCall( << instr->custom_call_target(); } - // Construct minor-to-major dimension orders for operands and result. - // cuDNN's convolution APIs support the BDYX layout for activations/output - // and the OIYX layout for weights. - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN - // calls after we switch to cuDNN v5. - const ConvolutionDimensionNumbers& dimension_numbers = - instr->convolution_dimension_numbers(); - std::vector input_layout; - for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0; - --i) { - input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); - } - input_layout.push_back(dimension_numbers.input_feature_dimension()); - input_layout.push_back(dimension_numbers.input_batch_dimension()); - *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); - - std::vector filter_layout; - for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0; - --i) { - filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); - } - filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension()); - filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension()); - *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); - - std::vector output_layout; - for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0; - --i) { - output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); + { + DataLayout input; + FilterLayout filter; + DataLayout output; + if (ConvUseLayoutHeuristic(instr->GetModule()->config())) { + std::tie(input, filter, output) = + HeuristicLayoutAssignment(instr, stream_executor_); + } else { + input = DataLayout::kBatchDepthYX; + filter = FilterLayout::kOutputInputYX; + output = DataLayout::kBatchDepthYX; + } + + TF_ASSIGN_OR_RETURN( + std::tie(*input_shape.mutable_layout(), *filter_shape.mutable_layout(), + *output_shape.mutable_layout()), + StreamExecutorConvLayoutsToXlaLayouts( + instr->convolution_dimension_numbers(), input, filter, output)); } - output_layout.push_back(dimension_numbers.output_feature_dimension()); - output_layout.push_back(dimension_numbers.output_batch_dimension()); - *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); // The custom call returns a tuple of (actual_result, scratch_buffer); // call_result_buf is the logical buffer for actual_result, the thing that @@ -132,7 +159,13 @@ static Status AddBackendConstraintsToDnnConvCustomCall( Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { - for (auto* instruction : constraints->computation()->instructions()) { + // Add convolution constraints in reverse postorder that the earliest + // convolution layout propagates first. This reduces the likelihood of fusion + // nodes with copies. + auto post_order = constraints->computation()->MakeInstructionPostOrder(); + for (auto iterator = post_order.rbegin(); iterator != post_order.rend(); + ++iterator) { + HloInstruction* instruction = *iterator; if (IsCustomCallToDnnConvolution(*instruction)) { TF_RETURN_IF_ERROR( AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 86a3a7111fd79494e469beecf3234f6cec9adb9c..ce24af1cf8856920ccf438b5bbd2ef28cfa8ba6f 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { @@ -27,8 +28,10 @@ namespace gpu { // layout constraints for operands and results of library calls. class GpuLayoutAssignment : public LayoutAssignment { public: - explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout) - : LayoutAssignment(entry_computation_layout) {} + explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout, + se::StreamExecutor* stream_executor) + : LayoutAssignment(entry_computation_layout), + stream_executor_(stream_executor) {} ~GpuLayoutAssignment() override {} protected: @@ -41,6 +44,12 @@ class GpuLayoutAssignment : public LayoutAssignment { LayoutConstraints* constraints) override; bool CustomCallRequiresMajorFirstLayout( const HloInstruction* instruction) override; + + private: + Status AddBackendConstraintsToDnnConvCustomCall( + HloInstruction* instr, LayoutConstraints* constraints); + + se::StreamExecutor* stream_executor_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 4c45d2e94aebce5496da94841f6a1ae9015615c1..e48165c1426ea04839c245bc20b851a0f1710246 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -69,7 +69,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape_with_layout); - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); for (const HloInstruction* operand : add->operands()) { @@ -156,7 +157,8 @@ TEST_F(LayoutAssignmentTest, BatchNormInference) { *computation_layout.mutable_result_layout() = ShapeLayout(result_shape); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -225,7 +227,8 @@ TEST_F(LayoutAssignmentTest, BatchNormTraining) { {result_shape, offset_scale_shape, offset_scale_shape})); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first operand to batchnorm should have the same layout as the @@ -305,7 +308,8 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { {result_shape, scale_shape, scale_shape})); } - GpuLayoutAssignment layout_assignment(&computation_layout); + GpuLayoutAssignment layout_assignment( + &computation_layout, backend().default_stream_executor()); EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); // The first and fourth operands to the batchnorm call should have the diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_options.cc new file mode 100644 index 0000000000000000000000000000000000000000..35b4b4e20b633792de4251a4b0e89f4b579053ce --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_options.cc @@ -0,0 +1,28 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_options.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace xla { +namespace gpu { + +bool ConvUseLayoutHeuristic(const HloModuleConfig& config) { + return !config.debug_options().xla_backend_extra_options().count( + "xla_gpu_experimental_conv_disable_layout_heuristic"); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/compiler/xla/service/gpu/gpu_options.h new file mode 100644 index 0000000000000000000000000000000000000000..498d4a94955cb2c50e0b165f28ded44ac1c0bfff --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_options.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ + +#include "tensorflow/compiler/xla/service/hlo_module_config.h" + +// Helper functions for querying options that are specific to the GPU backend. + +namespace xla { +namespace gpu { + +// Returns true if we should use heuristics to assign convolution layouts, as +// opposed to always assigning NCHW. +bool ConvUseLayoutHeuristic(const HloModuleConfig& config); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index f766f968826d960a8e86308f2395301aaa09f1ae..375709150e08996ea6a40f5e9e66a8f8d9287008 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -199,7 +199,7 @@ StatusOr> HloSchedule::Build( // concurrency by optimizing for minimal memory usage. TF_ASSIGN_OR_RETURN( schedule->thunk_launch_order_, - CreateMemoryMinimizingSequence( + ScheduleOneComputation( *entry_computation, [pointer_size](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); })); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc index e230d538cc2df826778e8d13eaaaf31ec81c57f0..45f0a1c645b2875cf90d2c11cfb66c3dd855d097 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc @@ -47,8 +47,7 @@ class HloScheduleTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", VersionedComputationHandle(), - config); + return MakeUnique("test_module", config); } HloVec RemoveHlo(const HloVec& input, diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index 3ddc1c0789d746bf021256638342364aac63e0e3..ae310beefad0c81c17fd4140b441b3a19a002e2c 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -49,13 +49,25 @@ void InfeedManager::EnqueueBuffers(const std::vector& buffers) { } InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { - tensorflow::mutex_lock l(mu_); - while (enqueued_buffer_.empty()) { - cv_.wait(l); + bool became_empty = false; + InfeedBuffer* current_buffer; + { + tensorflow::mutex_lock l(mu_); + while (enqueued_buffer_.empty()) { + cv_.wait(l); + } + current_buffer = enqueued_buffer_.front(); + enqueued_buffer_.pop_front(); + dequeued_buffer_.insert(current_buffer); + if (enqueued_buffer_.empty()) { + became_empty = true; + } + } + if (became_empty) { + for (const auto& callback : on_empty_callbacks_) { + callback(); + } } - InfeedBuffer* current_buffer = enqueued_buffer_.front(); - enqueued_buffer_.pop_front(); - dequeued_buffer_.insert(current_buffer); return current_buffer; } @@ -88,6 +100,10 @@ se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { return host_to_device_stream_.get(); } +void InfeedManager::RegisterOnEmptyCallback(std::function callback) { + on_empty_callbacks_.push_back(std::move(callback)); +} + InfeedManager* GetOrCreateInfeedManager() { static InfeedManager* manager = new InfeedManager; return manager; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h index d5f2216d460a45085536b15f9bf6e3bd3579f9c8..a3fc15cfe36a490f38daabca9ff36fbb1012aead 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -21,6 +21,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_ #include +#include #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/flatset.h" @@ -100,6 +101,10 @@ class InfeedManager { // returns null. se::Stream* GetStream(se::StreamExecutor* executor); + // Registers a callback that will be called when 'enqueued_buffer_' becomes + // empty. + void RegisterOnEmptyCallback(std::function callback); + private: // TODO(b/30467474): Revisit if this mutex becomes a point of // contention. @@ -122,6 +127,10 @@ class InfeedManager { // Executor that the host_to_device_stream belongs to. Not owned. se::StreamExecutor* host_to_device_executor_; + + // List of callbacks which will be called when 'enqueued_buffer_' becomes + // empty. + std::vector> on_empty_callbacks_; }; // Singleton creator-or-accessor: Returns the GPU infeed manager. diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 5d5bef6b57b57fce4255a145634745b38dccacc7..6c4519185b34989eb53c884ba214d69b824b113c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -77,15 +77,14 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, HloInstruction* producer = consumer->mutable_operand(operand_index); // Check if we can use output fusion for (A @ B) * alpha - if (consumer->operand_count() == 2 && - (producer->opcode() == HloOpcode::kDot || - (producer->opcode() == HloOpcode::kFusion && - producer->fused_expression_root()->opcode() == HloOpcode::kDot))) { + if (producer->opcode() == HloOpcode::kDot || + (producer->opcode() == HloOpcode::kFusion && + producer->fused_expression_root()->opcode() == HloOpcode::kDot)) { int64 other_operand_index = 1 - operand_index; - const HloInstruction* alpha = consumer->operand(other_operand_index); HloInstruction* op1 = nullptr; HloInstruction* op2 = nullptr; - if (consumer->opcode() == HloOpcode::kFusion && + if (consumer->operand_count() == 1 && + consumer->opcode() == HloOpcode::kFusion && consumer->fusion_kind() == HloInstruction::FusionKind::kLoop && Match(consumer->fused_expression_root(), match::Op() @@ -103,10 +102,12 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, op2->opcode() != HloOpcode::kBroadcast) { return false; } - if (IsIEEEFloatingPointScalarConstant(alpha)) { + if (IsIEEEFloatingPointScalarConstant(op2->operand(0))) { return true; } - } else if (consumer->opcode() == HloOpcode::kMultiply) { + } else if (consumer->operand_count() == 2 && + consumer->opcode() == HloOpcode::kMultiply) { + const HloInstruction* alpha = consumer->operand(other_operand_index); // Fuse if 'alpha' is a broadcast of a scalar constant. if (alpha->opcode() == HloOpcode::kBroadcast && alpha->dimensions().empty() && @@ -173,10 +174,38 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } + // Fuse scalar constants into loop fusion nodes, this reduces the number of + // parameters and makes matching scalar broadcasts easier. + if (ShapeUtil::IsEffectiveScalar(producer->shape()) && + consumer->opcode() == HloOpcode::kFusion && + producer->opcode() == HloOpcode::kConstant) { + return true; + } + return IsFusile(*producer) && IsFusile(*consumer) && InstructionFusion::ShouldFuse(consumer, operand_index); } +bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) { + const HloInstruction* producer = consumer->operand(operand_index); + // The IR emitter has limited support for non-loop fusions with multi output + // at present. + // TODO(tjoerg): Relax this constraint to allow for arbitraty kinds of fusion. + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() != HloInstruction::FusionKind::kLoop) { + return false; + } + // Multi-output fusion requires instructions with compatible shapes. + if (!ShapeUtil::Compatible(producer->shape(), consumer->shape())) { + return false; + } + // TODO(tjoerg): Stop calling `ShouldFuse` to relax the criteria for + // multi-output fusion. In particular, do not check whether an instruction is + // expensive to duplicate, since this doesn't matter here. + return GpuInstructionFusion::ShouldFuse(consumer, operand_index); +} + HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { if (IsReductionToVector(*consumer)) { diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h index 9fb06b0a244186484b1c17edf13bd28a4305a1a6..f629d9ff2c7165b652369612c30979150f93bd24 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h @@ -31,6 +31,9 @@ class GpuInstructionFusion : public InstructionFusion { bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; + bool ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) override; + HloInstruction::FusionKind ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) override; }; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 760e0e90f583d0e43975e23b731a40af75c7dc17..1963d9eef72d41fa0a275bea98f959671fa7e737 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -15,9 +15,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/util.h" namespace op = xla::testing::opcode_matchers; @@ -140,7 +143,7 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { // Tests that broadcasts fused into a fusion with a reduce root. TEST_F(InstructionFusionTest, BroadcastIntoReduce) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module add { @@ -165,11 +168,11 @@ TEST_F(InstructionFusionTest, BroadcastIntoReduce) { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Fusion()); EXPECT_THAT(root->fused_expression_root(), - op::Reduce(op::Broadcast(op::Parameter()), op::Parameter())); + op::Reduce(op::Broadcast(op::Constant()), op::Constant())); } TEST_F(InstructionFusionTest, BitcastIntoAdd) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY BroadcastIntoAdd { @@ -191,7 +194,7 @@ TEST_F(InstructionFusionTest, BitcastIntoAdd) { } TEST_F(InstructionFusionTest, AddIntoBitcast) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY BroadcastIntoAdd { @@ -213,7 +216,7 @@ TEST_F(InstructionFusionTest, AddIntoBitcast) { } TEST_F(InstructionFusionTest, DontFuseGTE) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY DontFuseGTE { p0 = (f32[10], f32[10]) parameter(0) @@ -229,7 +232,7 @@ TEST_F(InstructionFusionTest, DontFuseGTE) { } TEST_F(InstructionFusionTest, DotOutputFusion) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { alpha = f32[] constant(3) @@ -252,13 +255,13 @@ TEST_F(InstructionFusionTest, DotOutputFusion) { EXPECT_THAT( root->fused_expression_root(), op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())), - op::Broadcast(op::Parameter()))); + op::Broadcast(op::Constant()))); } // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is // duplicated and fused into both reduces. TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module Add { lhs = f32[] parameter(0) @@ -281,14 +284,15 @@ TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) { .ValueOrDie()); HloInstruction* root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())); + EXPECT_THAT(root, op::Tuple(op::Fusion(), op::Fusion())) + << module->ToString(); } // Compute sum(100/p0), where p0 has type s32, twice. Check that the division // is *not* duplicated and fused into both reduces, because we say that integer // division is not cheap. TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module Add { lhs = s32[] parameter(0) @@ -308,11 +312,12 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) { EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) .Run(module.get()) - .ValueOrDie()); + .ValueOrDie()) + << module->ToString(); } TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY NoOutputFusion { alpha = f32[] constant(3) @@ -334,7 +339,271 @@ TEST_F(InstructionFusionTest, DotOutputFusionImpossible) { EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); EXPECT_THAT(root->fused_expression_root(), op::Multiply(op::Multiply(op::Parameter(), op::Parameter()), - op::Broadcast(op::Parameter()))); + op::Broadcast(op::Constant()))); +} + +// Counts the HLO ops with a given op code in the specified module. +static int Count(const HloModule& module, HloOpcode op) { + int count = 0; + for (const auto* computation : module.computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == op) { + ++count; + } + } + } + return count; +} + +// Returns an HLO instruction from the given computation with the op code. +static StatusOr FindHloInstruction( + const HloComputation& computation, HloOpcode op) { + for (const auto* instruction : computation.instructions()) { + if (instruction->opcode() == op) { + return instruction; + } + } + return NotFound( + "Computation '%s' does not contain an instruction with op code '%s'.", + computation.name().c_str(), HloOpcodeString(op).c_str()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion) { + // sub --> add --> tuple + // \---------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + sub = f32[4,3]{1,0} subtract(p0, p2) + add = f32[4,3]{1,0} add(sub, p1) + ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT( + fusion->fused_expression_root(), + op::Tuple(op::Add(op::Subtract(), op::Parameter()), op::Subtract())); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionExpensiveOp) { + // tanh --> add --> tuple + // \---------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + tanh = f32[4,3]{1,0} tanh(p0) + add = f32[4,3]{1,0} add(tanh, p1) + ROOT tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}) tuple(tanh, add) + })") + .ValueOrDie(); + + // TODO(tjoerg): Allow multi-output fusion for expensive operations like tanh. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion2) { + // sub --> add1 --\--------\ + // \----------> add2 --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + sub = f32[4,3]{1,0} subtract(p0, p2) + add1 = f32[4,3]{1,0} add(sub, p1) + add2 = f32[4,3]{1,0} add(sub, add1) + ROOT tuple = (f32[4,3]{1,0}) tuple(add1, add2) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Add(op::Subtract(), op::Add()), + op::Add(op::Subtract(), op::Parameter()))); +} + +TEST_F(InstructionFusionTest, MultiOutputFusion3) { + // sub --> add1 ----\--------\ + // \ --> add2 --> add3 --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[4,3]{1,0} parameter(2) + p3 = f32[4,3]{1,0} parameter(3) + sub = f32[4,3]{1,0} subtract(p0, p2) + add1 = f32[4,3]{1,0} add(sub, p1) + add2 = f32[4,3]{1,0} add(p2, sub) + add3 = f32[4,3]{1,0} add(add1, add2) + ROOT tuple = (f32[4,3]{1,0}) tuple(add3, add2) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + SCOPED_TRACE(module->ToString()); + + // Expect that there is one multi-output fusion and subtract has not been + // duplicated. + EXPECT_EQ(Count(*module, HloOpcode::kFusion), 1); + EXPECT_EQ(Count(*module, HloOpcode::kSubtract), 1); + TF_ASSERT_OK_AND_ASSIGN( + const HloInstruction* fusion, + FindHloInstruction(*module->entry_computation(), HloOpcode::kFusion)); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Add(op::Add(), op::Add()), + op::Add(op::Parameter(), op::Subtract()))); +} + +TEST_F(InstructionFusionTest, NoCyclesDueToMultiOutputFusion) { + // sub --> mul ---\ + // \--> call --> add --> tuple + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + c = f32[] constant(42) + p0 = f32[4,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + sub = f32[4,3]{1,0} subtract(p0, p1) + mul = f32[4,3]{1,0} multiply(sub, c) + call = f32[4,3]{1,0} custom-call(sub), custom_call_target="foo" + add = f32[4,3]{1,0} add(mul, call) + ROOT tuple = (f32[4,3]{1,0}) tuple(add) + })") + .ValueOrDie(); + + ASSERT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + // Visit instructions in post order to detect cycles. + // TODO(tjoerg): Add cycle detection to the HloVerifier. + class DummyVisitor : public DfsHloVisitorWithDefault { + public: + DummyVisitor() {} + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + } visitor; + for (const HloComputation* computation : module->MakeComputationPostOrder()) { + // Accept will return a FailedPrecondition when a cycle is detected. + EXPECT_TRUE(computation->root_instruction()->Accept(&visitor).ok()); + } +} + +TEST_F(InstructionFusionTest, NoMultiOutputFusionWithIncompatibleShapes) { + // sub[2,3] --> add[4,3] --> tuple([2,3], [4,3]) + // \-------------------------/ + auto module = ParseHloString(R"( + HloModule test_module + ENTRY OutputFusion { + p0 = f32[2,3]{1,0} parameter(0) + p1 = f32[4,3]{1,0} parameter(1) + p2 = f32[2,3]{1,0} parameter(2) + sub = f32[2,3]{1,0} subtract(p0, p2) + add = f32[4,3]{1,0} add(sub, p1) + ROOT tuple = (f32[2,3]{1,0}, f32[4,3]{1,0}) tuple(sub, add) + })") + .ValueOrDie(); + + // Multi-output fusion requires shapes to be compatible. Since `sub` and `add` + // have incompatible shapes, expect that no multi-output fusion happens. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseIntoInputFusionInstruction) { + auto module = ParseHloString(R"( + HloModule test_module + + add_computation { + add_lhs = f32[] parameter(0) + add_rhs = f32[] parameter(1) + ROOT add_root = f32[] add(add_lhs, add_rhs) + } + + fused_computation { + p1 = f32[10] parameter(0) + zero = f32[] constant(0) + ROOT f2_root = f32[] reduce(p1, zero), dimensions={0}, + to_apply=add_computation + } + + ENTRY entry { + p0 = f32[10] parameter(0) + mul = f32[10] multiply(p0, p0) + fusion = f32[] fusion(mul), kind=kInput, calls=fused_computation + ROOT tuple = (f32[10], f32[]) tuple(fusion, mul) + })") + .ValueOrDie(); + + // Multi-output fusion is not supported for non-loop fusions at present. Since + // `fused_computation` is a input fusion, expect no multi-output fusion to + // happen. + ASSERT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()) + << module->ToString(); +} + +TEST_F(InstructionFusionTest, FuseScalarConstant) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY FuseScalarConstant { + p0 = f32[] parameter(0) + c0 = f32[] constant(1) + add1 = f32[] add(p0, c0) + b0 = f32[2]{0} broadcast(add1), dimensions={} + c1 = f32[2]{0} constant({1, 2}) + ROOT add2 = f32[2]{0} add(b0, c1) + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Add(op::Broadcast(op::Add(op::Parameter(), op::Constant())), + op::Parameter())); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 22e715099526c20532bb298e84e50457d89f615e..67890bfed1136796c83c7ef6912ffc1ab1b7e332 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -162,19 +162,8 @@ static HloInstruction* CreateCudnnConv( Shape call_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); - // Our CustomCall takes four arguments: The conv lhs and rhs, the cudnn - // algorithm to use, and a boolean indicating whether to use tensor cores. - // - // It's up to a later pass to choose the algorithm and decide whether to use - // tensor cores, so to indicate that we haven't yet made a choice, we speicfy - // -1 and false for those args. - HloInstruction* negative_one = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-1))); - HloInstruction* false_constant = computation->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(false))); - HloInstruction* custom_call = - computation->AddInstruction(HloInstruction::CreateCustomCall( - call_shape, {lhs, rhs, negative_one, false_constant}, call_target)); + HloInstruction* custom_call = computation->AddInstruction( + HloInstruction::CreateCustomCall(call_shape, {lhs, rhs}, call_target)); custom_call->set_window(window); custom_call->set_convolution_dimension_numbers(dnums); return custom_call; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 1e0db2821a2c212d0f212ae94ab69231bc6053ea..547af33e9a98c03e1429366172f9a401e385a9d1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -94,7 +94,10 @@ Status IrEmitter::HandleConstant(HloInstruction* constant) { << std::endl << " its type: " << llvm_ir::DumpToString(*global_for_const->getType()); - bindings_.BindHloToIrValue(*constant, global_for_const); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global_for_const, + llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); + bindings_.BindHloToIrValue(*constant, shape_constant); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index b0accc08d479258d65a18202122e4c9e90ff78d0..e55dfc6dae844ceb1d28ad389d133c80823bad9a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -120,10 +120,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* GetBasePointer(const HloInstruction& inst) const { return bindings_.GetBasePointer(inst); } - // A convenient helper for calling BufferAssignment::GetUniqueTopLevelSlice. - BufferAllocation::Slice GetAllocationSlice(const HloInstruction& hlo) const { + // A convenient helper for calling BufferAssignment::GetUniqueSlice. + BufferAllocation::Slice GetAllocationSlice( + const HloInstruction& hlo, const ShapeIndex& index = {}) const { return ir_emitter_context_->buffer_assignment() - .GetUniqueTopLevelSlice(&hlo) + .GetUniqueSlice(&hlo, index) .ConsumeValueOrDie(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 55d4c1d13d3ad41e09d48db70478cf5e6af59808..726434c3dfd4f1ef866d2eb9b6d7eb8b659e0984 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" @@ -58,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" @@ -79,6 +81,7 @@ namespace { using llvm_ir::IrName; using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::InlinedVector; using tensorflow::gtl::nullopt; using tensorflow::gtl::optional; using tensorflow::strings::StrCat; @@ -422,15 +425,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); - const HloInstruction* algorithm_inst = custom_call->operand(2); - CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString(); - int64 algorithm = algorithm_inst->literal().Get({}); - - const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3); - CHECK(tensor_ops_enabled_inst->IsConstant()) - << tensor_ops_enabled_inst->ToString(); - bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get({}); - + TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config, + custom_call->backend_config()); const auto& target = custom_call->custom_call_target(); std::unique_ptr thunk; if (target == kCudnnConvForwardCallTarget) { @@ -445,7 +441,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardInput, @@ -458,7 +455,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardFilter, @@ -471,7 +469,8 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, tensor_ops_enabled, custom_call); + backend_config.algorithm(), backend_config.tensor_ops_enabled(), + custom_call); } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); @@ -499,12 +498,31 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // initializes the output array to the initial value of the reduce. if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) { switch (root->opcode()) { + case HloOpcode::kTuple: case HloOpcode::kReduce: { VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString(); - TF_ASSIGN_OR_RETURN(std::unique_ptr initializer_thunk, - BuildInitializerThunk(fusion)); std::vector> thunks; - thunks.push_back(std::move(initializer_thunk)); + ArraySlice output_instructions = + root->opcode() == HloOpcode::kTuple + ? root->operands() + : ArraySlice(&root, 1); + + // For multi-output fusion emit an initializer for each tuple element. + // Otherwise it's sufficient to just initialize the single output. + HloInstruction* first_reduce = nullptr; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + if (output_instructions[i]->opcode() == HloOpcode::kReduce) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr initializer_thunk, + BuildInitializerThunk(fusion, output_instructions[i] == root + ? ShapeIndex() + : ShapeIndex({i}))); + thunks.push_back(std::move(initializer_thunk)); + first_reduce = + first_reduce == nullptr ? output_instructions[i] : first_reduce; + } + } + CHECK(first_reduce != nullptr); thunks.push_back(BuildKernelThunk(fusion)); thunk_sequence_->emplace_back( MakeUnique(std::move(thunks), fusion)); @@ -518,11 +536,50 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); - Shape input_shape = root->operand(0)->shape(); - return EmitReductionToVector( - root, input_shape, fused_emitter.GetGenerator(root->operand(0)), - fused_emitter.GetGenerator(root->operand(1)), root->dimensions(), - root->to_apply()); + // For multi-output fusion CHECK the constraints and feed all the + // reduces into a single loop code generator. Single-output reduce + // fusion is a special case of that. + InlinedVector input_gens; + InlinedVector init_value_gens; + std::vector> + extra_output_gens; + InlinedVector reducers; + InlinedVector reduce_output_shapes; + for (int i = 0, e = output_instructions.size(); i != e; ++i) { + const HloInstruction* inst = output_instructions[i]; + ShapeIndex output_shape_index; + if (root->opcode() == HloOpcode::kTuple) { + output_shape_index = {i}; + } + // TODO(kramerb): CHECK that layouts are equal. Currently this + // breaks multioutputfusion_test. The test has pre-fused + // instructions, but layout_assignment will not assign any layouts + // for instructions inside of a fused computation. It just removes + // the layouts instead. + if (inst->opcode() == HloOpcode::kReduce) { + CHECK(ShapeUtil::Compatible(first_reduce->shape(), inst->shape())); + CHECK(ShapeUtil::Compatible(first_reduce->operand(0)->shape(), + inst->operand(0)->shape())); + CHECK(ShapeUtil::Compatible(first_reduce->operand(1)->shape(), + inst->operand(1)->shape())); + CHECK(first_reduce->dimensions() == inst->dimensions()); + input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0))); + init_value_gens.push_back( + fused_emitter.GetGenerator(inst->operand(1))); + reducers.push_back(inst->to_apply()); + reduce_output_shapes.push_back(std::move(output_shape_index)); + } else { + CHECK(ShapeUtil::Compatible(first_reduce->operand(0)->shape(), + inst->shape())); + extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst), + std::move(output_shape_index)); + } + } + const Shape& input_shape = first_reduce->operand(0)->shape(); + return EmitReductionToVector(first_reduce, input_shape, input_gens, + init_value_gens, + first_reduce->dimensions(), reducers, + reduce_output_shapes, extra_output_gens); } default: LOG(FATAL) << "Bad opcode for input fusion: " @@ -907,10 +964,33 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { return IrEmitter::HandleCopy(copy); } +Status IrEmitterUnnested::EmitExtraOutputsForReduce( + const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { + for (int i = 0; i != extra_output_gens.size(); ++i) { + const HloInstruction* output = reduce->parent()->FusionInstruction(); + llvm::Value* extra_output_address = + GetIrArray(*output, *output, extra_output_gens[i].second) + .EmitArrayElementAddress(index, &ir_builder_, + "extra_output_element_address"); + TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, + extra_output_gens[i].first(index)); + ir_builder_.CreateStore(extra_output_ir_value, extra_output_address); + } + return Status::OK(); +} + Status IrEmitterUnnested::EmitReductionToScalar( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // Number of elements processed by a single thread. constexpr int64 kTileSize = 16; int64 num_elems = ShapeUtil::ElementsIn(input_shape); @@ -962,14 +1042,19 @@ Status IrEmitterUnnested::EmitReductionToScalar( // auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index({}))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } llvm::Value* x_in_tiles = tile_index[0]; @@ -1002,11 +1087,16 @@ Status IrEmitterUnnested::EmitReductionToScalar( llvm_ir::IrArray::Index input_index( /*linear=*/x, input_shape, &ir_builder_); llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); - return (EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, extra_output_gens); }; // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's @@ -1041,20 +1131,24 @@ Status IrEmitterUnnested::EmitReductionToScalar( : element_ir_type; for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_address, - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, result_from_other_lane}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), + &ir_builder_), + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], result_from_other_lane}, + partial_reduction_result_addresses[i])); + } } const HloInstruction* output = @@ -1070,14 +1164,21 @@ Status IrEmitterUnnested::EmitReductionToScalar( "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(/*linear=*/ir_builder_.getInt64(0), - output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + /*linear=*/ir_builder_.getInt64(0), + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_reduction_result_addresses[i])); + } + return Status::OK(); }; // Emit a parallel loop that iterates through all input tiles, one per thread. @@ -1097,8 +1198,13 @@ Status IrEmitterUnnested::EmitReductionToScalar( Status IrEmitterUnnested::EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // Divide the input matrix into tiles of size Kx1. For example, when the // input matrix is 4x4 and K=2, the tiled matrix looks like // @@ -1108,9 +1214,13 @@ Status IrEmitterUnnested::EmitColumnReduction( // 4567 // Numbers indicate tile IDs. // // Each tile is first partially reduced to a scalar by a thread, and then the - // scalar is accumulated to the output vector using atomic operations. We - // choose 16 as the tile size, which matches Eigen's ColumnReduceKernel. - constexpr int64 kTileSize = 16; + // scalar is accumulated to the output vector using atomic operations. + // + // We choose 128 as the tile size based on empirical evidence. It's big enough + // to reduce the amount of atomic adds in the end, maximizing the memory + // bandwidth. + constexpr int64 kTileSize = 128; + // If the height is not a multiple of the tile size, we pad the bottom of the // input matrix. const int64 height_in_tiles = CeilOfRatio(height, kTileSize); @@ -1140,15 +1250,20 @@ Status IrEmitterUnnested::EmitColumnReduction( // } auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + const int num_reduces = reducers.size(); // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index({}))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } // Emit an inner for-loop that partially reduces the elements in the given @@ -1206,13 +1321,18 @@ Status IrEmitterUnnested::EmitColumnReduction( .SourceIndexOfTranspose(normalized_input_shape, input_shape, transpose_dimension_mapping, &ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, - input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens); } - return (EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address)); }; // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's @@ -1241,13 +1361,20 @@ Status IrEmitterUnnested::EmitColumnReduction( &ir_builder_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + x, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, partial_reduction_result_addresses[i])); + } + return Status::OK(); }; // Emit a parallel loop that iterate through all input tiles. @@ -1265,12 +1392,42 @@ Status IrEmitterUnnested::EmitColumnReduction( .EmitLoop(IrName(reduce)); } +static std::pair ComputeTilingSchemeForReduction( + int64 depth, int64 width, int64 kWarpSize) { + constexpr int64 kTargetNumElementsPerThread = 64; + int64 x_tile_size = kTargetNumElementsPerThread; + int64 z_tile_size = 1; + + // Only tile along the x dimension with tile size kTargetNumElementsPerThread + // if doing so doesn't require a slow version of loop with bound check on each + // dimension. A more sophisticated heuristics is to enable tile along the + // x dimension with tile size kTargetNumElementsPerThread when either width is + // a factor of (kWarpSize * kTargetNumElementsPerThread) or width is big + // enough so that only a small fraction of the threads execute the slow + // version of loop with bound check. + if (width % (kWarpSize * kTargetNumElementsPerThread) != 0) { + x_tile_size = 8; + z_tile_size = 8; + while (depth % z_tile_size != 0) { + z_tile_size -= 1; + } + } + + return std::pair(x_tile_size, z_tile_size); +} + Status IrEmitterUnnested::EmitRowReduction( int64 depth, int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // A naive algorithm is: - // 1. Divide the input tensor into tiles of size 1x1xK. + // 1. Divide the x dimension of the input tensor into tiles of size 1x1xX. // 2. Partially reduces each tile to a scalar using one thread. // 3. Accumulates that scalar to the output vector using atomic operations. // @@ -1281,15 +1438,15 @@ Status IrEmitterUnnested::EmitRowReduction( // int y = linear_index / width_in_tiles % height; // int z = linear_index / (height * width_in_tiles); // float partial_result = 0; - // for (element_id_in_tile : range(kTileSize)) { - // int x = x_in_tiles * kTileSize + element_id_in_tile; + // for (element_id_in_tile : range(x_tile_size)) { + // int x = x_in_tiles * x_tile_size + element_id_in_tile; // if (x < width) // partial_result = reducer(partial_result, input[z][y][z]); // } // AtomicReducer(&output[y], partial_result); // } // - // Three optimizations are performed. + // Four optimizations are performed. // // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead @@ -1316,29 +1473,44 @@ Status IrEmitterUnnested::EmitRowReduction( // element_id_in_tile, which makes the code more friendly to optimizations // such as LICM. // + // 4. When the width is too small and x_tile_size is less than the target + // number of elements per thread and use a small factor of depth as + // z_tile_size to increase the number of elements calculated by each + // partial sum. This can reduce the needed number of dynamic shfl_down and + // atomic operations. + // // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x; // linear_index < depth * height * width_in_tiles; // linear_index += blockDim.x * gridDim.x) { // int x_in_tiles = linear_index % width_in_tiles; // int y = linear_index / width_in_tiles % height; - // int z = linear_index / (height * width_in_tiles); + // int z_in_tiles = linear_index / (height * width_in_tiles); // int warp_id = x_in_tiles / warpSize; // int lane_id = x_in_tiles % warpSize; // float partial_result = 0; // int x = warp_id * kTileSize * warpSize + lane_id; - // if (width % (kTileSize * warpSize) == 0 || - // x + (kTileSize - 1) * warpSize < width) { - // // The entire tile is in bounds. - // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; - // ++element_id_in_tile, x += warpSize) { - // partial_result = Reducer(partial_result, input[z][y][x]); + // if (width % (x_tile_size * warpSize) == 0 || + // x + (x_tile_size - 1) * warpSize < width) { + // // The entire x_tile is in bounds. + // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; + // ++element_id_in_z_tile) { + // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // for (int element_id_in_x_tile = 0;element_id_in_x_tile < x_tile_size; + // ++element_id_in_x_tile, x += warpSize) { + // partial_result = Reducer(partial_result, input[z][y][x]); + // } // } // } else { // // The tile is partially in bounds. - // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize; + // for (int element_id_in_z_tile = 0; element_id_in_z_tile < z_tile_size; + // ++element_id_in_z_tile) { + // z = z_in_tiles * z_tile_size + element_id_in_z_tile; + // for (int element_id_in_x_tile = 0; element_id_in_x_tile < + // x_tile_size; // ++element_id_in_tile, x += warpSize) { - // if (x < width) - // partial_result = Reducer(partial_result, input[z][y][x]); + // if (x < width) + // partial_result = Reducer(partial_result, input[z][y][x]); + // } // } // } // for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2) @@ -1349,29 +1521,35 @@ Status IrEmitterUnnested::EmitRowReduction( // AtomicReducer(&output[y], partial_result); // } // - // Choose 8 as the tile size, which matches Eigen's RowReduceKernel. - constexpr int64 kTileSize = 8; + + int64 x_tile_size; + int64 z_tile_size; + std::tie(x_tile_size, z_tile_size) = + ComputeTilingSchemeForReduction(depth, width, kWarpSize); + // Round the width in tiles up to the nearest multiple of kWarpSize, so that // the use of shfl_down is valid. const int64 width_in_tiles = - RoundUpToNearest(CeilOfRatio(width, kTileSize), kWarpSize); + RoundUpToNearest(CeilOfRatio(width, x_tile_size), kWarpSize); - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& tile_index) -> Status { - // Emit the loop body that reduces one tile. + auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) { + // Emit the loop body that reduces one z-x-tile. + const int num_reduces = reducers.size(); llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( input_shape.element_type(), ir_emitter_context_->llvm_module()); - llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); - { - TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, - init_value_gen(llvm_ir::IrArray::Index({}))); + std::vector partial_reduction_result_addresses; + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, + "partial_reduction_result." + llvm::Twine(i)); + TF_ASSIGN_OR_RETURN(llvm::Value* const init_ir_value, + init_value_gens[i](llvm_ir::IrArray::Index({}))); ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + partial_reduction_result_addresses.push_back( + partial_reduction_result_address); } - // Emit an inner for-loop that partially reduces the elements in the given - // tile. - llvm::Value* z = tile_index[0]; + llvm::Value* z_tile = tile_index[0]; llvm::Value* y = tile_index[1]; llvm::Value* x_tile = tile_index[2]; llvm::Value* warp_id = ir_builder_.CreateUDiv( @@ -1379,102 +1557,132 @@ Status IrEmitterUnnested::EmitRowReduction( llvm::Value* lane_id = ir_builder_.CreateURem( x_tile, ir_builder_.getInt64(kWarpSize), "lane_id"); - // The x-location of the last element in this tile. - // last_x = lane_id + warpSize * (kTileSize - 1 + warp_id * kTileSize); + // The x-location of the last element in this z-x-tile. + // last_x = lane_id + warpSize * (x_tile_size - 1 + warp_id * + // x_tile_size); llvm::Value* last_x = ir_builder_.CreateNSWAdd( - lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - ir_builder_.getInt64(kTileSize - 1), - ir_builder_.CreateNSWMul(warp_id, - ir_builder_.getInt64(kTileSize))))); - - auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { - std::unique_ptr tile_element_loop = - llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", - ir_builder_.getInt64(0), - ir_builder_.getInt64(kTileSize), - ir_builder_.getInt64(1), &ir_builder_); - - // Emit the body of the partial reduction loop. - llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), - &ir_builder_); - // x = lane_id + warpSize * (element_id_in_tile + warp_id * kTileSize); - llvm::Value* x = ir_builder_.CreateNSWAdd( - lane_id, - ir_builder_.CreateNSWMul( - ir_builder_.getInt64(kWarpSize), - ir_builder_.CreateNSWAdd( - tile_element_loop->GetIndVarValue(), - ir_builder_.CreateNSWMul(warp_id, - ir_builder_.getInt64(kTileSize))))); - - // Unless we know the tile is entirely in bounds, we have to emit a - // x-in-bounds check before reading from the input. - if (!tile_in_bounds) { - llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse( - ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(width)), - "x_in_bounds", &ir_builder_); - - // Points ir_builder_ to the then-block. - llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, - &ir_builder_); - } + lane_id, ir_builder_.CreateNSWMul( + ir_builder_.getInt64(kWarpSize), + ir_builder_.CreateNSWAdd( + ir_builder_.getInt64(x_tile_size - 1), + ir_builder_.CreateNSWMul( + warp_id, ir_builder_.getInt64(x_tile_size))))); + + KernelSupportLibrary ksl( + &ir_builder_, + /*unroll_mode=*/xla::llvm_ir::UnrollMode::kFullyUnroll, + /*prevent_vectorization=*/false); + + // Emit a for-loop that partially reduces the elements in the given + // z-x-tile. + auto emit_z_x_tile_element_loop = [&](bool x_tile_in_bounds, + int64 x_tile_loop_bound) -> Status { + auto emit_z_tile_element_loop = [&](llvm::Value* z_indvar) -> Status { + llvm::Value* z = ir_builder_.CreateNSWAdd( + z_indvar, ir_builder_.CreateNSWMul( + ir_builder_.getInt64(z_tile_size), z_tile)); + + TF_RETURN_IF_ERROR(ksl.For( + "x_tile", + /*start=*/0, /*end=*/x_tile_loop_bound, /*step=*/1, + [&](llvm::Value* x_indvar) -> Status { + // x = lane_id + warpSize * (element_id_in_x_tile + warp_id * + // x_tile_size); + llvm::Value* x = ir_builder_.CreateNSWAdd( + lane_id, + ir_builder_.CreateNSWMul( + ir_builder_.getInt64(kWarpSize), + ir_builder_.CreateNSWAdd( + x_indvar, + ir_builder_.CreateNSWMul( + warp_id, ir_builder_.getInt64(x_tile_size))))); + + // Unless we know the x-tile is entirely in bounds, we have to + // emit a x-in-bounds check before reading from the input. + if (!x_tile_in_bounds) { + llvm_ir::LlvmIfData if_x_in_bounds_data = + llvm_ir::EmitIfThenElse(ir_builder_.CreateICmpULT( + x, ir_builder_.getInt64(width)), + "x_in_bounds", &ir_builder_); + // Points ir_builder_ to the then-block. + llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block, + &ir_builder_); + } + + // Emit code that reads the input element and accumulates it + // to the partial reduction result. + llvm::Value* input_address = + ir_builder_.CreateAlloca(element_ir_type); + { + // {z,y,x} is an index to input_3d_tensor_shape + // [depth,height,width]. We need to convert that to an index + // to input_shape (the shape of the operand of "reduce"). + // This conversion is composed of a transposition from + // input_shape to normalized_input_shape and a reshape from + // normalized_input_shape to input_3d_tensor_shape. + const Shape normalized_input_shape = ShapeUtil:: + MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto input_shape_min2maj = + LayoutUtil::MinorToMajor(input_shape); + const std::vector transpose_dimension_mapping( + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); + const Shape input_3d_tensor_shape = + ShapeUtil::MakeShapeWithDescendingLayout( + input_shape.element_type(), {depth, height, width}); + const llvm_ir::IrArray::Index input_3d_tensor_index( + {z, y, x}, input_3d_tensor_shape, &ir_builder_); + const llvm_ir::IrArray::Index input_index = + input_3d_tensor_index + .SourceIndexOfReshape(input_3d_tensor_shape, + normalized_input_shape, + &ir_builder_) + .SourceIndexOfTranspose( + normalized_input_shape, input_shape, + transpose_dimension_mapping, &ir_builder_); + + for (int i = 0; i != num_reduces; ++i) { + TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, + input_gens[i](input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], input_address}, + partial_reduction_result_addresses[i])); + } + return EmitExtraOutputsForReduce(reduce, input_index, + extra_output_gens); + } + })); + return Status::OK(); + }; - // Emit code that reads the input element and accumulates it to the - // partial reduction result. - llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); - { - // {z,y,x} is an index to input_3d_tensor_shape [depth,height,width]. We - // need to convert that to an index to input_shape (the shape of the - // operand of "reduce"). This conversion is composed of a transposition - // from input_shape to normalized_input_shape and a reshape from - // normalized_input_shape to input_3d_tensor_shape. - const Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); - const std::vector transpose_dimension_mapping( - input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); - const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), - {depth, height, width}); - const llvm_ir::IrArray::Index input_3d_tensor_index( - {z, y, x}, input_3d_tensor_shape, &ir_builder_); - const llvm_ir::IrArray::Index input_index = - input_3d_tensor_index - .SourceIndexOfReshape(input_3d_tensor_shape, - normalized_input_shape, &ir_builder_) - .SourceIndexOfTranspose(normalized_input_shape, input_shape, - transpose_dimension_mapping, - &ir_builder_); - TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, - input_gen(input_index)); - ir_builder_.CreateStore(input_ir_value, input_address); - } - return EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, input_address}, - partial_reduction_result_address); + return ksl.For("z_tile", + /*start=*/0, /*end=*/z_tile_size, /*step=*/1, + emit_z_tile_element_loop); }; llvm::Value* tile_in_bounds = ir_builder_.CreateOr( - ir_builder_.getInt1(width % (kTileSize * kWarpSize) == 0), + ir_builder_.getInt1(width % (x_tile_size * kWarpSize) == 0), ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width))); - llvm_ir::LlvmIfData if_tile_in_bounds_data = - llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, - &ir_builder_); - TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); - // After the if-then-else statement on tile_in_bounds, emit calls to - // shfl_down that accumulate the partial reduction results of all threads - // from the warp. - llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, - &ir_builder_); + TF_RETURN_IF_ERROR( + ksl.If(tile_in_bounds, + /*true_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop(/*x_tile_in_bounds=*/true, + x_tile_size); + }, + /*false_block_generator=*/ + [&]() -> Status { + return emit_z_x_tile_element_loop( + /*x_tile_in_bounds=*/false, + CeilOfRatio(width % (x_tile_size * kWarpSize), kWarpSize)); + })); + + // After accumulating the elements of the z_x_tile, emit calls to + // shfl_down that accumulate the partial reduction results of all + // threads in a warp. int bit_width = llvm_ir::GetSizeInBits(element_ir_type); // bitcast cannot be applied to aggregate types (even packed ones), so we // instead bitcast addresses of load/store to intN* of the same bit-width. @@ -1483,20 +1691,24 @@ Status IrEmitterUnnested::EmitRowReduction( : element_ir_type; for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { - llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(partial_reduction_result_address, - shuffle_ir_type->getPointerTo()), - "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); - ir_builder_.CreateStore( - EmitShuffleDown(partial_reduction_result, - ir_builder_.getInt32(shuffle_distance), &ir_builder_), - ir_builder_.CreateBitCast(result_from_other_lane, - shuffle_ir_type->getPointerTo())); - TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - *reducer, {partial_reduction_result_address, result_from_other_lane}, - partial_reduction_result_address)); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(partial_reduction_result_addresses[i], + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), + &ir_builder_), + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {partial_reduction_result_addresses[i], result_from_other_lane}, + partial_reduction_result_addresses[i])); + } } const HloInstruction* output = @@ -1510,19 +1722,34 @@ Status IrEmitterUnnested::EmitRowReduction( "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = - GetIrArray(*output, *output) - .EmitArrayElementAddress( - llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), - &ir_builder_, "output_element_address"); - return EmitAtomicOperationForNestedComputation( - *reducer, output_address, partial_reduction_result_address); + for (int i = 0; i != num_reduces; ++i) { + llvm::Value* output_address = + GetIrArray(*output, *output, reduce_output_shapes[i]) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index( + y, + ShapeUtil::GetSubshape(output->shape(), + reduce_output_shapes[i]), + &ir_builder_), + &ir_builder_, "output_element_address"); + if (x_tile_size * z_tile_size < depth * width) { + TF_RETURN_IF_ERROR(EmitAtomicOperationForNestedComputation( + *reducers[i], output_address, + partial_reduction_result_addresses[i])); + } else { + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducers[i], + {output_address, partial_reduction_result_addresses[i]}, + output_address)); + } + } + return Status::OK(); }; // Emit a parallel loop that iterates through every input tiles. Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( - reduce->shape().element_type(), {depth, height, width_in_tiles}, - {2, 1, 0}); + reduce->shape().element_type(), + {depth / z_tile_size, height, width_in_tiles}, {2, 1, 0}); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( tiled_input_shape, ir_emitter_context_->device_description()); CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); @@ -1543,10 +1770,14 @@ Status IrEmitterUnnested::EmitRowReduction( // elementwise. Status IrEmitterUnnested::EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer) { + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens) { // This emission requires "reduce" to have an input layout. It is either set // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for // a fused kReduce). @@ -1581,8 +1812,9 @@ Status IrEmitterUnnested::EmitReductionToVector( // `EmitReductionToVector`, we only need to check whether the minormost // dimension of the input is to keep. if (input_dims_to_keep.empty()) { - return EmitReductionToScalar(reduce, input_shape, input_gen, init_value_gen, - reducer); + return EmitReductionToScalar(reduce, input_shape, input_gens, + init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); } else if (input_dims_to_keep.front() == LayoutUtil::Minor(input_shape.layout(), 0)) { // Column reduction. Treat the result of "input" as a matrix whose width @@ -1599,8 +1831,9 @@ Status IrEmitterUnnested::EmitReductionToVector( height *= input_shape.dimensions(input_dim); } } - return EmitColumnReduction(height, width, reduce, input_shape, input_gen, - init_value_gen, reducer); + return EmitColumnReduction(height, width, reduce, input_shape, input_gens, + init_value_gens, reducers, reduce_output_shapes, + extra_output_gens); } else { // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a // 3D tensor. The size of dimension 1 (the height) is the size of the @@ -1626,7 +1859,8 @@ Status IrEmitterUnnested::EmitReductionToVector( } const int64 height = ShapeUtil::ElementsIn(reduce->shape()); return EmitRowReduction(depth, height, width, reduce, input_shape, - input_gen, init_value_gen, reducer); + input_gens, init_value_gens, reducers, + reduce_output_shapes, extra_output_gens); } } @@ -1650,16 +1884,15 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { MakeUnique(std::move(thunks), reduce)); return EmitReductionToVector( - reduce, input->shape(), - [&](const llvm_ir::IrArray::Index& index) { + reduce, input->shape(), {[&](const llvm_ir::IrArray::Index& index) { return GetIrArray(*input, *reduce) .EmitReadArrayElement(index, &ir_builder_); - }, - [&](const llvm_ir::IrArray::Index& index) { + }}, + {[&](const llvm_ir::IrArray::Index& index) { return GetIrArray(*init_value, *reduce) .EmitReadArrayElement(index, &ir_builder_); - }, - dimensions_to_reduce, reducer); + }}, + dimensions_to_reduce, {reducer}, {{}}, {}); } thunk_sequence_->emplace_back(BuildKernelThunk(reduce)); @@ -2281,7 +2514,9 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( if (alpha->opcode() == HloOpcode::kBroadcast) { alpha = alpha->operand(0); } - alpha = inst->operand(alpha->parameter_number()); + if (alpha->opcode() == HloOpcode::kParameter) { + alpha = inst->operand(alpha->parameter_number()); + } // TODO(b/74185543): Remove the following if block once we support fusion // with a non-constant as well. Then we will just always use the constant // on the device. @@ -2324,21 +2559,30 @@ std::unique_ptr IrEmitterUnnested::BuildFftThunk( } StatusOr> IrEmitterUnnested::BuildInitializerThunk( - const HloInstruction* hlo) { + const HloInstruction* hlo, const ShapeIndex& index) { bool fused = HloOpcode::kFusion == hlo->opcode(); const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo; - const HloInstruction* init_value = [&] { + const HloInstruction* init_value_operand = [&] { switch (inst->opcode()) { case HloOpcode::kSelectAndScatter: return inst->operand(2); case HloOpcode::kReduce: return inst->operand(1); + case HloOpcode::kTuple: + CHECK(hlo->IsMultiOutputFusion()) + << ": " << hlo->ToString() << " is not a multi-output fusion."; + CHECK(inst->operand(index.back())->opcode() == HloOpcode::kReduce) + << ": Found '" << inst->operand(index.back())->opcode() << "' in " + << inst->ToString() << " but expected 'reduce'."; + // For multi-output fusion look through the tuple. + return inst->operand(index.back())->operand(1); default: LOG(FATAL) << "Opcode " << inst->opcode() << " should not need an initializer."; } }(); + const HloInstruction* init_value = init_value_operand; if (fused && init_value->opcode() == HloOpcode::kParameter) { init_value = hlo->operand(init_value->parameter_number()); } @@ -2356,7 +2600,7 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( ArraySlice literal_bytes( reinterpret_cast(literal.untyped_data()), num_bytes); if (c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) { - return {MakeUnique(GetAllocationSlice(*hlo), hlo)}; + return {MakeUnique(GetAllocationSlice(*hlo, index), hlo)}; } // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by @@ -2372,8 +2616,8 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( pattern16 = literal_bytes.front(); } uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16); - return {MakeUnique(pattern32, - GetAllocationSlice(*hlo), hlo)}; + return {MakeUnique( + pattern32, GetAllocationSlice(*hlo, index), hlo)}; } // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit @@ -2383,20 +2627,31 @@ StatusOr> IrEmitterUnnested::BuildInitializerThunk( literal_bytes.size() - 4) == 0) { uint32 word; memcpy(&word, literal_bytes.data(), sizeof(word)); - return {MakeUnique(word, GetAllocationSlice(*hlo), - hlo)}; + return {MakeUnique( + word, GetAllocationSlice(*hlo, index), hlo)}; } } // Otherwise fall back to our slow initializer code. std::unique_ptr kernel_thunk = BuildKernelThunk(hlo); - TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk( - *hlo, - [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value, *hlo) - .EmitReadArrayElement(index, &ir_builder_); - }, - kernel_thunk.get())); + LaunchDimensions launch_dimensions = + CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index), + ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), + ir_emitter_context_->llvm_module()); + // If the init_value was fused into this reduce we have to generate it first. + if (fused && init_value_operand->opcode() != HloOpcode::kParameter) { + CHECK_EQ(HloOpcode::kConstant, init_value_operand->opcode()); + TF_RETURN_IF_ERROR(HandleConstant(const_cast(init_value))); + } + TF_RETURN_IF_ERROR(ParallelLoopEmitter( + [=](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*init_value, *hlo) + .EmitReadArrayElement(index, &ir_builder_); + }, + GetIrArray(*hlo, *hlo, index), launch_dimensions, + &ir_builder_) + .EmitLoop(IrName(hlo))); // Clean up state left behind by emitting the loop above. (This is normally // done in IrEmitterUnnested::Postprocess().) diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index e42c5e86862576bad1c8610652d1c50d2364cd83..202231b82f3877c11cf932bd00a8aac350fd0afa 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -100,6 +100,13 @@ class IrEmitterUnnested : public IrEmitter { const HloInstruction& inst, tensorflow::gtl::ArraySlice args); + // Helper for writing extra outputs from inside a reduce kernel. + Status EmitExtraOutputsForReduce( + const HloInstruction* reduce, const llvm_ir::IrArray::Index& index, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); + // EmitColumnReduction and EmitRowReduction emit code for column and row // reduction of a matrix and/or 3D tensor. Row and column reduction have // different memory access pattern, so for performance their implementations @@ -110,28 +117,43 @@ class IrEmitterUnnested : public IrEmitter { // `EmitReductionToVector`. Note that input shape might not be // [height x width], but can be bitcast to [height x weight] with "height" // being the major dimension. - Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitColumnReduction( + int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Emits code that reduces a 3D tensor of shape [depth x height x width] to a // vector of shape [height]. Other parameters have the same meaning as those // of `EmitReductionToVector`. Note that input shape might not be // [depth x height x width], but can be bitcast to [depth x height x weight] // with "depth" being the most major dimension. - Status EmitRowReduction(int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitRowReduction( + int64 depth, int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Emits code that reduces a tensor of arbitrary rank to a scalar. - Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); + Status EmitReductionToScalar( + HloInstruction* reduce, const Shape& input_shape, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Figures out whether `reduce` is a row or column reduction, and which // dimensions to reduce, and calls either `EmitRowReduction` or @@ -141,13 +163,24 @@ class IrEmitterUnnested : public IrEmitter { // generate elements of the input and the initial value. Other parameters mean // the same as for `HandleReduce`. // + // Multiple reduces can be emitted in the same loop, assuming they have the + // same input and output shapes, and the same reduce dimensions. + // + // extra_output_gens can contain extra generators for intermediate outputs. + // These must have the same shape as the reduce input as they are computed + // when the reduce inputs are being read. + // // Prerequisite: `IsReductionToVector(*reduce)` Status EmitReductionToVector( HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice input_gens, + tensorflow::gtl::ArraySlice init_value_gens, tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer); + tensorflow::gtl::ArraySlice reducers, + tensorflow::gtl::ArraySlice reduce_output_shapes, + tensorflow::gtl::ArraySlice< + std::pair> + extra_output_gens); // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned @@ -166,7 +199,7 @@ class IrEmitterUnnested : public IrEmitter { // Returns a thunk that, given a reduce or select-and-scatter op, initializes // its memory to the appropriate initial value. StatusOr> BuildInitializerThunk( - const HloInstruction* hlo); + const HloInstruction* hlo, const ShapeIndex& index = {}); // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..e3f444a1268cf8e3e551a3c5b986fee0339fa327 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc @@ -0,0 +1,144 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace gpu { + +GpuMultiOutputFusion::GpuMultiOutputFusion() : MultiOutputFusion(INT64_MAX) {} + +bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) { + auto get_element_instr = + [&](const HloInstruction* instr) -> const HloInstruction* { + const HloInstruction* element_instr = instr; + if (instr->opcode() == HloOpcode::kFusion) { + auto fused_expression_root = instr->fused_expression_root(); + if (instr->IsMultiOutputFusion()) { + // If possible, we want to pick a reduce operand of the fusion root, + // because it has the most constraints. + for (const auto* inst : fused_expression_root->operands()) { + if (inst->opcode() == HloOpcode::kReduce) { + return inst; + } + } + return fused_expression_root->operands()[0]; + } else { + element_instr = fused_expression_root; + } + } + return element_instr; + }; + + auto get_element_shape = [&](const HloInstruction* element_instr) { + // Special handling of kReduce instructions -- the fusion + // applies to the first operand. + if (element_instr->opcode() == HloOpcode::kReduce) { + return element_instr->operand(0)->shape(); + } + return element_instr->shape(); + }; + + // The shapes in all tuple operands should agree, unless it is a reduce. + // In that case, the operand of the reduce needs to have the same shape + // as the other tuple operands, but also we need to compare the output + // shapes of the reduces. + auto* element_instr_1 = get_element_instr(instr1); + auto* element_instr_2 = get_element_instr(instr2); + if (element_instr_1->opcode() == HloOpcode::kReduce && + element_instr_2->opcode() == HloOpcode::kReduce && + !ShapeUtil::Equal(element_instr_1->shape(), element_instr_2->shape())) { + return false; + } + // The elementwise output shapes must be the same (including layout). + return ShapeUtil::Equal(get_element_shape(element_instr_1), + get_element_shape(element_instr_2)); +} + +bool GpuMultiOutputFusion::IsProfitableOperand(HloInstruction* instr) { + // kConstant instruction will not have memory reads, so it won't be a profit + // source. Skip them. + if (instr->opcode() == HloOpcode::kConstant && + ShapeUtil::IsEffectiveScalar(instr->shape())) { + return false; + } + // We don't target to fuse producer/consumer instructions -- this should + // be taken care of by the instruction_fusion pass. If instr has only + // one user, it will not have sibling instructions. We won't consider it. + if (instr->user_count() < 2) { + return false; + } + return true; +} + +namespace { +bool IsReduction(HloInstruction* instr) { + if (instr->IsMultiOutputFusion()) { + for (const HloInstruction* operand : + instr->fused_expression_root()->operands()) { + if (operand->opcode() == HloOpcode::kReduce) { + return true; + } + } + return false; + } else if (instr->opcode() == HloOpcode::kFusion) { + return instr->fused_expression_root()->opcode() == HloOpcode::kReduce; + } else { + return instr->opcode() == HloOpcode::kReduce; + } +} +} // namespace + +bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) { + return IsReduction(instr); +} + +int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1, + HloInstruction* instr2) { + tensorflow::gtl::FlatSet in_list; + for (auto instr : instr1->operands()) { + if (!IsProfitableOperand(instr)) { + continue; + } + in_list.insert(instr); + } + int64 profit = 0; + for (auto instr : instr2->operands()) { + if (!IsProfitableOperand(instr) || in_list.count(instr) == 0) { + continue; + } + profit += ShapeUtil::ByteSizeOf(instr->shape()); + } + VLOG(2) << "Fusing instr1=" << instr1->name() << " instr2=" << instr2->name() + << ", the profit is =" << profit; + return profit; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..5451a93cec5e4aeca05717c181edb9dad0305c83 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ + +#include "tensorflow/compiler/xla/service/multi_output_fusion.h" + +namespace xla { +namespace gpu { + +// Multi-output fusion of sibling and producer-consumer instructions for the +// Jellyfish backend. +class GpuMultiOutputFusion : public MultiOutputFusion { + public: + GpuMultiOutputFusion(); + + protected: + // Test if instr1 and instr2 have the compatible shapes that can be legally + // fused. + bool ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) override; + + // We currently only consider reduce and reduce fusion nodes as candidates. + bool IsFusible(HloInstruction* instr) override; + + // This function estimates the amount of memory reads saved by merging + // instr1 and instr2 into one multi-output fusion instruction. For a fusion + // instruction, all the operands need to be loaded from memory. If we merge + // instr1 and instr2, common operands will not be loaded twice. The profit is + // estimated as the size of the common operands b/w instr1 and instr2. + int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) override; + + // Whether fusing the instruction can reduce memory reads. + // + // TODO(tjoerg): Move this method up into the MultiOutputFusion base class. + bool IsProfitableOperand(HloInstruction* instr) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..924cfb11f3da76c475458ea14f201e809be61be8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion_test.cc @@ -0,0 +1,230 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace gpu { + +using InstructionFusionTest = HloTestBase; + +const char kModulePrefix[] = R"( + HloModule test_module + + scalar_add_computation { + scalar_lhs.0 = f32[] parameter(0) + scalar_rhs.0 = f32[] parameter(1) + ROOT add.0 = f32[] add(scalar_lhs.0, scalar_rhs.0) + } + scalar_mul_computation { + scalar_lhs.1 = f32[] parameter(0) + scalar_rhs.1 = f32[] parameter(1) + ROOT mul.1 = f32[] add(scalar_lhs.1, scalar_rhs.1) + })"; + +TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) { + // Fusion with reduce instruction root and a sibling reduce instruction + // sharing the same input param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation { + p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + const.2 = f32[] constant(1) + fusion = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation + reduce.2 = f32[512]{0} reduce(p1, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion, reduce.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[6400]{0} parameter(1) + mul = f32[6400]{0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[6400]{0} parameter(1) + r1 = f32[64,100]{0,1} reshape(p1.2) + const.2 = f32[] parameter(0) + ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[6400]{0} parameter(1) + fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1 + fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceOutputShapes) { + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[10,10]{1,0} parameter(1) + mul = f32[10,10]{1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[] reduce(mul, const.1), dimensions={0,1}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[10,10]{1,0} parameter(1) + const.2 = f32[10]{0} parameter(0) + ROOT reduce.2 = f32[10]{0} reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1.3 = f32[10,10]{1,0} parameter(1) + fusion.1 = f32[] fusion(p0, p1.3), kind=kInput, calls=fused_computation_1 + p2 = f32[] parameter(2) + fusion.2 = f32[10]{0} fusion(p2, p1.3), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[], f32[10]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) { + // Two sibling fusions with reduce instruction roots sharing the same input + // param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + mul = f32[128,512,28,28]{3,2,1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + ROOT reduce.1 = f32[512]{0} reduce(mul, const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + fused_computation_2 { + p1.2 = f32[128,512,28,28]{3,2,1,0} parameter(1) + const.2 = f32[] parameter(0) + ROOT reduce.2 = f32[512]{0} reduce(p1.2, const.2), dimensions={0,2,3}, to_apply=scalar_add_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[128,512,28,28]{3,2,1,0} parameter(1) + fusion.1 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_1 + fusion.2 = f32[512] fusion(p0, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[512]{0}, f32[512]{0}) tuple(fusion.1, fusion.2) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, + MultiOutputFusionSiblingReduceAndReduceMultiOutputFusion) { + // Multi-output fusion with two reduce instructions root and a sibling reduce + // instruction sharing the same input param. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation (p0: f32[128,512,28,28]) -> (f32[512], f32[512]) { + const.1 = f32[] constant(1) + p0.1 = f32[128,512,28,28]{3,2,1,0} parameter(0) + mul = f32[128,512,28,28]{3,2,1,0} multiply(f32[128,512,28,28]{3,2,1,0} p0.1, f32[128,512,28,28]{3,2,1,0} p0.1) + reduce.1 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} mul, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + reduce.2 = f32[512]{0} reduce(f32[128,512,28,28]{3,2,1,0} p0.1, f32[] const.1), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT tuple = (f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} reduce.1, f32[512]{0} reduce.2) + } + + ENTRY entry (p0: f32[128,512,28,28]) -> (f32[512], f32[512], f32[512]) { + p0 = f32[128,512,28,28]{3,2,1,0} parameter(0) + const = f32[] constant(1) + fusion = (f32[512]{0}, f32[512]{0}) fusion(f32[128,512,28,28]{3,2,1,0} p0), kind=kInput, calls=fused_computation + get-tuple-element = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=0 + get-tuple-element.1 = f32[512]{0} get-tuple-element((f32[512]{0}, f32[512]{0}) fusion), index=1 + reduce.3 = f32[512]{0} reduce(p0, const), dimensions={0,2,3}, to_apply=scalar_add_computation + ROOT root = (f32[512]{0}, f32[512]{0}, f32[512]{0}) tuple(f32[512]{0} get-tuple-element, f32[512]{0} get-tuple-element.1, f32[512]{0} reduce.3) + })")) + .ValueOrDie(); + ASSERT_TRUE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* fusion = + module->entry_computation()->root_instruction()->operand(0)->operand(0); + ASSERT_TRUE(fusion->IsMultiOutputFusion()); + EXPECT_THAT(fusion->fused_expression_root(), + op::Tuple(op::Reduce(), op::Reduce(), op::Reduce())); +} + +TEST_F(InstructionFusionTest, + MultiOutputFusionSiblingFusionCheckAgainstReduceOperand) { + // Verify that if we already have a multi-output fusion that we prefer to pick + // a reduce op from its operands for checking shape compatibility. + auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"( + fused_computation_1 { + p1.1 = f32[10,10]{1,0} parameter(1) + mul = f32[10,10]{1,0} multiply(p1.1, p1.1) + const.1 = f32[] parameter(0) + reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0,1}, to_apply=scalar_add_computation + ROOT tuple = (f32[10,10], f32[]) tuple(mul, reduce.1) + } + + fused_computation_2 { + p1.2 = f32[10,10]{1,0} parameter(1) + const.2 = f32[10] parameter(0) + ROOT reduce.2 = f32[10] reduce(p1.2, const.2), dimensions={0}, to_apply=scalar_mul_computation + } + + ENTRY entry { + p0 = f32[] parameter(0) + p1 = f32[10,10]{1,0} parameter(1) + p2 = f32[10]{0} parameter(2) + fusion.1 = (f32[10,10], f32[10]) fusion(p0, p1), kind=kInput, calls=fused_computation_1 + get-tuple-element.1 = f32[10,10] get-tuple-element((f32[10,10], f32[10]) fusion.1), index=0 + get-tuple-element.2 = f32[] get-tuple-element((f32[10,10], f32[10]) fusion.1), index=1 + fusion.2 = f32[10] fusion(p2, p1), kind=kInput, calls=fused_computation_2 + ROOT root = (f32[10,10], f32[], f32[10]) tuple(get-tuple-element.1, get-tuple-element.2, fusion.2) + })")) + .ValueOrDie(); + ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc index 696fa7e0194032b5c78bf11383c3280a62de07fa..6f4bb0580e8dfc1dce1cca0a60cc3dd9ea600fb3 100644 --- a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc @@ -33,8 +33,7 @@ class StreamAssignmentTest : public HloTestBase { auto debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_disable_multi_streaming(false); config.set_debug_options(debug_options); - return MakeUnique("test_module", VersionedComputationHandle(), - config); + return MakeUnique("test_module", config); } // Pre-canned shapes. diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..a50ddf6ac63c7fa7ccace94bc7f40f438aedccf8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.cc @@ -0,0 +1,151 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" + +#include "tensorflow/compiler/xla/layout_util.h" + +namespace xla { +namespace gpu { + +using stream_executor::dnn::DataLayout; +using stream_executor::dnn::DataLayoutString; +using stream_executor::dnn::FilterLayout; +using stream_executor::dnn::FilterLayoutString; + +StatusOr> +StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, + DataLayout input, FilterLayout filter, + DataLayout output) { + std::vector input_layout; + switch (input) { + case DataLayout::kBatchDepthYX: + input_layout.push_back(dnums.input_batch_dimension()); + input_layout.push_back(dnums.input_feature_dimension()); + input_layout.insert(input_layout.end(), + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + break; + case DataLayout::kBatchYXDepth: + input_layout.push_back(dnums.input_batch_dimension()); + input_layout.insert(input_layout.end(), + dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end()); + input_layout.push_back(dnums.input_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid input layout: ", + DataLayoutString(input)); + } + + std::vector filter_layout; + switch (filter) { + case FilterLayout::kOutputInputYX: + filter_layout.push_back(dnums.kernel_output_feature_dimension()); + filter_layout.push_back(dnums.kernel_input_feature_dimension()); + filter_layout.insert(filter_layout.end(), + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + break; + case FilterLayout::kOutputYXInput: + filter_layout.push_back(dnums.kernel_output_feature_dimension()); + filter_layout.insert(filter_layout.end(), + dnums.kernel_spatial_dimensions().begin(), + dnums.kernel_spatial_dimensions().end()); + filter_layout.push_back(dnums.kernel_input_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid filter layout: ", + FilterLayoutString(filter)); + } + + std::vector output_layout; + switch (output) { + case DataLayout::kBatchDepthYX: + output_layout.push_back(dnums.output_batch_dimension()); + output_layout.push_back(dnums.output_feature_dimension()); + output_layout.insert(output_layout.end(), + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + break; + case DataLayout::kBatchYXDepth: + output_layout.push_back(dnums.output_batch_dimension()); + output_layout.insert(output_layout.end(), + dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end()); + output_layout.push_back(dnums.output_feature_dimension()); + break; + default: + return tensorflow::errors::Internal("Invalid output layout: ", + DataLayoutString(output)); + } + + return std::make_tuple(LayoutUtil::MakeLayoutFromMajorToMinor(input_layout), + LayoutUtil::MakeLayoutFromMajorToMinor(filter_layout), + LayoutUtil::MakeLayoutFromMajorToMinor(output_layout)); +} + +StatusOr> +XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, + const Layout& input, const Layout& filter, + const Layout& output) { + Layout nchw_input, nchw_filter, nchw_output; + std::tie(nchw_input, nchw_filter, nchw_output) = + StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchDepthYX, + FilterLayout::kOutputInputYX, + DataLayout::kBatchDepthYX) + .ConsumeValueOrDie(); + + Layout nhwc_input, nhwc_filter, nhwc_output; + std::tie(nhwc_input, nhwc_filter, nhwc_output) = + StreamExecutorConvLayoutsToXlaLayouts(dnums, DataLayout::kBatchYXDepth, + FilterLayout::kOutputYXInput, + DataLayout::kBatchYXDepth) + .ConsumeValueOrDie(); + + DataLayout input_layout; + if (LayoutUtil::Equal(input, nchw_input)) { + input_layout = DataLayout::kBatchDepthYX; + } else if (LayoutUtil::Equal(input, nhwc_input)) { + input_layout = DataLayout::kBatchYXDepth; + } else { + return tensorflow::errors::Internal("Invalid input layout: ", + input.ShortDebugString()); + } + + FilterLayout filter_layout; + if (LayoutUtil::Equal(filter, nchw_filter)) { + filter_layout = FilterLayout::kOutputInputYX; + } else if (LayoutUtil::Equal(filter, nhwc_filter)) { + filter_layout = FilterLayout::kOutputYXInput; + } else { + return tensorflow::errors::Internal("Invalid filter layout: ", + filter.ShortDebugString()); + } + + DataLayout output_layout; + if (LayoutUtil::Equal(output, nchw_output)) { + output_layout = DataLayout::kBatchDepthYX; + } else if (LayoutUtil::Equal(output, nhwc_output)) { + output_layout = DataLayout::kBatchYXDepth; + } else { + return tensorflow::errors::Internal("Invalid output layout: ", + output.ShortDebugString()); + } + + return std::make_tuple(input_layout, filter_layout, output_layout); +} +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/stream_executor_util.h b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h new file mode 100644 index 0000000000000000000000000000000000000000..8218f4fd11d3978d0ecc53fc15e287aea4b69ec3 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/stream_executor_util.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ + +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +// Helper functions for interacting with StreamExecutor. + +namespace xla { +namespace gpu { + +// Returns (input, filter, output) XLA Layout protos given the StreamExecutor +// layouts. +StatusOr> +StreamExecutorConvLayoutsToXlaLayouts(const ConvolutionDimensionNumbers& dnums, + stream_executor::dnn::DataLayout input, + stream_executor::dnn::FilterLayout filter, + stream_executor::dnn::DataLayout output); + +// Returns (input, filter, output) StreamExecutor layouts given the XLA layouts. +StatusOr> +XlaConvLayoutsToStreamExecutorLayouts(const ConvolutionDimensionNumbers& dnums, + const Layout& input, const Layout& filter, + const Layout& output); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_EXECUTOR_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index ad55728c45599c801aad7e12fac95ae9f0c4fc3b..7749201cbceece216a2db2569936949eb7de5125 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -457,8 +457,8 @@ class WhileBodyComputationMatcher : public MatcherBase { return InvalidArgument("Unexpected tuple index instruction : %s", inst->name().c_str()); } else if (tag == "loop_increment") { - // Parse the constant which represents the loop induction variable - // increment value. + // ParseHloString the constant which represents the loop induction + // variable increment value. TF_RETURN_IF_ERROR(ParseConstInteger(inst, &loop_increment_)); } else if (tag == "param0" && inst != computation_->parameter_instruction(0)) { diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 06a5e0351b63270b61b998ca2211f480f256f759..5dba50a63b1d77bda0835e0333cc7dd9ddbe2dcf 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -26,6 +26,41 @@ namespace xla { using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; +StatusOr MinimumMemoryForModule( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function) { + if (module_sequence.empty()) { + return 0; + } + + const HloModule* module = module_sequence.begin()->first->parent(); + TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, + TuplePointsToAnalysis::Run(module)); + + // The absolute minimum memory required for a given sequence of instructions + // is determined by the sequence of Alloc and Free calls on a simulated heap, + // ignoring fragmentation. We run the heap simulation on the whole module, + // rather than summing each computation, since it gives us a better lower + // bound, by minimizing the liveness of sub-computations. + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), *module, + module_sequence, *points_to_analysis, size_function)); + return result.heap_size; +} + +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} + /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 8b2b43a37a5c41d334e5338c6a6fad160f03a51e..3be3bb8e7fae9d22c9be00f3f81d2b93638c22ef 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -34,6 +34,21 @@ limitations under the License. namespace xla { +// Returns the minimum memory required to compute an HLO module where all +// computations have been scheduled (represented by the given module_sequence), +// assuming no fragmentation. +StatusOr MinimumMemoryForModule( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function); + +// Returns the minimum memory required to compute the given computation, +// assuming no fragmentation. +StatusOr MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function); + // Forward declare classes defined below. class HeapAlgorithm; diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 6271652412c2979ff926702f12722102344b0dfb..309ab85f784274835904015472f3f0d601885763 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -34,6 +34,64 @@ limitations under the License. namespace xla { namespace { +class MinimumMemoryForSequenceTest : public HloTestBase {}; + +TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + // Entry params: 8 bytes (4 bytes per param), TOTAL=8 + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); + // While: 8 bytes (4 bytes per element), TOTAL=32 + // Both cond and body use a max of 24 bytes, TOTAL=56 + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, MinimumMemoryForModule(module_sequence, size_fn).ValueOrDie()); +} + const char kAlloc[] = "Alloc"; const char kFree[] = "Free"; const char kFinish[] = "Finish"; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 1f7c1cffd324ad2f4e4cdf11046c8459b8ceb6d5..e201359d3d25b7d2dda852762c6de1fcb75685d7 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -145,6 +145,7 @@ message HloInstructionProto { repeated int64 operand_ids = 36; repeated int64 control_predecessor_ids = 37; repeated int64 called_computation_ids = 38; + repeated int64 replica_group_ids = 44; xla.OpSharding sharding = 40; diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils.h b/tensorflow/compiler/xla/service/hlo_casting_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..7f73bba036534a62a70a80431236cffa766c9b38 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils.h @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Casting utilitiy functions for HLO instructions. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ + +#include +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +class HloInstruction; + +template +using EnableIfDerivedFromHlo = + typename std::enable_if::value>::type; + +// TODO(b/93238915): Switch implementation from C++'s dynamic_cast to LLVM-like +// RTTI if it turns out to be a performance issue. +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr or runtime information does not match. +// +// Similar to LLVM's cast. +template * = nullptr> +const T* Cast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + const T* casted = dynamic_cast(instruction); + CHECK(casted != nullptr); + return casted; +} + +// Non-const overload of Cast. +template * = nullptr> +T* Cast(HloInstruction* instruction) { + return const_cast( + Cast(const_cast(instruction))); +} + +// Works just like the Cast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's cast_or_null. +template * = nullptr> +const T* CastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? Cast(instruction) : nullptr; +} + +// Non-const overload of CastOrNull. +template * = nullptr> +T* CastOrNull(HloInstruction* instruction) { + return const_cast( + CastOrNull(const_cast(instruction))); +} + +// Casts an HloInstruction pointer to one of its subclasses, dies if argument is +// nullptr, returns nullptr if runtime information does not match. +// +// Similar to LLVM's dyn_cast. +template * = nullptr> +const T* DynCast(const HloInstruction* instruction) { + CHECK(instruction != nullptr); + return dynamic_cast(instruction); +} + +// Non-const overload of DynCast. +template * = nullptr> +T* DynCast(HloInstruction* instruction) { + return const_cast( + DynCast(const_cast(instruction))); +} + +// Works just like the DynCast, except that it allows for a null pointer as an +// argument which it then propagates. +// +// Similar to LLVM's dyn_cast_or_null. +template * = nullptr> +const T* DynCastOrNull(const HloInstruction* instruction) { + return instruction != nullptr ? DynCast(instruction) : nullptr; +} + +// Non-const overload of DynCastOrNull. +template * = nullptr> +T* DynCastOrNull(HloInstruction* instruction) { + return const_cast( + DynCastOrNull(const_cast(instruction))); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a3364275409122254bf99b40a7d2fcbb2d7564cc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_casting_utils_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class DummyInstruction : public HloInstruction { + public: + DummyInstruction() + : HloInstruction(HloOpcode::kConstant, ShapeUtil::MakeShape(F32, {})) {} +}; + +class AnotherDummyInstruction : public HloInstruction { + public: + AnotherDummyInstruction() + : HloInstruction(HloOpcode::kParameter, ShapeUtil::MakeShape(F32, {})) {} +}; + +TEST(HloCastingUtilsTest, CastSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + Cast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, CastDiesForWrongType) { + AnotherDummyInstruction instruction; + ASSERT_DEATH( + Cast(static_cast(&instruction)), ""); +} + +TEST(HloCastingUtilsTest, CastDiesForNullptr) { + HloInstruction* null = nullptr; + ASSERT_DEATH(Cast(null), ""); +} + +TEST(HloCastingUtilsTest, CastOrNullSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + Cast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, CastOrNullDiesForWrongType) { + AnotherDummyInstruction instruction; + ASSERT_DEATH( + Cast(static_cast(&instruction)), ""); +} + +TEST(HloCastingUtilsTest, CastOrNullReturnsNullptrForNullptr) { + HloInstruction* null = nullptr; + DummyInstruction* casted = CastOrNull(null); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = + DynCast(static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, DynCastReturnsNullptrForWrongType) { + AnotherDummyInstruction instruction; + DummyInstruction* casted = + DynCast(static_cast(&instruction)); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastDiesForNullptr) { + HloInstruction* null = nullptr; + ASSERT_DEATH(DynCast(null), ""); +} + +TEST(HloCastingUtilsTest, DynCastOrNullSucceeds) { + DummyInstruction instruction; + DummyInstruction* casted = DynCastOrNull( + static_cast(&instruction)); + ASSERT_EQ(casted, &instruction); +} + +TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForWrongType) { + AnotherDummyInstruction instruction; + DummyInstruction* casted = DynCastOrNull( + static_cast(&instruction)); + ASSERT_EQ(casted, nullptr); +} + +TEST(HloCastingUtilsTest, DynCastOrNullReturnsNullptrForNullptr) { + HloInstruction* null = nullptr; + DummyInstruction* casted = DynCastOrNull(null); + ASSERT_EQ(casted, nullptr); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h new file mode 100644 index 0000000000000000000000000000000000000000..658643b427a9625fac1166151a89cbd669f817d5 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_clone_context.h @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { + +class HloInstruction; +class HloComputation; +class HloModule; + +// Data structure used to track the cloning of HloInstruction and HloComputation +// objects. +class HloCloneContext { + public: + // Creates a new HloCloneContext object to clone HloInstruction and + // HloComputation objects to be added to the module specified as argument. + // The suffix string will be appended to computation names. + explicit HloCloneContext(HloModule* module, const string& suffix = "") + : module_(module), suffix_(suffix) {} + + HloModule* module() const { return module_; } + + const string& suffix() const { return suffix_; } + + void MapInstruction(const HloInstruction* old_instruction, + HloInstruction* new_instruction) { + instructions_[old_instruction] = new_instruction; + } + + void MapComputation(const HloComputation* old_computation, + HloComputation* new_computation) { + computations_[old_computation] = new_computation; + } + + // Finds the new instruction mapped to its old copy, or return nullptr in case + // it is not found. + HloInstruction* FindInstruction(const HloInstruction* old_instruction) const { + return FindOrDefault(instructions_, old_instruction, nullptr); + } + + // Finds the new computation mapped to its old copy, or return nullptr in case + // it is not found. + HloComputation* FindComputation(const HloComputation* old_computation) const { + return FindOrDefault(computations_, old_computation, nullptr); + } + + // Retrieves the new instruction mapped to its old copy, or fail if not found. + HloInstruction* GetInstruction(const HloInstruction* old_instruction) const { + return FindOrDie(instructions_, old_instruction); + } + + // Retrieves the new computation mapped to its old copy, or fail if not found. + HloComputation* GetComputation(const HloComputation* old_computation) const { + return FindOrDie(computations_, old_computation); + } + + const tensorflow::gtl::FlatMap& + cloned_instructions() const { + return instructions_; + } + + const tensorflow::gtl::FlatMap& + cloned_computations() const { + return computations_; + } + + private: + HloModule* module_; + string suffix_; + tensorflow::gtl::FlatMap + instructions_; + tensorflow::gtl::FlatMap + computations_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CLONE_CONTEXT_H_ diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 63c3dc4a5932f754a9ccdd70d03c999fe528a448..b158f44923982642615543cebbd54f00596e2641 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -64,7 +64,7 @@ HloComputation::HloComputation( const string& name, int parameter_count, std::vector>* instructions, HloInstruction* root_instruction, HloInstruction* fusion_instruction) - : name_(name), + : name_(NameUniquer::GetSanitizedName(name)), unique_id_(-1), root_instruction_(root_instruction), fusion_instruction_(fusion_instruction) { @@ -315,6 +315,42 @@ void ComputeComputationPostOrder( } } +std::list ComputeInstructionPostOrder( + HloInstruction* root, tensorflow::gtl::FlatSet* visited) { + std::list post_order; + std::vector> dfs_stack; + dfs_stack.emplace_back(root, false); + while (!dfs_stack.empty()) { + const auto current = dfs_stack.back(); + if (current.second) { + dfs_stack.pop_back(); + if (!visited->insert(current.first).second) { + continue; + } + post_order.push_back(current.first); + } else { + if (visited->count(current.first)) { + dfs_stack.pop_back(); + continue; + } + dfs_stack.back().second = true; + + // Add the operands to the stack in reverse order so the first operand is + // processed first. This will produce a more natural ordering and a nicer + // result for thigns like HLO stringification. + const auto& operands = current.first->operands(); + for (int64 i = operands.size() - 1; i >= 0; --i) { + dfs_stack.emplace_back(operands[i], false); + } + + for (HloInstruction* op : current.first->control_predecessors()) { + dfs_stack.emplace_back(op, false); + } + } + } + return post_order; +} + } // namespace std::list HloComputation::MakeInstructionPostOrder() const { @@ -328,9 +364,9 @@ std::list HloComputation::MakeInstructionPostOrder() const { // users). trace_instructions.push_back(instruction.get()); } else if (instruction->users().empty()) { - post_order.splice(post_order.end(), - InstructionPostOrderer::GetOrder(instruction.get(), - &added_instructions)); + post_order.splice( + post_order.end(), + ComputeInstructionPostOrder(instruction.get(), &added_instructions)); } } post_order.splice(post_order.end(), trace_instructions); @@ -752,22 +788,21 @@ Status HloComputation::Accept( } std::unique_ptr HloComputation::Clone( - const string& suffix, HloModule* module, - HloInstruction::CloneMap* clone_map) { + const string& suffix, HloCloneContext* context) { return CloneWithReplacements( /*replacements=*/std::unordered_map>(), - module, clone_map, suffix); + context, suffix); } std::unique_ptr HloComputation::CloneWithReplacements( std::unordered_map> replacements, - HloModule* module, HloInstruction::CloneMap* clone_map, - const string& suffix) { - HloInstruction::CloneMap local_clone_map; - if (clone_map == nullptr) { - clone_map = &local_clone_map; + HloCloneContext* context, const string& suffix) { + std::unique_ptr context_ptr; + if (context == nullptr) { + context_ptr = MakeUnique(parent(), suffix); + context = context_ptr.get(); } // Look up instr in the replacements map, and return either the replacement, @@ -792,18 +827,18 @@ std::unique_ptr HloComputation::CloneWithReplacements( } std::vector> instructions; - std::unique_ptr new_instr = nullptr; + std::unique_ptr new_instr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { auto replaced_operand = replace(operand); CHECK_NE(replaced_operand, nullptr) - << "Replacements map specifies to leave out " << operand->ToString() - << ", but it is used by " << instr->ToString() << "."; - new_operands.push_back(FindOrDie(*clone_map, replaced_operand)); + << "replacements map tried to eliminate a used instruction " + << operand->ToString() << ", used by " << instr->ToString(); + new_operands.push_back(context->GetInstruction(replaced_operand)); } - new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands, - module, clone_map); + new_instr = + instr->CloneWithNewOperands(instr->shape(), new_operands, context); instructions.push_back(std::move(new_instr)); } Builder builder(name() + "." + suffix); @@ -811,22 +846,23 @@ std::unique_ptr HloComputation::CloneWithReplacements( builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction()))); + /*root_instruction=*/context->GetInstruction( + replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { - HloInstruction* new_instr = FindOrDie(*clone_map, instr); + HloInstruction* new_instr = context->GetInstruction(instr); for (auto successor : instr->control_successors()) { auto replaced_successor = replace(successor); - CHECK_NE(replaced_successor, nullptr) - << "Replacements map specifies to leave out " << successor->ToString() - << ", but it is control-depended-on by " << instr->ToString() << "."; - - TF_CHECK_OK(new_instr->AddControlDependencyTo( - FindOrDie(*clone_map, replaced_successor))); + // successor may not have been remapped, because it might have been + // removed by the replacements map. + if (replaced_successor != nullptr) { + TF_CHECK_OK(new_instr->AddControlDependencyTo( + context->GetInstruction(replaced_successor))); + } } } - + context->MapComputation(this, result.get()); // We cloned the elements of 'replacements', so they're all going to be // destroyed. HloInstructions need to be detached from their operands before // they're destroyed, otherwise they stick around in the operands' users lists @@ -836,7 +872,6 @@ std::unique_ptr HloComputation::CloneWithReplacements( new_instr->DetachFromOperands(); } } - return result; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 8bc97df0365a32bdc89d4636ad4c7076ffb08296..0da4a305f3d5d694a1918fed294337100b0a27fd 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -300,17 +301,11 @@ class HloComputation { const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - // - // If the module pointer is not nullptr, then the cloned computations will be - // added to this module in order to support deep cloning. Otherwise the module - // of the computation is used. - // - // If clone_map is not nullptr, then each original instruction that is cloned - // will be inserted and map to its clone. clone_map should not already contain - // any of the instructions to clone. - std::unique_ptr Clone( - const string& suffix = "clone", HloModule* module = nullptr, - HloInstruction::CloneMap* clone_map = nullptr); + // If the clone context is specified, it will be populated with the cloned + // object mappings, and its module() will be used to add new computations + // into. + std::unique_ptr Clone(const string& suffix = "clone", + HloCloneContext* context = nullptr); // Like Clone(), but if an instruction is present in replacement_map, we use // the map's value to replace that instruction in the cloned computation. @@ -320,9 +315,7 @@ class HloComputation { std::unique_ptr CloneWithReplacements( std::unordered_map> replacements, - HloModule* module = nullptr, - HloInstruction::CloneMap* clone_map = nullptr, - const string& suffix = "clone"); + HloCloneContext* context = nullptr, const string& suffix = "clone"); // Returns true if the given instruction can be removed from the computation. // Parameter instructions cannot be removed without violating invariants of diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 94c9c7eabcc99d4cf61f535925c068a9b55ed136..762e1afc71b108b2e32b5a7f7f1bbeb783fc6fbd 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -172,15 +172,22 @@ Status HloCostAnalysis::HandleReverse(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleSlice(const HloInstruction* slice) { + current_properties_[kBytesAccessedKey] = shape_size_(slice->shape()) * 2; return Status::OK(); } -Status HloCostAnalysis::HandleDynamicSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleDynamicSlice( + const HloInstruction* dynamic_slice) { + current_properties_[kBytesAccessedKey] = + shape_size_(dynamic_slice->shape()) * 2; return Status::OK(); } -Status HloCostAnalysis::HandleDynamicUpdateSlice(const HloInstruction*) { +Status HloCostAnalysis::HandleDynamicUpdateSlice( + const HloInstruction* dynamic_update_slice) { + current_properties_[kBytesAccessedKey] = + shape_size_(dynamic_update_slice->operand(1)->shape()) * 2; return Status::OK(); } @@ -386,6 +393,10 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleGenerateToken(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { auto lhs = convolution->operand(0); auto rhs = convolution->operand(1); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index d17678d20f2a23fd98d18b77d5fb25853901a789..0d66736fe1d0677d13a63ede7a203d6ac20c76f5 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -97,6 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; + Status HandleGenerateToken(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 16fdda8a8b9ade09ea31cda1f4cf5e8ff2c0a081..d22bef56730da194816b4ee89dc3196439b350f9 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -460,5 +460,51 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { EXPECT_EQ(analysis.flop_count(), 1472); } +TEST_F(HloCostAnalysisTest, Slice) { + // Test the analysis on a slice. + XlaBuilder builder("slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.Slice(x, {0}, {1}, {1}); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + +TEST_F(HloCostAnalysisTest, DynamicSlice) { + // Test the analysis on a slice. + XlaBuilder builder("dynamic-slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.DynamicSlice(x, builder.ConstantR1({1}), {1}); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + +TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { + // Test the analysis on a slice. + XlaBuilder builder("dynamic-update-slice"); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "x"); + auto slice = builder.DynamicUpdateSlice(x, builder.ConstantR1({1.0}), + builder.ConstantR1({1})); + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis(ShapeSize); + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 8); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index c17c26c5a435fe34dd1024d596004cf6b5fdce8c..a0ee8896230d6dcacb5a8eb607fc00ae5226cfa5 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -41,16 +42,16 @@ namespace { // Find and combine identical constants. Constants are identical if they have // the same type and value. -bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { - bool changed = false; - +StatusOr CombineConstants(HloComputation* computation, + bool is_layout_sensitive) { + TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, "")); // Map from ShortDebugString of the layoutless shape of the constant to the // set of constant instructions with that shape. Layoutless shape is used to // bin possible common constants together to reduce number of constant // comparisons. If we end up having too many constant comparisons, a more // precise binning might have to be used. std::multimap constants; - + int64 combined = 0; auto inst_it = computation->instructions().begin(); while (inst_it != computation->instructions().end()) { HloInstruction* instruction = *inst_it; @@ -70,7 +71,8 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { auto range = constants.equal_range(shape_string); HloInstruction* match = nullptr; for (auto it = range.first; it != range.second; ++it) { - if (instruction->literal() == it->second->literal()) { + if (instruction->literal() == it->second->literal() && + domain_map->InSameDomain(it->second, instruction)) { match = it->second; break; } @@ -81,12 +83,13 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { // Match found, replace this instruction with the one in the multimap. TF_CHECK_OK(instruction->ReplaceAllUsesWith(match)); TF_CHECK_OK(computation->RemoveInstruction(instruction)); - changed = true; + ++combined; } } } - - return changed; + VLOG(4) << "Combined " << combined << " constants in " << computation->name() + << " computation"; + return combined > 0; } // An instruction is considered to be equivalent to another only if they @@ -123,24 +126,27 @@ StatusOr HloCSE::Run(HloModule* module) { continue; } - changed |= CombineConstants(computation, is_layout_sensitive_); + TF_ASSIGN_OR_RETURN(bool combined, + CombineConstants(computation, is_layout_sensitive_)); + changed |= combined; // HLO instructions are grouped into equivalency classes by using the // cse_equal predicate defined above. This set holds a representative // instruction for each class. tensorflow::gtl::FlatSet - representatives(/*N=*/1024, &CseHash, cse_equal); - + representatives(/*N=*/computation->instruction_count() + 1, &CseHash, + cse_equal); for (auto instruction : computation->MakeInstructionPostOrder()) { // If the instruction has zero operands (constants, parameters, etc.) skip // over it. if (instruction->operand_count() == 0) { continue; } - - // Skip instructions which have side effects. - if (instruction->HasSideEffect()) { + // Skip instructions which have side effects or are a domain (which must + // not be CSE-ed). + if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kDomain) { continue; } diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 9735764b692238d6a320bcff51e43b98dcadabda..16db374566c727f1f3efe2a6d419f1f3caf0aaf1 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -32,10 +32,10 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/types.h" @@ -142,31 +142,46 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { // Test that constants with the same value but different type are *not* // commoned. auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + std::vector constants; + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0)))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); // Duplicate the float constant to verify something happens. - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + constants.push_back(builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f)))); + + const Shape shape_r0 = ShapeUtil::MakeShape(F32, {}); + for (int64 i = 0; i < constants.size(); ++i) { + constants[i] = builder.AddInstruction( + HloInstruction::CreateConvert(shape_r0, constants[i])); + } + HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary( + shape_r0, HloOpcode::kAdd, constants[0], constants[1])); + for (int64 i = 2; i < constants.size(); ++i) { + root = builder.AddInstruction(HloInstruction::CreateBinary( + shape_r0, HloOpcode::kAdd, root, constants[i])); + } auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); - EXPECT_EQ(7, computation->instruction_count()); + EXPECT_EQ(20, computation->instruction_count()); HloCSE cse(/*is_layout_sensitive=*/false); EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); - EXPECT_EQ(6, computation->instruction_count()); + // CSE will remove both the second float(42.0f) and the corresponding + // convert/cast. + EXPECT_EQ(18, computation->instruction_count()); } TEST_F(HloCseTest, NonscalarConstants) { @@ -471,7 +486,7 @@ TEST_F(HloCseTest, DoNotCombineCallsToImpureFunctions) { } TEST_F(HloCseTest, CompareComputations) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule m add_computation { @@ -501,5 +516,25 @@ TEST_F(HloCseTest, CompareComputations) { EXPECT_EQ(root->operand(0), root->operand(1)); } +TEST_F(HloCseTest, ConstantsSameValueInDifferentDomains) { + // Test that constants with the same value but in different domains (disjoint + // in this case) are not collapsed. + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(2, computation->instruction_count()); + + HloCSE cse(/*is_layout_sensitive=*/false); + EXPECT_FALSE(cse.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(2, computation->instruction_count()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index b06e6c9f3e62f375a9e48f8ef81efe7121bbef94..d0200058683b2db8f5f0469d6c643014881f741e 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -363,7 +363,7 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { bool HloDataflowAnalysis::UpdateConditionalValueSet( HloInstruction* conditional) { CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); - std::vector inputs = { + const InstructionValueSet* const inputs[] = { &GetInstructionValueSet( conditional->true_computation()->root_instruction()), &GetInstructionValueSet( @@ -538,7 +538,7 @@ bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) { bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) { CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile); - std::vector inputs = { + const InstructionValueSet* const inputs[] = { &GetInstructionValueSet(xla_while->while_body()->root_instruction()), &GetInstructionValueSet(xla_while->operand(0))}; if (ssa_form_) { @@ -931,16 +931,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( } const HloUse& use = value.uses()[0]; - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return use.instruction == user->fused_expression_root() && - use.operand_number == 0; + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -967,6 +968,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( use.operand_number == other_add_operand_index; } } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || user->opcode() == HloOpcode::kWhile) { // We eliminated other users in BufferLiveness::live_range_strictly_before, @@ -998,8 +1000,13 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( }) != uses.end(); return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; } - // Check if 'user' is element-wise. - return user->IsElementwise(); + + // Loop fusions that contain transposing copies won't reach here as they have + // different layouts, which fails the check in the beginning of this function. + // + // Multi-output fusion will fail the check here as tuples are not considered + // an elementwise operation. + return user->IsElementwiseOnOperand(user->operand_index(operand)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 5798326dcbf65c3c34748afb02afab1dc7af9147..db1822ec47a7f52e2c3ef8dcbf433cd787ef75ab 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1974,6 +1974,89 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {})); } +TEST_F(CanShareOperandBufferWithUserTest, + NonElementwiseLoopFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape, "param0")); + + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0)); + + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(data_shape, neg, {0, 1})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {reverse, neg}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + MultiOutputFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + Shape in_shape = ShapeUtil::MakeShape(F32, {8}); + Shape out_shape = ShapeUtil::MakeShape(PRED, {8}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, in_shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, in_shape, "param1")); + + auto copy0 = builder.AddInstruction( + HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0)); + auto copy1 = builder.AddInstruction( + HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {1})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {0})); + EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {1})); +} + +TEST_F(CanShareOperandBufferWithUserTest, + ElementwiseLoopFusionCantAliasOperandBuffer) { + auto builder = HloComputation::Builder(TestName()); + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {exp, neg}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {}, + fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { auto builder = HloComputation::Builder(TestName()); @@ -2048,6 +2131,46 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { fusion, {})); } +TEST_F(CanShareOperandBufferWithUserTest, + FusedDynamicUpdateSliceWithConvertCantShare) { + auto builder = HloComputation::Builder(TestName()); + + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8}); + auto tuple = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple")); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple, 1)); + + auto convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(data_shape_bf16, gte1)); + + // Create a DynamicUpdateSlice instruction of tuple element 1. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({2}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({2.f, 2.f, 2.f}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape_bf16, convert1, update, starts)); + + auto convert2 = builder.AddInstruction( + HloInstruction::CreateConvert(data_shape, dynamic_update_slice)); + builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2})); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {convert2, dynamic_update_slice, starts, update, convert1}, + HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + // The fusion instruction can't share with tuple element 1. + EXPECT_FALSE( + dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {})); +} + TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { auto builder = HloComputation::Builder(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc new file mode 100644 index 0000000000000000000000000000000000000000..78955db0da02f16eb93689db947dc1190ab7049a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainIsolator::RunContext { + public: + RunContext(HloModule* module, HloDomainIsolator* isolator) + : module_(module), isolator_(isolator) {} + + StatusOr Run(); + + private: + // Inserts a kDomain instruction between parent and operand, in case + // the attribute (ie, sharding) values change between instruction and operand. + // Returns the newly inserted kDomain instruction, or nullptr if no kDomain + // instruction was necessary. + StatusOr CreateDomain(HloInstruction* instruction, + HloInstruction* parent, + HloInstruction* operand); + + HloModule* module_; + HloDomainIsolator* isolator_; +}; + +StatusOr HloDomainIsolator::RunContext::CreateDomain( + HloInstruction* instruction, HloInstruction* parent, + HloInstruction* operand) { + HloInstruction* domain = nullptr; + std::unique_ptr domain_instruction = + isolator_->creator_(instruction, operand); + if (domain_instruction != nullptr) { + domain = operand->parent()->AddInstruction(std::move(domain_instruction)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain)); + } + return domain; +} + +StatusOr HloDomainIsolator::RunContext::Run() { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator"); + + int64 added_domains = 0; + for (HloComputation* computation : module_->computations()) { + // Walk in post order and place all the required kDomain instructions. + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() == HloOpcode::kDomain) { + continue; + } + for (HloInstruction* operand : instruction->unique_operands()) { + // When applying multiple domains, we could end up stacking more than + // one in one edge, so here we want to build the effective + // (kDomain-less) instruction->operand edge. + HloInstruction* parent = instruction; + while (operand->opcode() == HloOpcode::kDomain) { + parent = operand; + operand = operand->mutable_operand(0); + } + // Check whether a kDomain is necessary between instruction and operand. + TF_ASSIGN_OR_RETURN(HloInstruction * domain, + CreateDomain(instruction, parent, operand)); + if (domain != nullptr) { + VLOG(4) << "New domain: " << domain->ToString(); + ++added_domains; + } + } + } + } + VLOG(3) << "Added " << added_domains << " kDomain instructions"; + if (added_domains > 0) { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Isolator"); + } + return added_domains > 0; +} + +HloDomainIsolator::HloDomainIsolator(DomainCreator creator) + : creator_(std::move(creator)) {} + +StatusOr HloDomainIsolator::Run(HloModule* module) { + RunContext run_context(module, this); + return run_context.Run(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h new file mode 100644 index 0000000000000000000000000000000000000000..e0c5718509dabebb7b9307bf764b0ea1ce7369a0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Domain isolation is the task of placing kDomain instructions between HLO +// instructions having different shrading. A kDomain instruction is essentially +// used to break an HLO graph edge connecting two instructions with different +// sharding. If a set of connected instructions have all the same sharding, no +// kDomain instruciton will be placed. +class HloDomainIsolator : public HloPassInterface { + public: + // Creates a new kDomain instruction for the edge between the use instruction + // (the first HloInstruction argument), and the operand instruction (the + // second HloInstruction argument). + // Returns nullptr in case no domain separation is necessary. + using DomainCreator = std::function( + HloInstruction*, HloInstruction*)>; + + explicit HloDomainIsolator(DomainCreator creator); + + tensorflow::StringPiece name() const override { return "domain_isolator"; } + + StatusOr Run(HloModule* module) override; + + private: + class RunContext; + + DomainCreator creator_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_ISOLATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc new file mode 100644 index 0000000000000000000000000000000000000000..ebd5adb5d573ce4b556046f85eb26a6ad59efcb9 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc @@ -0,0 +1,176 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" + +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +/* static */ StatusOr> HloDomainMap::Create( + HloComputation* computation, string domain_kind) { + auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + return std::move(domain_map); +} + +/* static */ StatusOr> HloDomainMap::Create( + HloModule* module, string domain_kind) { + auto domain_map = WrapUnique(new HloDomainMap(std::move(domain_kind))); + for (HloComputation* computation : module->computations()) { + TF_RETURN_IF_ERROR(domain_map->Populate(computation)); + } + return std::move(domain_map); +} + +bool HloDomainMap::InSameDomain(HloInstruction* instruction1, + HloInstruction* instruction2) const { + int64 domain_id1 = FindOrDefault(instruction_to_domain_, instruction1, -1); + int64 domain_id2 = FindOrDefault(instruction_to_domain_, instruction2, -1); + return domain_id1 >= 0 && domain_id1 == domain_id2; +} + +Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) { + TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain); + // We only check operands, so we are sure to not process the empty domain from + // both sides. + for (HloInstruction* operand : instruction->unique_operands()) { + if (IsDomainInstruction(operand)) { + auto domain = MakeUnique(); + domain->enter_domains.insert(operand); + domain->exit_domains.insert(instruction); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } + } + return Status::OK(); +} + +Status HloDomainMap::Populate(HloComputation* computation) { + for (HloInstruction* instruction : computation->instructions()) { + if (IsDomainInstruction(instruction)) { + // If this is a kDomain of the kind we are currently processing, check + // whether this is an "empty domain". + TF_RETURN_IF_ERROR(TryProcessEmptyDomain(instruction)); + continue; + } + int64 domain_id = FindOrDefault(instruction_to_domain_, instruction, -1); + if (domain_id >= 0) { + // We have already processed this instruction. + continue; + } + TF_ASSIGN_OR_RETURN(std::unique_ptr domain, + CreateDomain(instruction)); + TF_RETURN_IF_ERROR(InsertDomain(std::move(domain))); + } + return Status::OK(); +} + +Status HloDomainMap::InsertDomain( + std::unique_ptr domain) { + int64 domain_id = instruction_domains_.size(); + instruction_domains_.push_back(std::move(domain)); + for (HloInstruction* instruction : instruction_domains_.back()->reach_set) { + instruction_to_domain_[instruction] = domain_id; + } + return Status::OK(); +} + +Status HloDomainMap::ExpandDomain(HloInstruction* instruction, + DomainMetadata::Domain* domain) const { + std::vector in_queue; + in_queue.push_back(instruction); + while (!in_queue.empty()) { + HloInstruction* current_instruction = in_queue.back(); + in_queue.pop_back(); + if (domain->reach_set.insert(current_instruction).second) { + // We should not be finding instructions with assigned domain here. + // If we assigned a domain to the instruction, it means that all the + // instructions reached by it, should have a domain as well. + int64 domain_id = + FindOrDefault(instruction_to_domain_, current_instruction, -1); + TF_RET_CHECK(domain_id < 0) + << "Instruction " << current_instruction->ToString() + << " already has domain " << domain_id; + for (HloInstruction* operand : current_instruction->operands()) { + if (IsDomainInstruction(operand)) { + // The reach set instruction is a user of the domain instruction + // (the instruction sees the kDomain as operand). + // IOW the dataflow enters the domain through the kDomain instruction. + domain->enter_domains.insert(operand); + } else { + in_queue.push_back(operand); + } + } + for (HloInstruction* user : current_instruction->users()) { + if (IsDomainInstruction(user)) { + // The reach set instruction is an operand of the domain instruction + // (the instruction sees the kDomain as user). + // IOW the dataflow exits the domain through the kDomain instruction. + domain->exit_domains.insert(user); + } else { + in_queue.push_back(user); + } + } + } + } + return Status::OK(); +} + +StatusOr> HloDomainMap::CreateDomain( + HloInstruction* instruction) const { + auto domain = MakeUnique(); + TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get())); + domain->instructions = MakeNonDomainInstructions(domain->reach_set); + return std::move(domain); +} + +bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const { + if (instruction->opcode() != HloOpcode::kDomain) { + return false; + } + if (!domain_kind_.empty()) { + if (instruction->user_side_metadata().Kind() != domain_kind_) { + return false; + } + // Both user and operand side of the metadata must be of the same kind. + CHECK(instruction->operand_side_metadata().Kind() == domain_kind_) + << "Instruction " << instruction->ToString() + << " has mismatching metadata kinds"; + } + return true; +} + +/* static */ std::vector +HloDomainMap::MakeNonDomainInstructions( + const tensorflow::gtl::FlatSet& instruction_set) { + std::vector instructions; + instructions.reserve(instruction_set.size()); + for (HloInstruction* instruction : instruction_set) { + if (instruction->opcode() != HloOpcode::kDomain) { + instructions.push_back(instruction); + } + } + std::sort(instructions.begin(), instructions.end(), + [](HloInstruction* a, HloInstruction* b) { + return a->unique_id() < b->unique_id(); + }); + return instructions; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h new file mode 100644 index 0000000000000000000000000000000000000000..e62ef763fb3881ab6030b1f6a66266ac80a3d84d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_map.h @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// The HloDomainMap splits a set of instructions within a module or computation, +// into different domains, separated by kDomain instructions. +// A domain is composed by a set of instructions which can reach each other via +// operand/user edges, without crossing a kDomain insutrction of a given kind. +// A domain never crosses computation boundaries. +class HloDomainMap { + public: + // Creates a new HloDomainMap, creating all the domains within the input + // computation, of the given kind. If domain_kind is not empty, only the + // kDomain instructions of domain_kind will be considered as separators. + // Otherwise every kDomain instruction will be splitting domains. + static StatusOr> Create( + HloComputation* computation, string domain_kind); + + // Creates a new HloDomainMap, creating all the domains within the input + // module, of the given kind. If domain_kind is not empty, only the + // kDomain instructions of domain_kind will be considered as separators. + // Otherwise every kDomain instruction will be splitting domains. + static StatusOr> Create(HloModule* module, + string domain_kind); + + // Retrieves all the domains the input module or computation are composed by. + const std::vector>& GetDomains() + const { + return instruction_domains_; + } + + // Checks whether two instructions are within the same domain. + bool InSameDomain(HloInstruction* instruction1, + HloInstruction* instruction2) const; + + // Checks whether instruction is a kDomain instruction of the kind we are + // currently processing. + bool IsDomainInstruction(HloInstruction* instruction) const; + + private: + HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {} + + // Check if the kDomain instruction is facing (via its operand link) another + // kDomain instruction of the same kind, hence defining an empty domain. + // If that is the case, create the empty domain and call the proper + // normalizer. + Status TryProcessEmptyDomain(HloInstruction* instruction); + + Status Populate(HloComputation* computation); + + // Inserts the provided domain into the ones tracked by this object, + // creating a new domain ID. + Status InsertDomain(std::unique_ptr domain); + + // From the given instruction, epxands operand and user wise, the set of + // instructions which can be reached without crossing a kDomain instruction + // of the kind specified by domain_kind_. + // The domain data structure will be populated with all the reached + // instructions, and the boundaries of the domain, with the kDomain + // instructions encountered while expanding the reach. + Status ExpandDomain(HloInstruction* instruction, + DomainMetadata::Domain* domain) const; + + // Creates a domain data structure using the ExpandDomain() API. + StatusOr> CreateDomain( + HloInstruction* instruction) const; + + // Out of an instruction set, returns a vector of all the ones which are not + // a kDomain kind. + static std::vector MakeNonDomainInstructions( + const tensorflow::gtl::FlatSet& instruction_set); + + string domain_kind_; + std::vector> instruction_domains_; + tensorflow::gtl::FlatMap instruction_to_domain_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_MAP_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..aa0308100a21f109579de75788fce7d242d6a6b0 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Cannot include hlo_instruction.h as this file is included from there. +class HloInstruction; + +// The DomainMetadata represents the base class for metadata which can be +// attached to kDomain HLO instructions. +class DomainMetadata { + public: + // A Domain data structure captures all the information about a kDomain + // bounded instruction set. + struct Domain { + // The set of instructions which are reachable from each other via + // operand/user pathways, without crossing a kDomain instruction of a given + // kind. The reach_set can contain kDomain instructions of other kinds, if + // two domains of different kind intersect each other. + tensorflow::gtl::FlatSet reach_set; + + // The same instructions in reach_set, but purged from kDomain instructions. + std::vector instructions; + + // If we consider a graph edge as an arrow oriented from the operand to the + // user, the enter_domains will contain the set of kDomain instructions + // whose dataflow enters the reach set (domain), while the exit_domains + // contains the set of kDomain instructions whose dataflow exit the reach + // set. + tensorflow::gtl::FlatSet enter_domains; + tensorflow::gtl::FlatSet exit_domains; + }; + + virtual ~DomainMetadata() = default; + + // Clones the metadata object. + virtual std::unique_ptr Clone() const = 0; + + // Returns the metadata type. A unique identifier which describes the real + // metadata type. + virtual tensorflow::StringPiece Kind() const = 0; + + // Compares the metadata object with another one and returns true if the + // two matches. + virtual bool Matches(const DomainMetadata& other) const = 0; + + // Returns a string representation of the metadata. + virtual string ToString() const = 0; + + // Given a reachable set (the set of instructions which are reachable from + // each other via user/operand pathways, without crossing a kDomain + // instruciton), makes sure that all of them have metadata attributes which + // are coherent with this metadata object. + virtual Status NormalizeInstructions(const Domain& domain) const = 0; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.cc b/tensorflow/compiler/xla/service/hlo_domain_remover.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d06040b0e7c92b03f4cb5481bdee73a0f74f939 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.cc @@ -0,0 +1,149 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_domain_remover.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" +#include "tensorflow/compiler/xla/service/hlo_domain_map.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +class HloDomainRemover::RunContext { + public: + RunContext(HloModule* module, HloDomainRemover* remover) + : module_(module), remover_(remover) {} + + StatusOr Run(); + + private: + // Verifies the consistency of the domain, and normalizes the instructions + // within it. + Status VerifyAndNormalizeDomain(const DomainMetadata::Domain& domain); + + HloModule* module_; + HloDomainRemover* remover_; +}; + +Status HloDomainRemover::RunContext::VerifyAndNormalizeDomain( + const DomainMetadata::Domain& domain) { + // Verify that the whole kDomain frontier bounding the instruction reach set, + // has matching metadata. + // A kDomain instruction has two sides of metadata, a user facing and an + // operand facing. + // A reachable instruction set can make contact with a kDomain instruction on + // a user facing side (the kDomain is operand of the instruction), or on a + // operand facing side (the kDomain is user of the instruction). + // And depending on the contact side, the proper metadata object + // (user_side_metadata() vs. operand_side_metadata()) needs to be used for + // consistency checks. + const DomainMetadata* ref_metadata = nullptr; + VLOG(4) << "Reach set:"; + for (HloInstruction* instruction : domain.instructions) { + VLOG(4) << " " << instruction->name(); + } + VLOG(4) << " Domains:"; + for (HloInstruction* instruction : domain.enter_domains) { + const DomainMetadata& meta = instruction->user_side_metadata(); + VLOG(4) << " User side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + for (HloInstruction* instruction : domain.exit_domains) { + const DomainMetadata& meta = instruction->operand_side_metadata(); + VLOG(4) << " Operand side: " << instruction->name(); + VLOG(4) << " " << meta.ToString(); + if (ref_metadata == nullptr) { + ref_metadata = &meta; + } else { + TF_RET_CHECK(meta.Matches(*ref_metadata)) + << "Metadata mismatch at instruction " << instruction->name() << " : " + << meta.ToString() << " vs " << ref_metadata->ToString(); + } + } + if (ref_metadata != nullptr) { + VLOG(4) << "Applying domain normalization: " << ref_metadata->ToString(); + TF_RETURN_IF_ERROR(ref_metadata->NormalizeInstructions(domain)); + } else { + // No kDomain instruction was present within this domain, so call the + // generic normalization functions and have them apply their heuristic. + VLOG(2) << "Applying domain-less normalization"; + TF_RETURN_IF_ERROR(remover_->normalizer_(domain)); + } + return Status::OK(); +} + +StatusOr HloDomainRemover::RunContext::Run() { + VLOG(4) << "Processing metadata domain: '" << remover_->kind_ << "'"; + hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Remover"); + + int64 removed_domains = 0; + for (HloComputation* computation : module_->computations()) { + // First create the domain instruciton sets. A domain instruction set is + // the set of instructions whose edges never cross a kDomain instruction. + TF_ASSIGN_OR_RETURN(std::unique_ptr domain_map, + HloDomainMap::Create(computation, remover_->kind_)); + // Verify and normalize every domain populated within the map. + for (auto& domain : domain_map->GetDomains()) { + TF_RETURN_IF_ERROR(VerifyAndNormalizeDomain(*domain)); + } + + // Now remove all the kDomain instructions of the kind specified by the + // remover, that are within the currently processed computation from the + // graph. + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + for (HloInstruction* operand : instruction->unique_operands()) { + if (domain_map->IsDomainInstruction(operand)) { + VLOG(5) << "Removing " << operand->name(); + TF_RETURN_IF_ERROR( + operand->ReplaceAllUsesWith(operand->mutable_operand(0))); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(operand)); + ++removed_domains; + } + } + } + HloInstruction* root = computation->root_instruction(); + if (root != nullptr && domain_map->IsDomainInstruction(root)) { + VLOG(5) << "Removing " << root->name(); + computation->set_root_instruction(root->mutable_operand(0)); + TF_RETURN_IF_ERROR(computation->RemoveInstruction(root)); + ++removed_domains; + } + } + VLOG(3) << "Removed " << removed_domains << " kDomain instructions of '" + << remover_->kind_ << "' kind"; + if (removed_domains > 0) { + hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Remover"); + } + return removed_domains > 0; +} + +StatusOr HloDomainRemover::Run(HloModule* module) { + RunContext run_context(module, this); + return run_context.Run(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_domain_remover.h b/tensorflow/compiler/xla/service/hlo_domain_remover.h new file mode 100644 index 0000000000000000000000000000000000000000..0c71dd34fd4d2944037dc965a2c9ad2c592d6e3e --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_remover.h @@ -0,0 +1,56 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ + +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { + +// Removes all the kDomain instructions of a given kind from the input module, +// and calls the normalizer to propagate the properties on the possibly new born +// instructions. +class HloDomainRemover : public HloPassInterface { + public: + // Creates a new HloDomainRemover object tasked at removing all the kDomain + // instructions of a given kind. + // In case a reachable set (the set of instructions within a computation, + // which are mutually reachable via operand/user pathways) has all the + // instructions in it with the same attributes (ie, sharding), a normalizer + // function is tasked at applying attribute normalization on the instructions + // within such domain. + HloDomainRemover( + tensorflow::StringPiece kind, + std::function normalizer) + : kind_(kind.ToString()), normalizer_(std::move(normalizer)) {} + + tensorflow::StringPiece name() const override { return "domain_remover"; } + + StatusOr Run(HloModule* module) override; + + private: + class RunContext; + + string kind_; + std::function normalizer_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DOMAIN_REMOVER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5553ddb153f7f1f2e6a790890c11f35e192488c4 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -0,0 +1,432 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_domain_isolator.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_domain_remover.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +class HloDomainTest : public HloTestBase { + protected: + bool FindUserViaDomainPath(HloInstruction* instruction, + HloInstruction* operand) const { + for (HloInstruction* user : operand->users()) { + if (user == instruction) { + return true; + } + if (user->opcode() == HloOpcode::kDomain && + FindUserViaDomainPath(instruction, user)) { + return true; + } + } + return false; + } + + // Checks whether there is a kDomain instruction in the edge between the + // instruction and the operand. + bool HasDomainEdge(HloModule* module, + tensorflow::StringPiece instruction_name, + tensorflow::StringPiece operand_name) { + HloInstruction* instruction = FindInstruction(module, instruction_name); + HloInstruction* operand = FindInstruction(module, operand_name); + CHECK_NE(instruction, nullptr); + CHECK_NE(operand, nullptr); + if (!instruction->IsUserOf(operand)) { + // If instruction is not an immediate user, we must find a path from + // operand to instruction anyway, otherwise there is a corruption. + if (FindUserViaDomainPath(instruction, operand)) { + return true; + } + LOG(FATAL) << "Bad HLO module generated across the '" << instruction_name + << "' and '" << operand_name << "' instructions:\n" + << module->ToString(); + } + return false; + } + + StatusOr> ParseModule( + tensorflow::StringPiece hlo_string) { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + return ParseHloString(hlo_string, config); + } +}; + +// Dummy DomainMetadata implementation which create kDomain boundaries around +// HLO instructions with the same metadata().op_name() values. +class OpNameMetadata : public DomainMetadata { + public: + explicit OpNameMetadata(string opname) : opname_(std::move(opname)) {} + + std::unique_ptr Clone() const override { + return MakeUnique(opname_); + } + + tensorflow::StringPiece Kind() const override { return KindName(); } + + bool Matches(const DomainMetadata& other) const override { + const OpNameMetadata* other_ptr = + dynamic_cast(&other); + if (other_ptr == nullptr) { + // If other is not a OpNameMetadata, then it is clearly a no match. + return false; + } + return opname_ == other_ptr->opname_; + } + + string ToString() const override { return opname_; } + + Status NormalizeInstructions( + const DomainMetadata::Domain& domain) const override { + // For the purposes of this test, nothing to do. + return Status::OK(); + } + + static tensorflow::StringPiece KindName() { return "opname"; } + + private: + string opname_; +}; + +// Creator function for OpNameMetadata domains. +std::unique_ptr OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* operand) { + if (instruction->metadata().op_name() == operand->metadata().op_name()) { + return nullptr; + } + std::unique_ptr operand_side_metadata = + MakeUnique(operand->metadata().op_name()); + std::unique_ptr user_side_metadata = + MakeUnique(instruction->metadata().op_name()); + return HloInstruction::CreateDomain(operand->shape(), operand, + std::move(operand_side_metadata), + std::move(user_side_metadata)); +} + +Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain) { + // Nothing to do for the particular use this test make of the OpName domains. + return Status::OK(); +} + +TEST_F(HloDomainTest, CheckDomainLinks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(f32[4] a, f32[4] b), sharding={maximal device=1} + d = f32[4] subtract(a, b), sharding={maximal device=1} + e = f32[4] multiply(c, d), sharding={maximal device=1} + ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); +} + +TEST_F(HloDomainTest, CheckNoDomainAddedIfNoSharding) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(f32[4] a, f32[4] b) + d = f32[4] subtract(a, b) + e = f32[4] multiply(c, d) + ROOT f = (f32[4], f32[4], f32[4]) tuple(c, d, e) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(!isolator_changed); +} + +TEST_F(HloDomainTest, CheckDomainAroundIO) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = (f32[4], u32[]) send(a), channel_id=1, sharding={maximal device=0} + c = () send-done(b), channel_id=1, sharding={maximal device=0} + d = (f32[4], u32[]) recv(), channel_id=2, sharding={maximal device=0} + e = f32[4] recv-done(d), channel_id=2, sharding={maximal device=0} + f = f32[4] add(a, e) + g = f32[4] subtract(a, e) + ROOT h = (f32[4], f32[4]) tuple(f, g) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "f", "e")); + EXPECT_FALSE(HasDomainEdge(module.get(), "a", "p0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "b", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "f", "e")); +} + +TEST_F(HloDomainTest, CheckNoDomainAddedOnPureIOComputation) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=-1} + b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=-1} + c = f32[4] add(b, b), sharding={maximal device=-1} + d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=-1} + ROOT e = () send-done(d), channel_id=2, sharding={maximal device=-1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_FALSE(isolator_changed); +} + +TEST_F(HloDomainTest, CheckNormalizationOnPureIOComputation) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + a = (f32[4], u32[]) recv(), channel_id=1, sharding={maximal device=0} + b = f32[4] recv-done(a), channel_id=1, sharding={maximal device=0} + c = f32[4] add(b, b) + d = (f32[4], u32[]) send(c), channel_id=2, sharding={maximal device=0} + ROOT e = () send-done(d), channel_id=2, sharding={maximal device=0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_FALSE(remover_changed); + + HloInstruction* add = FindInstruction(module.get(), "c"); + ASSERT_NE(add, nullptr); + auto device = add->sharding_unique_device(); + EXPECT_TRUE(device.has_value()); + EXPECT_EQ(*device, 0); +} + +TEST_F(HloDomainTest, CheckMultiDomainLinks) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[4], f32[4]) parameter(0) + a = f32[4] get-tuple-element(p0), index=0 + b = f32[4] get-tuple-element(p0), index=1 + c = f32[4] add(a, b), sharding={maximal device=1} + d = f32[4] subtract(a, c), sharding={maximal device=1}, metadata={op_name="D"} + e = f32[4] multiply(c, d), sharding={maximal device=1}, metadata={op_name="D"} + f = f32[4] add(e, c), sharding={maximal device=1} + ROOT g = (f32[4], f32[4], f32[4]) tuple(c, d, f) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator sharding_isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed, + sharding_isolator.Run(module.get())); + EXPECT_TRUE(sharding_isolator_changed); + + HloDomainIsolator opname_isolator(OpNameDomainCreator); + TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, + opname_isolator.Run(module.get())); + EXPECT_TRUE(opname_isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_TRUE(HasDomainEdge(module.get(), "d", "c")); + EXPECT_FALSE(HasDomainEdge(module.get(), "e", "d")); + + HloDomainRemover sharding_remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, + sharding_remover.Run(module.get())); + EXPECT_TRUE(sharding_remover_changed); + + HloDomainRemover opname_remover(OpNameMetadata::KindName(), + OpNameDomainNormalizer); + TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, + opname_remover.Run(module.get())); + EXPECT_TRUE(opname_remover_changed); + + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "c", "b")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "a")); + EXPECT_FALSE(HasDomainEdge(module.get(), "d", "c")); +} + +TEST_F(HloDomainTest, CheckNormalizationOnInfeedTuple) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + infeed = (f32[4], f32[4]) infeed(), + sharding={{maximal device=1}, {maximal device=0}} + gte0 = f32[4] get-tuple-element(infeed), index=0 + gte1 = f32[4] get-tuple-element(infeed), index=1 + copy0 = f32[4] copy(gte0) + copy1 = f32[4] copy(gte1) + ROOT add = f32[4] add(copy0, copy1) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator isolator(CreateShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module.get())); + EXPECT_TRUE(isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module.get(), "gte0", "infeed")); + EXPECT_TRUE(HasDomainEdge(module.get(), "gte1", "infeed")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy0", "gte0")); + EXPECT_FALSE(HasDomainEdge(module.get(), "copy1", "gte1")); + + // Inject unassigned tuple/gte within the infeed domain, to simulate the + // HLO passes adding unexpected instructions. + // + // infeed + // / \ + // GTE0 GTE1 + // / \ + // COPY0 COPY1 + // \ / + // \ / + // TUPLE + // | + // DOMAIN + HloInstruction* infeed = FindInstruction(module.get(), "infeed"); + ASSERT_NE(infeed, nullptr); + auto infeed_users = infeed->users(); + HloInstruction* new_gte0 = + infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed->shape(), 0), infeed, 0)); + HloInstruction* new_copy0 = + infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + new_gte0->shape(), HloOpcode::kCopy, new_gte0)); + HloInstruction* new_gte1 = + infeed->parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(infeed->shape(), 1), infeed, 1)); + HloInstruction* new_copy1 = + infeed->parent()->AddInstruction(HloInstruction::CreateUnary( + new_gte1->shape(), HloOpcode::kCopy, new_gte1)); + HloInstruction* new_tuple = infeed->parent()->AddInstruction( + HloInstruction::CreateTuple({new_copy0, new_copy1})); + for (HloInstruction* user : infeed_users) { + TF_EXPECT_OK(infeed->ReplaceUseWith(user, new_tuple)); + } + + HloDomainRemover remover(ShardingMetadata::KindName(), + NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool remover_changed, remover.Run(module.get())); + EXPECT_TRUE(remover_changed); + + struct Assignment { + HloInstruction* instruction; + int64 device; + } assignments[] = { + {new_gte0, 1}, + {new_copy0, 1}, + {new_gte1, 0}, + {new_copy1, 0}, + }; + for (auto& assignment : assignments) { + auto device = assignment.instruction->sharding_unique_device(); + EXPECT_TRUE(device.has_value()); + EXPECT_EQ(*device, assignment.device); + } + EXPECT_TRUE(new_tuple->has_sharding()); + EXPECT_EQ( + new_tuple->sharding(), + HloSharding::Tuple(new_tuple->shape(), {HloSharding::AssignDevice(1), + HloSharding::AssignDevice(0)})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index d236f83aeb9254b9c6e6d04629758ac2c8fd0da3..4ed1508d7067684a15d0fb7d86e69b055bc1333b 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -119,6 +119,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { return false; } + HloCloneContext context(module); bool changed = false; for (auto* computation : module->computations()) { for (auto* hlo : computation->MakeInstructionPostOrder()) { @@ -140,6 +141,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { // These are ops with embedded computations where it suffices to convert // the embedded computations instead of converting the ops themselves. if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || + opcode == HloOpcode::kCrossReplicaSum || opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kSelectAndScatter || @@ -180,7 +182,7 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); new_hlo = computation->AddInstruction( - hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + hlo->CloneWithNewOperands(shape, new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); new_hlo = ToElementType(new_hlo, eliminate_type_); @@ -189,16 +191,16 @@ StatusOr HloElementTypeConverter::Run(HloModule* module) { Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, replace_with_type_); - new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( - new_shape, new_operands, hlo->GetModule())); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(new_shape, new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); // Convert the elements of the result of `new_hlo` to produce a new // tuple with shape `old_shape`. new_hlo = ConvertTupleElements(new_hlo, old_shape); } else { - new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( - hlo->shape(), new_operands, hlo->GetModule())); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context)); TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo)); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index fa59a5fb2030b22aa9e6a59abbfba521d19adb51..e0648e14672c45e9a691fd6a674c9a2cd7605a12 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -309,6 +309,35 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( return result; } +StatusOr> HloEvaluator::EvaluateElementwiseBinaryOp( + HloOpcode opcode, const Literal& lhs, const Literal& rhs) { + std::unique_ptr lhs_instr = + HloInstruction::CreateConstant(lhs.CloneToUnique()); + std::unique_ptr rhs_instr = + HloInstruction::CreateConstant(rhs.CloneToUnique()); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateBinary(lhs.shape(), opcode, lhs_instr.get(), + rhs_instr.get()); + auto result = Evaluate(cloned_instruction.get()); + + cloned_instruction->DetachFromOperands(); + return result; +} + +StatusOr> HloEvaluator::EvaluateElementwiseUnaryOp( + HloOpcode opcode, const Literal& operand) { + std::unique_ptr operand_instr = + HloInstruction::CreateConstant(operand.CloneToUnique()); + + std::unique_ptr cloned_instruction = + HloInstruction::CreateUnary(operand.shape(), opcode, operand_instr.get()); + auto result = Evaluate(cloned_instruction.get()); + + cloned_instruction->DetachFromOperands(); + return result; +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; @@ -859,6 +888,36 @@ Status HloEvaluator::HandleGather(HloInstruction* gather) { return Status::OK(); } +Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { + const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0)); + + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(operand.shape())) + << "broadcast dimensions is of size: " << broadcast->dimensions().size() + << " and rank of operand_to_broadcast is: " + << ShapeUtil::Rank(operand.shape()); + // Checks that operand's dimensions are the same as the broadcast's + // dimensions along the dimensions to be broadcasted. + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == + operand.shape().dimensions(i)); + } + + TF_ASSIGN_OR_RETURN( + evaluated_[broadcast], + operand.Broadcast(broadcast->shape(), broadcast->dimensions())); + + return Status::OK(); +} + +Status HloEvaluator::HandleGenerateToken(HloInstruction* token) { + // Literals cannot represent a TOKEN shape so just create an empty tuple as + // the "result" of the kGenerateToken operation. + // TODO(b/109929053): Add support for TOKENs in Literals. + evaluated_[token] = Literal::MakeTuple({}); + return Status::OK(); +} + Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const auto result_shape = get_tuple_element->shape(); const int64 index = get_tuple_element->tuple_index(); @@ -914,9 +973,10 @@ Status HloEvaluator::HandleFusion(HloInstruction* fusion) { // Attach cloned computation to an empty HLO module so the existing ones are // not modified. HloModule empty_hlo_module("EmptyModuleForFusion", config); + HloCloneContext context(&empty_hlo_module); auto cloned_fused_computation = fusion->fused_instructions_computation()->Clone( - /*suffix=*/"clone_with_layout", &empty_hlo_module); + /*suffix=*/"clone_with_layout", &context); for (auto* instruction : cloned_fused_computation->instructions()) { LayoutUtil::SetToDefaultLayout(instruction->mutable_shape()); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 566d53a41427119ea3d429a60a4430068bc953b1..fc2fc9437b238a2e519401b2b121dfbef070e2dc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -109,6 +109,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { const std::unordered_map& substitutions); + StatusOr> EvaluateElementwiseBinaryOp( + HloOpcode opcode, const Literal& lhs, const Literal& rhs); + + StatusOr> EvaluateElementwiseUnaryOp( + HloOpcode opcode, const Literal& operand); + protected: // Make HloEvaluatorTypedVisitor a friend because it is logically part of this // class. @@ -166,6 +172,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; + Status HandleBroadcast(HloInstruction* broadcast) override; + + Status HandleGenerateToken(HloInstruction* token) override; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index ae5b5e0412ef99db9b72d645a954759ca0b9eb8b..72eb9930e92c340ab9f42cd563c27507623b2ba7 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -262,13 +262,13 @@ TEST_P(HloEvaluatorTest, DoesCosR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), - use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); + use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesSinR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), - use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); + use_bfloat16_ ? 0.031250 : 9.5367431640625E-7); } TEST_P(HloEvaluatorTest, DoesNotR2) { auto operand = @@ -333,7 +333,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) { result->EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - EXPECT_NEAR(value, literal_clone->Get(rindexes), 0x1.0P-5); + EXPECT_NEAR(value, literal_clone->Get(rindexes), 0.031250); }); } @@ -567,7 +567,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = Literal::CreateR2FromArray2D(*expected_array); - EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0x1.0P-5))); + EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0.031250))); } TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { @@ -1248,7 +1248,7 @@ void BM_ReducePrecisely(int num_iters) { HloComputation::Builder b("BM_ReducePrecisely"); HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_ReducePrecisely", VersionedComputationHandle(), config); + HloModule module("BM_ReducePrecisely", config); constexpr int kNumElements = 1 << 25; // float += 1 saturates at 1<<24 std::vector v(kNumElements, 1.0f); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 024e8751f79b8b73cb868f6cbd4603f3e94ca7ea..13f46407e33e36bdbef4c9032630101d6c18268f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -161,36 +161,6 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleRound(round); } - Status HandleBroadcast(HloInstruction* broadcast) override { - const Literal& operand_to_broadcast = - parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); - std::vector broadcast_indices( - ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); - - TF_RET_CHECK(broadcast->dimensions().size() == - ShapeUtil::Rank(operand_to_broadcast.shape())) - << "broadcast dimensions is of size: " << broadcast->dimensions().size() - << " and rank of operand_to_broadcast is: " - << ShapeUtil::Rank(operand_to_broadcast.shape()); - // Checks that operand's dimensions are the same as the broadcast's - // dimensions along the dimensions to be broadcasted. - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) == - operand_to_broadcast.shape().dimensions(i)); - } - - auto output = MakeUnique(broadcast->shape()); - TF_RETURN_IF_ERROR(output->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { - broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; - } - return operand_to_broadcast.Get(broadcast_indices); - })); - parent_->evaluated_[broadcast] = std::move(output); - return Status::OK(); - } - template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -1482,11 +1452,12 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // Evaluate computation with specified literal operands. auto curr_val_literal = Literal::CreateR0(curr_val); auto result_val_literal = Literal::CreateR0(result_val); - std::vector args = {result_val_literal.get(), - curr_val_literal.get()}; std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) + embedded_evaluator + .Evaluate( + *function, + {result_val_literal.get(), curr_val_literal.get()}) .ConsumeValueOrDie(); // Clear visit states so that we can use the evaluator again on // the same computation. @@ -1685,10 +1656,11 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { Literal::CreateR0(curr_val); const auto result_val_literal = Literal::CreateR0(result_val); - const std::vector args = { - result_val_literal.get(), curr_val_literal.get()}; std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) + embedded_evaluator + .Evaluate( + *function, + {result_val_literal.get(), curr_val_literal.get()}) .ConsumeValueOrDie(); // Clear visit states so that the we can use the evaluate again @@ -1990,10 +1962,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { // TODO(b/74360564): This is implementation defined behavior, but is // currently respected by all implementations. Change this if we ever decide - // to oficially document different behavior. + // to officially document different behavior. for (int64 i = 0; i < start.size(); ++i) { start[i] = std::min( - std::max(0LL, start[i]), + std::max(int64{0}, start[i]), operand_literal.shape().dimensions(i) - result_shape.dimensions(i)); } diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index 4900c813fdf037e65c6b42d027f1cbefb6ee9830..eba80c0f199f6224f4b46ac19af482c713585154 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/strings/strcat.h" namespace xla { @@ -29,7 +29,7 @@ using ::testing::ContainsRegex; class HloExecutionProfileTest : public HloTestBase {}; TEST_F(HloExecutionProfileTest, Basic) { - auto hlo_module = tools::Parse(R"( + auto hlo_module = ParseHloString(R"( HloModule test_module ENTRY entry_computation { lhs = f32[30,30]{1,0} parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 17e3c405f1e5269ddf2f03c031a1137f9bb14fcc..28fc6c4209bcc14d890f28d6a9935c55b76586c0 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -28,6 +28,8 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -321,13 +323,11 @@ optional MatchTrivialComputation(const HloComputation* computation) { class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, - const DebugOptions& debug_options, bool show_metadata, - bool show_backend_config, const HloExecutionProfile* profile, - NodeFilter filter) + const DebugOptions& debug_options, bool show_backend_config, + const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), label_(std::string(label)), debug_options_(debug_options), - show_metadata_(show_metadata), show_backend_config_(show_backend_config), profile_(profile), filter_(std::move(filter)) {} @@ -395,7 +395,6 @@ class HloDotDumper { const HloComputation* computation_; // never null const string label_; // overall name for the graph const DebugOptions& debug_options_; - const bool show_metadata_; const bool show_backend_config_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -430,7 +429,8 @@ class HloDotDumper { // When coloring by sharding information, we track the sharding string // representation to color association, by round-robin the color schemes. - std::unordered_map sharding_colors_; + std::unordered_map + sharding_colors_; int64 next_shard_color_ = 0; }; @@ -592,15 +592,26 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, const HloInstruction* parent_instr) { VLOG(2) << "Dumping subcomputation " << subcomp->name(); - const char* computation_fmt = R"(subgraph %s { -%s -label = <%s>; -labelloc = t; -tooltip = " "; -%s -} // %s + // Add an edge from the subcomputation to its parent node. If subcomp + // belongs to a fusion node, it's drawn in place of the fusion instruction, + // so there's no need to link those. + if (parent_instr->opcode() != HloOpcode::kFusion) { + const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); + VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() + << " as " << next_edge_id_; + edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); + const char* edge_fmt = + R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; + edges_.push_back(Printf( + edge_fmt, InstructionId(from), InstructionId(parent_instr), + SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); + } -)"; + // Have we already dumped this subcomputation? If so, generating the edge + // linking it and parent_instr is all we want to do in this function. + if (cluster_ids_.find(subcomp) != cluster_ids_.end()) { + return ""; + } cluster_ids_[subcomp] = next_cluster_id_++; @@ -647,25 +658,16 @@ tooltip = " "; string comp_body = DumpComputation(subcomp); - // Add an edge from the subcomputation to its parent node. If subcomp - // belongs to a fusion node, it's drawn in place of the fusion instruction, - // so there's no need to link those. - if (parent_instr->opcode() != HloOpcode::kFusion) { - const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); - VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() - << " as " << next_edge_id_; - edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); - const char* edge_fmt = - R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back(Printf( - edge_fmt, InstructionId(from), InstructionId(parent_instr), - SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); - } - - string computation = - Printf(computation_fmt, id, style, subcomp_label, comp_body, id); + const char* computation_fmt = R"(subgraph %s { +%s +label = <%s>; +labelloc = t; +tooltip = " "; +%s +} // %s - return computation; +)"; + return Printf(computation_fmt, id, style, subcomp_label, comp_body, id); } string HloDotDumper::DumpComputation(const HloComputation* comp) { @@ -723,11 +725,25 @@ string HloDotDumper::DumpRootTag() { to_id, node_body, node_shape, NodeColorAttributes(color)); } +static const HloConstantInstruction* TryGetFusionParameterConstant( + const HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) { + return nullptr; + } + const HloInstruction* fusion = instr->parent()->FusionInstruction(); + const HloInstruction* operand = fusion->operand(instr->parameter_number()); + return DynCast(operand); +} + bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { // If a node: // - // - is a tuple-shaped parameter, - // - is not a parameter to a fusion node, + // - is a parameter of a fusion node which is bound to a constant, + // + // or + // + // - is a tuple-shaped parameter, and + // - is not a parameter to a fusion node, and // - has at least kMinUsersToOmit users shown, and // - all of the shown users are get-tuple-elements, // @@ -735,6 +751,9 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { // // This helps us handle the common case where a while loop body has one big // tuple-shaped parameter. + if (TryGetFusionParameterConstant(instr) != nullptr) { + return true; + } const int kMinUsersToOmit = 3; return instr->opcode() == HloOpcode::kParameter && ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && @@ -791,22 +810,22 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { } // Build the text that will be displayed inside the node. string node_body = node_label; - for (const string& s : {trivial_subcomputation, node_metadata, - node_backend_config, extra_info, inlined_constants}) { + for (const string& s : {trivial_subcomputation, node_backend_config, + extra_info, inlined_constants}) { if (!s.empty()) { StrAppend(&node_body, "
", s); } } - return Printf(R"(%s [label=<%s>, shape=%s, tooltip=" ", %s];)" + return Printf(R"(%s [label=<%s>, shape=%s, tooltip="%s", %s];)" "\n", - InstructionId(instr), node_body, node_shape, + InstructionId(instr), node_body, node_shape, node_metadata, NodeColorAttributes(color)); } string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { - auto stringify_constant = [](const HloInstruction* constant) { + auto stringify_constant = [](const HloConstantInstruction* constant) { const auto& shape = constant->shape(); // If the shape has a dimension of size zero, print it as e.g. @@ -825,7 +844,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( *elem_count *= dim; } } - if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { + if (elem_count.has_value() && *elem_count <= 8) { return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } @@ -841,29 +860,26 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( ShapeUtil::HumanString(constant->shape())); }; - // Special case: If instr is a parameter to a fusion node, check whether the - // corresponding operand to the fusion node is a constant. - if (instr->opcode() == HloOpcode::kParameter && instr->IsFused()) { - const HloInstruction* fusion = instr->parent()->FusionInstruction(); - const HloInstruction* operand = fusion->operand(instr->parameter_number()); - if (operand->opcode() != HloOpcode::kConstant) { - return ""; - } - return StrCat("constant ", stringify_constant(operand)); - } - std::vector lines; for (int64 i = 0; i < instr->operand_count(); ++i) { const HloInstruction* operand = instr->operand(i); + const auto* constant_operand = DynCast(operand); optional operand_str; - if (operand->opcode() == HloOpcode::kConstant) { - operand_str = stringify_constant(operand); + if (constant_operand != nullptr) { + operand_str = stringify_constant(constant_operand); } else if (ShouldMergeIntoUsers(operand)) { - // Special case: If the operand is a parameter, use its parameter number - // rather than its name, because that's generally how people think of the - // node. + // Special case: If the operand is a parameter to a fusion node and it + // always has a constant value, display it like a regular constant. + // + // For other parameters, use the parameter number rather than the proper + // name, because that's generally how people think of the node. if (operand->opcode() == HloOpcode::kParameter) { - operand_str = Printf("Parameter %lld", operand->parameter_number()); + if (const HloConstantInstruction* constant = + TryGetFusionParameterConstant(operand)) { + operand_str = stringify_constant(constant); + } else { + operand_str = Printf("Parameter %lld", operand->parameter_number()); + } } else { operand_str = operand->name(); } @@ -885,24 +901,26 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { if (!instr->has_sharding()) { return kDashedBorder; } - string shard_str = instr->sharding().ToString(); - auto it = sharding_colors_.find(shard_str); + auto it = sharding_colors_.find(instr->sharding()); if (it != sharding_colors_.end()) { return it->second; } ColorScheme color = static_cast( kBlue + (next_shard_color_++ % (kDashedBorder - kBlue))); - sharding_colors_.emplace(shard_str, color); + sharding_colors_.emplace(instr->sharding(), color); return color; } const auto kParameterColor = kOrange; // Special case: If this instruction has a parameter merged into it, paint it - // the same color as a parameter. + // the same color as a parameter. Unless the merged-in parameter is a + // parameter to a fusion node that is bound to a constant -- these aren't + // "real" parameters from the user's perspective. if (std::any_of(instr->operands().begin(), instr->operands().end(), [&](const HloInstruction* operand) { return operand->opcode() == HloOpcode::kParameter && - ShouldMergeIntoUsers(operand); + ShouldMergeIntoUsers(operand) && + TryGetFusionParameterConstant(operand) == nullptr; })) { return kParameterColor; } @@ -965,6 +983,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kBitcast: case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: + case HloOpcode::kGenerateToken: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: @@ -976,7 +995,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { } return kGreen; case HloOpcode::kConcatenate: - case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: case HloOpcode::kGather: case HloOpcode::kPad: @@ -998,6 +1016,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kWhite; } return kGreen; + case HloOpcode::kCopy: + // Emphasize copy nodes, which are either physical transposes (and thus + // significant), or copies of read-only buffers (and thus dead weight). + return kGreen; case HloOpcode::kConvolution: case HloOpcode::kDot: case HloOpcode::kFft: @@ -1013,6 +1035,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReduceWindow: case HloOpcode::kSelectAndScatter: return kPurple; + case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kMap: return kGray; @@ -1068,10 +1091,6 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { } string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { - if (!show_metadata_) { - return ""; - } - std::vector lines; if (!instr->metadata().op_name().empty()) { lines.push_back(HtmlLikeStringSanitize(instr->metadata().op_name())); @@ -1091,11 +1110,11 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeBackendConfig( const HloInstruction* instr) { - if (!show_backend_config_ || instr->backend_config().empty()) { + if (!show_backend_config_ || instr->raw_backend_config_string().empty()) { return ""; } - return StrCat("backend_config=\"", instr->backend_config(), "\""); + return StrCat("backend_config=\"", instr->raw_backend_config_string(), "\""); } string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { @@ -1154,6 +1173,20 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { return Join(lines, "
"); } +// Gets the total number of array elements in the given shape. For tuples, this +// is the sum of all the sizes of all of the array elements recursively in the +// tuple. +static int64 TotalElementsInShape(const Shape& shape) { + int64 elems = 0; + ShapeUtil::ForEachSubshape( + shape, [&](const Shape& subshape, const ShapeIndex& /*index*/) { + if (ShapeUtil::IsArray(subshape)) { + elems += ShapeUtil::ElementsIn(subshape); + } + }); + return elems; +} + void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, int64 operand_num, bool control_edge = false) { @@ -1173,9 +1206,16 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { } else if (control_edge) { edge_label = "style=\"dotted\" color=\"gray\" label=\"ctrl\""; } - const char* kEdgeFmt = R"(%s -> %s [tooltip="%s -> %s" %s];)"; + + // We print "small" arrays using a hollow arrowhead and "large" arrays using + // a filled arrowhead. For now, we use an arbitrary cutoff for what "big" + // means. + bool is_big_array = TotalElementsInShape(from->shape()) >= 4096; + + const char* kEdgeFmt = R"(%s -> %s [arrowhead=%s tooltip="%s -> %s" %s];)"; edges_.push_back(Printf(kEdgeFmt, InstructionId(from), InstructionId(to), - from->name(), to->name(), edge_label)); + (is_big_array ? "normal" : "empty"), from->name(), + to->name(), edge_label)); }; // Add edges from instr's operands to instr. Parameters within fusion @@ -1425,7 +1465,7 @@ string ExportGraph(const string& graph, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, - bool show_metadata, bool show_backend_config) { + bool show_backend_config) { GraphRendererInterface::GraphKind graph_kind; string graph; if (debug_options.xla_hlo_dump_as_graphdef()) { @@ -1436,8 +1476,8 @@ string DumpGraph(const HloComputation& computation, const string& label, graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { graph = - HloDotDumper(&computation, label, debug_options, show_metadata, - show_backend_config, hlo_execution_profile, NodeFilter()) + HloDotDumper(&computation, label, debug_options, show_backend_config, + hlo_execution_profile, NodeFilter()) .Dump(); graph_kind = GraphRendererInterface::DOT_GRAPH; } @@ -1449,15 +1489,15 @@ string DumpGraph(const HloComputation& computation, const string& label, } string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata, bool show_backend_config) { + bool show_backend_config) { auto debug_options = node.GetModule()->config().debug_options(); string label = StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); - string graph = HloDotDumper(node.parent(), label, debug_options, - show_metadata, show_backend_config, - /*profile=*/nullptr, filter) - .Dump(); + string graph = + HloDotDumper(node.parent(), label, debug_options, show_backend_config, + /*profile=*/nullptr, filter) + .Dump(); return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index fc8e1468aca9c2edbc22c30a41a1be8b32a1feca..0b11f34abb7f0d937a24d11f4dc5d2d6a0aae6e7 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label, string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile = nullptr, - bool show_metadata = false, bool show_backend_config = false); + bool show_backend_config = false); // Like DumpGraph, but renders only nodes "near" the given node in the graph. // @@ -64,7 +64,6 @@ string DumpGraph(const HloComputation& computation, const string& label, // (roughly) corresponds to the max distance a node may be from the primary node // before it's omitted from the graph. string DumpNeighborhoodAround(const HloInstruction& node, int radius, - bool show_metadata = false, bool show_backend_config = false); // Dumps the HloModule::ToString() as a file into the provided directory path diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 8e52d926d85f1ce6fabeb2dedd2f8e0fe0c2051d..68f41a1cbb4db228f5dcf8b4a6130f05e81262a8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -121,7 +121,7 @@ TEST(HloGraphDumperTest, Constant) { HloComputation::Builder b("b"); auto instruction = b.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(-42))); - instruction->set_name("i_am_a_constant_root_instruction"); + instruction->SetAndSanitizeName("i_am_a_constant_root_instruction"); HloModuleConfig config; HloModule m(TestName(), config); HloComputation* root_computation = m.AddEntryComputation(b.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index db1c33e2f0dfa0599810ab2e8d32209e64c5c865..39662d1735e4b411ef36cbc8421eabe52be6165a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -16,7 +16,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include -#include #include #include #include @@ -27,7 +26,9 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -37,9 +38,11 @@ limitations under the License. #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/human_readable_json.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -58,68 +61,218 @@ StatusOr> HloInstruction::CreateFromProto( TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); - auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); - for (const int64 operand_id : proto.operand_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) - << "No instruction with id " << operand_id; - instruction->AppendOperand(instruction_map.at(operand_id)); - } - for (const int64 predecessor_id : proto.control_predecessor_ids()) { - TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) - << "No instruction with id " << predecessor_id; - TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) - ->AddControlDependencyTo(instruction.get())); - } - - // In the proto, fused computations are held exclusively within the - // HloInstructionProto and do not appear as an HloComputationProto within the - // HloModuleProto. - if (instruction->opcode() == HloOpcode::kFusion) { - TF_RET_CHECK(!proto.fusion_kind().empty()); - TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, - StringToFusionKind(proto.fusion_kind())); - - // Find the fused computation and set its fusion instruction. - TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "Expect 1 called computation for fusion instruction, but sees " - << proto.called_computation_ids_size(); - const int64 fusion_id = proto.called_computation_ids(0); - auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); - TF_RET_CHECK(fused_computation != nullptr) - << "No fusion computation with id " << fusion_id; - fused_computation->SetFusionInstruction(instruction.get()); - instruction->called_computations_.push_back(fused_computation); - } else { - for (const int64 computation_id : proto.called_computation_ids()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_id)) - << "No computation with id " << computation_id; - instruction->called_computations_.push_back( - computation_map.at(computation_id)); + std::unique_ptr instruction; + const auto operands = [&instruction_map, &proto](int index) { + return instruction_map.at(proto.operand_ids(index)); + }; + const auto computations = [&computation_map, &proto](int index) { + return computation_map.at(proto.called_computation_ids(index)); + }; + switch (opcode) { + // Ops migrated to subclasses. + case HloOpcode::kBatchNormTraining: + CHECK_EQ(proto.operand_ids_size(), 3); + instruction = CreateBatchNormTraining( + proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(), + proto.feature_index()); + break; + case HloOpcode::kBatchNormInference: + CHECK_EQ(proto.operand_ids_size(), 5); + instruction = CreateBatchNormInference( + proto.shape(), operands(0), operands(1), operands(2), operands(3), + operands(4), proto.epsilon(), proto.feature_index()); + break; + case HloOpcode::kBatchNormGrad: + CHECK_EQ(proto.operand_ids_size(), 5); + instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1), + operands(2), operands(3), operands(4), + proto.epsilon(), proto.feature_index()); + break; + case HloOpcode::kFft: { + CHECK_EQ(proto.operand_ids_size(), 1); + std::vector fft_length(proto.fft_length().begin(), + proto.fft_length().end()); + instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), + tensorflow::gtl::ArraySlice(fft_length)); + break; + } + case HloOpcode::kSend: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateSend(operands(0), proto.channel_id()); + break; + case HloOpcode::kSendDone: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateSendDone(operands(0)); + break; + case HloOpcode::kRecv: + CHECK_EQ(proto.operand_ids_size(), 0); + instruction = + CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id()); + break; + case HloOpcode::kRecvDone: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateRecvDone(operands(0)); + break; + case HloOpcode::kReverse: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateReverse(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kConcatenate: { + CHECK_EQ(proto.dimensions_size(), 1); + std::vector concat_operands(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + concat_operands.begin(), + [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateConcatenate(proto.shape(), concat_operands, + proto.dimensions(0)); + break; + } + case HloOpcode::kReduce: + CHECK_EQ(proto.operand_ids_size(), 2); + CHECK_EQ(proto.called_computation_ids_size(), 1); + instruction = CreateReduce(proto.shape(), operands(0), operands(1), + std::vector(proto.dimensions().begin(), + proto.dimensions().end()), + computations(0)); + break; + case HloOpcode::kTranspose: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = + CreateTranspose(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kBroadcast: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = + CreateBroadcast(proto.shape(), operands(0), + std::vector(proto.dimensions().begin(), + proto.dimensions().end())); + break; + case HloOpcode::kMap: { + CHECK_EQ(proto.called_computation_ids_size(), 1); + std::vector map_operands(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + map_operands.begin(), + [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateMap(proto.shape(), map_operands, computations(0)); + break; + } + case HloOpcode::kSlice: { + CHECK_EQ(proto.operand_ids_size(), 1); + std::vector slice_starts, slice_limits, slice_strides; + for (const HloInstructionProto::SliceDimensions& slice_dimensions : + proto.slice_dimensions()) { + slice_starts.push_back(slice_dimensions.start()); + slice_limits.push_back(slice_dimensions.limit()); + slice_strides.push_back(slice_dimensions.stride()); + } + instruction = CreateSlice(proto.shape(), operands(0), slice_starts, + slice_limits, slice_strides); + break; + } + case HloOpcode::kConstant: { + CHECK(proto.has_literal()); + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateConstant(std::move(literal)); + break; + } + case HloOpcode::kTrace: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Trace instruction should have 1 operand but sees " + << proto.operand_ids_size(); + CHECK(proto.has_literal()); + TF_ASSIGN_OR_RETURN(auto literal, + Literal::CreateFromProto(proto.literal())); + instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); + break; + } + case HloOpcode::kFusion: { + // In the proto, fused computations are held exclusively within the + // HloInstructionProto and do not appear as an HloComputationProto within + // the HloModuleProto. + TF_RET_CHECK(!proto.fusion_kind().empty()); + TF_ASSIGN_OR_RETURN(FusionKind fusion_kind, + StringToFusionKind(proto.fusion_kind())); + + // Find the fused computation and set its fusion instruction. + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Expect 1 called computation for fusion instruction, but sees " + << proto.called_computation_ids_size(); + const int64 fusion_id = proto.called_computation_ids(0); + auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); + TF_RET_CHECK(fused_computation != nullptr) + << "No fusion computation with id " << fusion_id; + std::vector fusion_operands(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + fusion_operands.begin(), + [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateFusion(proto.shape(), fusion_kind, fusion_operands, + fused_computation); + break; + } + case HloOpcode::kRng: { + std::vector rng_parms(proto.operand_ids_size()); + std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), + rng_parms.begin(), [&instruction_map](int64 operand_id) { + return instruction_map.at(operand_id); + }); + instruction = CreateRng(proto.shape(), proto.distribution(), rng_parms); + break; + } + case HloOpcode::kParameter: + instruction = CreateParameter(proto.parameter_number(), proto.shape(), + proto.name()); + break; + case HloOpcode::kGetTupleElement: + CHECK_EQ(proto.operand_ids_size(), 1); + instruction = CreateGetTupleElement(proto.shape(), operands(0), + proto.tuple_index()); + break; + case HloOpcode::kReducePrecision: + instruction = + CreateReducePrecision(proto.shape(), operands(0), + proto.exponent_bits(), proto.mantissa_bits()); + break; + default: { + instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + for (const int64 operand_id : proto.operand_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, operand_id)) + << "No instruction with id " << operand_id; + instruction->AppendOperand(instruction_map.at(operand_id)); + } + for (const int64 predecessor_id : proto.control_predecessor_ids()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id)) + << "No instruction with id " << predecessor_id; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id) + ->AddControlDependencyTo(instruction.get())); + } + if (instruction->opcode() != HloOpcode::kFusion) { + for (const int64 computation_id : proto.called_computation_ids()) { + TF_RET_CHECK(ContainsKey(computation_map, computation_id)) + << "No computation with id " << computation_id; + instruction->called_computations_.push_back( + computation_map.at(computation_id)); + } + } + break; } - } - - if (instruction->opcode() == HloOpcode::kTrace) { - TF_RET_CHECK(instruction->operands().size() == 1) - << "Trace instruction should have 1 operand but sees " - << instruction->operands().size(); - instruction->mutable_operand(0)->set_tracing(instruction.get()); } TF_RET_CHECK(!proto.name().empty()); - instruction->name_ = proto.name(); - + instruction->SetAndSanitizeName(proto.name()); instruction->metadata_ = proto.metadata(); - instruction->set_backend_config(proto.backend_config()); - if (proto.has_literal()) { - TF_ASSIGN_OR_RETURN(instruction->literal_, - Literal::CreateFromProto(proto.literal())); - } - instruction->parameter_number_ = proto.parameter_number(); + instruction->backend_config_ = proto.backend_config(); - instruction->tuple_index_ = proto.tuple_index(); - for (int64 dimension : proto.dimensions()) { - instruction->dimensions_.push_back(dimension); - } if (proto.has_window()) { instruction->window_ = MakeUnique(proto.window()); } @@ -132,14 +285,7 @@ StatusOr> HloInstruction::CreateFromProto( instruction->dot_dimension_numbers_ = MakeUnique(proto.dot_dimension_numbers()); } - for (const HloInstructionProto::SliceDimensions& slice_dimensions : - proto.slice_dimensions()) { - instruction->slice_starts_.push_back(slice_dimensions.start()); - instruction->slice_limits_.push_back(slice_dimensions.limit()); - instruction->slice_strides_.push_back(slice_dimensions.stride()); - } - instruction->exponent_bits_ = proto.exponent_bits(); - instruction->mantissa_bits_ = proto.mantissa_bits(); + for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) { instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size); } @@ -148,17 +294,9 @@ StatusOr> HloInstruction::CreateFromProto( MakeUnique(proto.padding_config()); } instruction->outfeed_config_ = proto.outfeed_config(); - instruction->distribution_ = proto.distribution(); - instruction->epsilon_ = proto.epsilon(); - instruction->feature_index_ = proto.feature_index(); - instruction->channel_id_ = proto.channel_id(); instruction->infeed_config_ = proto.infeed_config(); instruction->custom_call_target_ = proto.custom_call_target(); instruction->outfeed_shape_ = proto.outfeed_shape(); - instruction->fft_type_ = proto.fft_type(); - for (int64 fft_len : proto.fft_length()) { - instruction->fft_length_.push_back(fft_len); - } if (proto.has_sharding()) { TF_ASSIGN_OR_RETURN(const auto& sharding, @@ -177,57 +315,38 @@ StatusOr> HloInstruction::CreateFromProto( instruction->channel_name_ = proto.channel_name(); instruction->cost_estimate_ns_ = proto.cost_estimate_ns(); + for (int64 replica_group_id : proto.replica_group_ids()) { + instruction->replica_group_ids_.push_back(replica_group_id); + } + return std::move(instruction); } /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kParameter, shape)); - instruction->parameter_number_ = parameter_number; - instruction->name_ = name; - return instruction; + return MakeUnique(parameter_number, shape, name); } /* static */ std::unique_ptr HloInstruction::CreateTrace( const string& tag, HloInstruction* operand) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); - instruction->operands_.push_back(operand); - instruction->literal_ = Literal::CreateR1U8(tag); - operand->set_tracing(instruction.get()); - return instruction; + return MakeUnique(tag, operand); } /* static */ std::unique_ptr HloInstruction::CreateConstant( std::unique_ptr literal) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape())); - instruction->literal_ = std::move(literal); - return instruction; + return MakeUnique(std::move(literal)); } /* static */ std::unique_ptr HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { - CHECK(ShapeUtil::IsTuple(operand->shape())); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape)); - instruction->tuple_index_ = index; - instruction->AppendOperand(operand); - return instruction; + return MakeUnique(shape, operand, index); } /* static */ std::unique_ptr HloInstruction::CreateRng( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRng, shape)); - instruction->distribution_ = distribution; - instruction->shape_ = shape; - for (HloInstruction* param : parameters) { - instruction->AppendOperand(param); - } - return instruction; + return MakeUnique(shape, distribution, parameters); } /* static */ std::unique_ptr HloInstruction::CreateNary( @@ -256,6 +375,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kClz: + case HloOpcode::kDomain: case HloOpcode::kExp: case HloOpcode::kExpm1: case HloOpcode::kFloor: @@ -341,13 +461,8 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, const Shape& shape, tensorflow::gtl::ArraySlice operands, HloComputation* map_computation, tensorflow::gtl::ArraySlice static_operands) { - CHECK(static_operands.empty()) << "static_operands not yet supported"; - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->called_computations_.push_back(map_computation); - return instruction; + return MakeUnique(shape, operands, map_computation, + static_operands); } /* static */ std::unique_ptr HloInstruction::CreateConvolve( @@ -373,11 +488,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateFft( const Shape& shape, HloInstruction* operand, FftType fft_type, tensorflow::gtl::ArraySlice fft_length) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape)); - instruction->AppendOperand(operand); - instruction->fft_type_ = fft_type; - instruction->fft_length_.assign(fft_length.begin(), fft_length.end()); - return instruction; + return MakeUnique(shape, operand, fft_type, fft_length); } /* static */ std::unique_ptr HloInstruction::CreateDot( @@ -410,18 +521,29 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape)); - instruction->AppendOperand(operand); - instruction->exponent_bits_ = exponent_bits; - instruction->mantissa_bits_ = mantissa_bits; - return instruction; + return MakeUnique( + shape, operand, exponent_bits, mantissa_bits); } /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice operands) { - return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& channel_id) { + // TODO(b/79737069): Remove the CHECK when supported. + CHECK(!channel_id.has_value()); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->called_computations_.push_back(reduce_computation); + instruction->replica_group_ids_.assign(replica_group_ids.begin(), + replica_group_ids.end()); + instruction->cross_replica_sum_barrier_ = std::string(barrier); + return instruction; } /* static */ std::unique_ptr HloInstruction::CreateInfeed( @@ -447,56 +569,44 @@ HloInstruction::CreateCrossReplicaSum( /* static */ std::unique_ptr HloInstruction::CreateSend( HloInstruction* operand, int64 channel_id) { - // Send instruction produces a tuple of {aliased operand, U32 context}. - Shape output_shape = ShapeUtil::MakeTupleShape( - {operand->shape(), ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = channel_id; - return instruction; + return MakeUnique(operand, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateSendDone( HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kSend) + auto send_operand = DynCast(operand); + CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - auto instruction = WrapUnique( - new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil())); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; + return MakeUnique(send_operand); } /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, int64 channel_id) { - // Recv instruction produces a tuple of {receive buffer, U32 context}. - Shape output_shape = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape)); - instruction->channel_id_ = channel_id; - return instruction; + return MakeUnique(shape, channel_id); } /* static */ std::unique_ptr HloInstruction::CreateRecvDone( HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kRecv) + auto recv_operand = DynCast(operand); + CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; + return MakeUnique(recv_operand); } /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); + return MakeUnique(shape, operand, dimensions); +} + +/* static */ std::unique_ptr +HloInstruction::CreateGenerateToken( + tensorflow::gtl::ArraySlice operands) { + auto instruction = WrapUnique(new HloInstruction( + HloOpcode::kGenerateToken, ShapeUtil::MakeTokenShape())); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } return instruction; } @@ -533,18 +643,8 @@ HloInstruction::CreateCrossReplicaSum( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape)); - instruction->AppendOperand(operand); - instruction->slice_starts_.assign(start_indices.begin(), start_indices.end()); - instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end()); - instruction->slice_strides_.assign(strides.begin(), strides.end()); - // For backward compatibility with old serialized computations: if there are - // no strides, assume all strides are 1. - // TODO(b/63317920): remove this code. - if (instruction->slice_strides_.empty()) { - instruction->slice_strides_ = std::vector(start_indices.size(), 1LL); - } - return instruction; + return MakeUnique(shape, operand, start_indices, + limit_indices, strides); } /* static */ std::unique_ptr HloInstruction::CreateDynamicSlice( @@ -575,13 +675,7 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateConcatenate( const Shape& shape, tensorflow::gtl::ArraySlice operands, int64 dimension) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); - } - instruction->dimensions_.push_back(dimension); - return instruction; + return MakeUnique(shape, operands, dimension); } /* static */ std::unique_ptr HloInstruction::CreateConvert( @@ -604,13 +698,8 @@ HloInstruction::CreateBitcastConvert(const Shape& shape, const Shape& shape, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, HloComputation* reduce_computation) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape)); - instruction->AppendOperand(arg); - instruction->AppendOperand(init_value); - instruction->dimensions_.assign(dimensions_to_reduce.begin(), - dimensions_to_reduce.end()); - instruction->called_computations_.push_back(reduce_computation); - return instruction; + return MakeUnique( + shape, arg, init_value, dimensions_to_reduce, reduce_computation); } /* static */ std::unique_ptr HloInstruction::CreateReduceWindow( @@ -631,14 +720,8 @@ HloInstruction::CreateBatchNormTraining(const Shape& shape, HloInstruction* scale, HloInstruction* offset, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(offset); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique( + shape, operand, scale, offset, epsilon, feature_index); } /* static */ std::unique_ptr @@ -646,16 +729,8 @@ HloInstruction::CreateBatchNormInference( const Shape& shape, HloInstruction* operand, HloInstruction* scale, HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormInference, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(offset); - instruction->AppendOperand(mean); - instruction->AppendOperand(variance); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique( + shape, operand, scale, offset, mean, variance, epsilon, feature_index); } /* static */ std::unique_ptr @@ -664,16 +739,9 @@ HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand, HloInstruction* variance, HloInstruction* grad_output, float epsilon, int64 feature_index) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBatchNormGrad, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(scale); - instruction->AppendOperand(mean); - instruction->AppendOperand(variance); - instruction->AppendOperand(grad_output); - instruction->epsilon_ = epsilon; - instruction->feature_index_ = feature_index; - return instruction; + return MakeUnique(shape, operand, scale, mean, + variance, grad_output, epsilon, + feature_index); } /* static */ std::unique_ptr @@ -696,12 +764,8 @@ HloInstruction::CreateSelectAndScatter( /* static */ std::unique_ptr HloInstruction::CreateBroadcast( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(broadcast_dimensions.begin(), - broadcast_dimensions.end()); - return instruction; + return MakeUnique(shape, operand, + broadcast_dimensions); } /* static */ std::unique_ptr @@ -780,45 +844,29 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { - CHECK_EQ(shape.dimensions().size(), dimensions.size()); - CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); - CHECK(std::equal(operand->shape().dimensions().begin(), - operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())) - << "shape: " << ShapeUtil::HumanString(shape) - << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << Join(dimensions, ", ") << "}"; - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape)); - instruction->AppendOperand(operand); - instruction->dimensions_.assign(dimensions.begin(), dimensions.end()); - return instruction; + return MakeUnique(shape, operand, dimensions); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - instruction->fusion_kind_ = fusion_kind; - instruction->name_ = "fusion"; - instruction->set_parent(fused_root->parent()); - instruction->set_metadata(fused_root->metadata()); - instruction->CloneAndFuseInternal(fused_root); - return instruction; + return MakeUnique(shape, fusion_kind, fused_root); } /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); - for (auto operand : operands) { - instruction->AppendOperand(operand); + return MakeUnique(shape, fusion_kind, operands, + fusion_computation); +} + +void HloInstruction::set_single_sharding(const HloSharding& sharding) { + CHECK(!sharding.IsTuple()) << sharding; + if (ShapeUtil::IsTuple(shape())) { + set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape()))); + } else { + set_sharding(sharding); } - instruction->fusion_kind_ = fusion_kind; - instruction->name_ = "fusion"; - instruction->called_computations_.push_back(fusion_computation); - fusion_computation->SetFusionInstruction(instruction.get()); - return instruction; } void HloInstruction::SetupDerivedInstruction( @@ -831,289 +879,6 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->set_metadata(metadata_); } -HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { - CHECK_EQ(opcode(), HloOpcode::kFusion); - CHECK_EQ(operand_count(), - fused_instructions_computation()->parameter_instructions().size()); - const int64 param_no = operand_count(); - // Name the parameter after the instruction it represents in the outer - // (non-fusion) computation. - string param_name = StrCat(new_operand->name(), ".param_", param_no); - HloInstruction* fused_parameter = - fused_instructions_computation()->AddParameter( - HloInstruction::CreateParameter(param_no, new_operand->shape(), - param_name)); - AppendOperand(new_operand); - return fused_parameter; -} - -void HloInstruction::MergeFusionInstruction( - HloInstruction* instruction_to_merge) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion); - CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != - operands().end()); - // Clone the instruction from which to merge fused instructions. - std::unique_ptr clone = instruction_to_merge->Clone(); - // Replace uses of fused parameters with the corresponding operand of the - // fusion. Add all non-parameter fused instructions to 'unfused_instructions' - // to be merged into 'this'. This is done in reverse post order. - std::vector unfused_instructions; - auto fused_instructions = - clone->fused_instructions_computation()->MakeInstructionPostOrder(); - for (auto fused_it = fused_instructions.rbegin(); - fused_it != fused_instructions.rend(); ++fused_it) { - auto fused_instruction = *fused_it; - if (fused_instruction->opcode() == HloOpcode::kParameter) { - TF_CHECK_OK(fused_instruction->ReplaceAllUsesWith( - clone->mutable_operand(fused_instruction->parameter_number()))); - } else { - unfused_instructions.push_back(fused_instruction); - } - } - CHECK(unfused_instructions.front() == clone->fused_expression_root()); - // Replace instruction_to_merge use of 'this' with unfused_root. - TF_CHECK_OK( - instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front())); - // Fuse 'unfused_instructions' into 'this'. - for (auto& instruction : unfused_instructions) { - FuseInstruction(instruction); - instruction->DetachFromOperands(); - } - CHECK_EQ(0, clone->user_count()); - clone->DetachFromOperands(); - TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( - clone->fused_instructions_computation())); -} - -void HloInstruction::MergeFusionInstructionIntoMultiOutput( - HloInstruction* instruction_to_merge) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK_EQ(instruction_to_merge->opcode(), HloOpcode::kFusion); - // Add all non-parameter fused instructions to 'unfused_instructions' to be - // merged into 'this'. `old_to_new' maps the instructions in the fused node - // to the disaseembled fusion instructions. - // Note that we add the unfused instructions to this->parent_ computation. - // This is necessary because the unique_id needs for an instruction and - // it's only added when inserting to the computation. - tensorflow::gtl::FlatMap old_to_new; - std::vector unfused_instructions; - auto computation_to_merge = - instruction_to_merge->fused_instructions_computation(); - auto post_order = computation_to_merge->MakeInstructionPostOrder(); - for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) { - auto fused_instruction = *rit; - if (fused_instruction->opcode() == HloOpcode::kParameter) { - InsertOrDie(&old_to_new, fused_instruction, - instruction_to_merge->mutable_operand( - fused_instruction->parameter_number())); - continue; - } - - // Here we clone the insertion and call FuseInstructionIntoMultiOutput() - // which clones again. This can be improved. - auto cloned_instruction = - parent_->AddInstruction(fused_instruction->Clone()); - unfused_instructions.push_back(cloned_instruction); - InsertOrDie(&old_to_new, fused_instruction, cloned_instruction); - } - for (auto unfused_instruction : unfused_instructions) { - for (int64 index = 0; index < unfused_instruction->operand_count(); - index++) { - auto new_operand = - FindOrDie(old_to_new, unfused_instruction->mutable_operand(index)); - TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand)); - } - } - - HloInstruction* unfused_root = unfused_instructions.front(); - TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root)); - - TF_CHECK_OK( - instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge)); - if (GetModule()) { - TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge)); - } - - // Fuse the root instruction and generate multiple outputs. - FuseInstructionIntoMultiOutput(unfused_root); - TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); - // The rest instructions are of normal fusing. - for (int64 i = 1; i < unfused_instructions.size(); i++) { - auto instruction = unfused_instructions[i]; - FuseInstruction(instruction); - TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); - } -} - -HloInstruction* HloInstruction::FuseInstructionInternal( - HloInstruction* instruction_to_fuse, bool add_output) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - - // When add_output is false, this fusion instruction must be a user of - // instruction_to_fuse. - if (!add_output) { - CHECK(IsUserOf(instruction_to_fuse)); - } - HloInstruction* fused_instruction = - CloneAndFuseInternal(instruction_to_fuse, add_output); - return fused_instruction; -} - -HloInstruction* HloInstruction::CloneAndFuseInternal( - HloInstruction* instruction_to_fuse, bool add_output) { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); - VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); - HloInstruction* clone = nullptr; - if (called_computations_.empty()) { - // New fusion instruction. It should not be a multioutput instruction. - CHECK(!add_output); - auto builder = HloComputation::Builder("fused_computation", this); - builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); - called_computations_.push_back( - CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); - clone = fused_expression_root(); - } else { - clone = fused_instructions_computation()->AddInstruction( - instruction_to_fuse->Clone(/*suffix=*/"")); - // When add_output is false, instruction_to_fuse is necessarily an operand - // of the fusion instruction. After fusion this will no longer be the case. - // Remove the operand from the operand list and remove its corresponding - // fused parameter instruction. Renumber parameters as necessary to make - // parameter numbers consistent with their index in the - // fused_parameter_ vector. - bool in_operand_list = std::find(operands_.begin(), operands_.end(), - instruction_to_fuse) != operands_.end(); - CHECK(add_output || in_operand_list); - const std::vector& fused_parameters = - fused_instructions_computation()->parameter_instructions(); - for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { - if (instruction_to_fuse == operands_[operand_num]) { - // replace the fused parameter instruction's uses with the clone. - HloInstruction* fused_parameter = fused_parameters[operand_num]; - TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone)); - - // Remove the corresponding fused parameter and operand from their - // respective vectors. - TF_CHECK_OK( - fused_instructions_computation()->RemoveParameter(operand_num)); - operands_.erase(operands_.begin() + operand_num); - break; - } - } - // We've cloned instruction_to_fuse into this fusion instruction, so this - // fusion instruction is no longer a use of instruction_to_fuse. - if (in_operand_list) { - instruction_to_fuse->RemoveUser(this); - // When the instruction_to_fuse does not have other users, we don't need - // to generate a multioutput fusion instruction. - if (instruction_to_fuse->user_count() == 0) { - add_output = false; - } - } - } - - // Reread the parameters in the computation. - const std::vector& fused_parameters = - fused_instructions_computation()->parameter_instructions(); - - // Add each operand of the clone as an operand of the fusion instruction. A - // complication is that some clone operands may already be operands of the - // fusion instruction. - for (int64 operand_num = 0; operand_num < clone->operand_count(); - ++operand_num) { - HloInstruction* operand = clone->mutable_operand(operand_num); - - // See if this operand is already an operand of the fusion node. - CHECK_EQ(operands_.size(), fused_parameters.size()); - HloInstruction* fused_param = nullptr; - for (int64 i = 0; i < operands_.size(); ++i) { - if (operands_[i] == operand) { - fused_param = fused_parameters[i]; - break; - } - } - - if (fused_param == nullptr) { - // Clone's operand was not already an operand of the fusion - // instruction. Add it as an operand and add a corresponding fused - // parameter instruction. - fused_param = AddFusionOperand(operand); - } - TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); - } - - if (add_output) { - CHECK_GT(instruction_to_fuse->user_count(), 0); - // If this is already a multioutput fusion instruction, expand the root - // tuple by 1. - HloInstruction* fused_root = fused_expression_root(); - HloInstruction::InstructionVector tuple_elements; - bool newly_created_tuple_instr = false; - if (fused_root->opcode() == HloOpcode::kTuple) { - tuple_elements = fused_root->operands(); - } else { - tuple_elements.push_back(fused_root); - newly_created_tuple_instr = true; - } - if (clone->opcode() == HloOpcode::kTuple) { - for (auto inst : clone->operands()) { - tuple_elements.push_back(inst); - } - } else { - tuple_elements.push_back(clone); - } - HloInstruction* new_root = fused_instructions_computation()->AddInstruction( - HloInstruction::CreateTuple(tuple_elements)); - fused_instructions_computation()->set_root_instruction(new_root); - shape_ = new_root->shape(); - if (fused_root->opcode() == HloOpcode::kTuple) { - TF_CHECK_OK( - fused_instructions_computation()->RemoveInstruction(fused_root)); - } - - // If this is a newly created multioutput instruction, we need to update - // the use of the original fusion instruction. - if (newly_created_tuple_instr) { - HloInstruction* new_instr = parent_->AddInstruction( - HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); - TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); - } - int64 index = tuple_elements.size(); - if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { - index -= instruction_to_fuse->operand_count(); - std::vector to_be_removed; - for (auto old_gte : instruction_to_fuse->users()) { - CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); - int64 old_tuple_index = old_gte->tuple_index(); - HloInstruction* new_gte = - parent_->AddInstruction(HloInstruction::CreateGetTupleElement( - old_gte->shape(), this, index + old_tuple_index)); - TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte)); - to_be_removed.push_back(old_gte); - } - for (auto old_gte : to_be_removed) { - TF_CHECK_OK(parent_->RemoveInstruction(old_gte)); - } - TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); - } else { - HloInstruction* new_gte = - parent_->AddInstruction(HloInstruction::CreateGetTupleElement( - clone->shape(), this, index - 1)); - TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte)); - } - } - - VLOG(2) << "New clone:\n" << clone->ToString(); - return clone; -} - -RandomDistribution HloInstruction::random_distribution() const { - CHECK_EQ(opcode_, HloOpcode::kRng); - return distribution_; -} - bool HloInstruction::HasSideEffectNoRecurse() const { switch (opcode_) { case HloOpcode::kSend: @@ -1225,25 +990,58 @@ bool HloInstruction::HasSideEffect() const { return gather_dim_numbers; } +/* static */ std::unique_ptr HloInstruction::CreateDomain( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDomain, shape)); + instruction->operand_side_metadata_ = std::move(operand_side_metadata); + instruction->user_side_metadata_ = std::move(user_side_metadata); + instruction->AppendOperand(operand); + return instruction; +} + std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, - HloModule* module, CloneMap* clone_map) const { + HloCloneContext* context) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { VLOG(3) << " %" << new_operand->name(); } - if (module == nullptr) { - module = GetModule(); - } std::unique_ptr clone; - // Explicitly call the factory for the instruction type. This is more robust // in the face of code changes than copying fields explicitly. This also // properly sets the user fields of the operands. switch (opcode_) { + // Ops migrated to subclasses. + // TODO(b/80131774): Remove this switch when migration is complete. + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kFft: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReverse: + case HloOpcode::kConcatenate: + case HloOpcode::kReduce: + case HloOpcode::kTranspose: + case HloOpcode::kBroadcast: + case HloOpcode::kMap: + case HloOpcode::kSlice: + case HloOpcode::kConstant: + case HloOpcode::kTrace: + case HloOpcode::kFusion: + case HloOpcode::kRng: + case HloOpcode::kParameter: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReducePrecision: + clone = CloneWithNewOperandsImpl(shape, new_operands, context); + break; // Unary ops. case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: @@ -1302,23 +1100,24 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( new_operands[2]); break; // Other supported ops. - case HloOpcode::kBroadcast: - CHECK_EQ(new_operands.size(), 1); - clone = CreateBroadcast(shape, new_operands[0], dimensions_); - break; case HloOpcode::kCall: clone = CreateCall(shape, new_operands, to_apply()); break; case HloOpcode::kCustomCall: clone = CreateCustomCall(shape, new_operands, custom_call_target_); + if (window_ != nullptr) { + clone->window_ = MakeUnique(*window_); + } + if (convolution_dimension_numbers_ != nullptr) { + clone->convolution_dimension_numbers_ = + MakeUnique( + *convolution_dimension_numbers_); + } break; case HloOpcode::kHostCompute: clone = CreateHostCompute(shape, new_operands, channel_name_, cost_estimate_ns_); break; - case HloOpcode::kConcatenate: - clone = CreateConcatenate(shape, new_operands, dimensions(0)); - break; case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); clone = CreateConvert(shape, new_operands[0]); @@ -1327,11 +1126,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateBitcastConvert(shape, new_operands[0]); break; - case HloOpcode::kReducePrecision: - CHECK_EQ(new_operands.size(), 1); - clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_, - mantissa_bits_); - break; case HloOpcode::kConvolution: CHECK_EQ(new_operands.size(), 2); clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, @@ -1342,30 +1136,16 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateDot(shape, new_operands[0], new_operands[1], *dot_dimension_numbers_); break; - case HloOpcode::kFft: - CHECK_EQ(new_operands.size(), 1); - clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); - break; case HloOpcode::kCrossReplicaSum: - clone = CreateCrossReplicaSum(shape, new_operands); - break; - case HloOpcode::kGetTupleElement: - CHECK_EQ(new_operands.size(), 1); - clone = CreateGetTupleElement(shape, new_operands[0], tuple_index()); - break; - case HloOpcode::kMap: - clone = CreateMap(shape, new_operands, to_apply()); + clone = + CreateCrossReplicaSum(shape, new_operands, to_apply(), + replica_group_ids_, cross_replica_sum_barrier_); break; case HloOpcode::kPad: CHECK_EQ(new_operands.size(), 2); clone = CreatePad(shape, new_operands[0], new_operands[1], *padding_config_); break; - case HloOpcode::kReduce: - CHECK_EQ(new_operands.size(), 2); - clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_, - to_apply()); - break; case HloOpcode::kReduceWindow: CHECK_EQ(new_operands.size(), 2); clone = CreateReduceWindow(shape, new_operands[0], new_operands[1], @@ -1377,22 +1157,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CreateSelectAndScatter(shape, new_operands[0], select(), *window_, new_operands[1], new_operands[2], scatter()); break; - case HloOpcode::kReverse: - CHECK_EQ(new_operands.size(), 1); - clone = CreateReverse(shape, new_operands[0], dimensions_); - break; - case HloOpcode::kRng: - clone = CreateRng(shape, distribution_, new_operands); - break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); clone = CreateReshape(shape, new_operands[0]); break; - case HloOpcode::kSlice: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_, - slice_strides_); - break; case HloOpcode::kDynamicSlice: clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); @@ -1402,10 +1170,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], new_operands[2]); break; - case HloOpcode::kTranspose: - CHECK_EQ(new_operands.size(), 1); - clone = CreateTranspose(shape, new_operands[0], dimensions_); - break; case HloOpcode::kTuple: clone = CreateTuple(new_operands); *clone->mutable_shape() = shape; @@ -1415,33 +1179,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateWhile(shape, while_condition(), while_body(), new_operands[0]); break; - case HloOpcode::kConstant: - clone = CreateConstant(literal_->CloneToUnique()); - break; - case HloOpcode::kFusion: { - CHECK_NE(module, nullptr); - auto new_fused_computation = module->AddEmbeddedComputation( - fused_instructions_computation()->Clone("clone", module, clone_map)); - clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(), - /*operands=*/new_operands, - /*fusion_computation=*/new_fused_computation); - break; - } - case HloOpcode::kParameter: - clone = CreateParameter(parameter_number_, shape, name_); - break; - case HloOpcode::kBatchNormTraining: - CHECK_EQ(new_operands.size(), 3); - clone = - CreateBatchNormTraining(shape, new_operands[0], new_operands[1], - new_operands[2], epsilon(), feature_index()); - break; - case HloOpcode::kBatchNormInference: - CHECK_EQ(new_operands.size(), 5); - clone = CreateBatchNormInference( - shape, new_operands[0], new_operands[1], new_operands[2], - new_operands[3], new_operands[4], epsilon(), feature_index()); - break; case HloOpcode::kInfeed: CHECK_EQ(new_operands.size(), 0); clone = CreateInfeed(shape, infeed_config()); @@ -1450,49 +1187,37 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); break; - case HloOpcode::kBatchNormGrad: - CHECK_EQ(new_operands.size(), 5); - clone = CreateBatchNormGrad(shape, new_operands[0], new_operands[1], - new_operands[2], new_operands[3], - new_operands[4], epsilon(), feature_index()); - break; case HloOpcode::kConditional: CHECK_EQ(new_operands.size(), 3); clone = CreateConditional(shape, new_operands[0], new_operands[1], true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kSend: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSend(new_operands[0], channel_id()); - break; - case HloOpcode::kSendDone: - CHECK_EQ(new_operands.size(), 1); - clone = CreateSendDone(new_operands[0]); - break; - case HloOpcode::kRecv: - CHECK_EQ(new_operands.size(), 0); - // The shape is a tuple, but CreateRecv() wants the raw data shape. - clone = - CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); - break; - case HloOpcode::kRecvDone: - CHECK_EQ(new_operands.size(), 1); - clone = CreateRecvDone(new_operands[0]); - break; case HloOpcode::kGather: CHECK_EQ(new_operands.size(), 2); clone = CreateGather(shape, new_operands[0], new_operands[1], *gather_dimension_numbers_, gather_window_bounds_); break; - case HloOpcode::kTrace: - LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); + case HloOpcode::kDomain: + CHECK_EQ(new_operands.size(), 1); + clone = + CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), + user_side_metadata_->Clone()); + break; + case HloOpcode::kGenerateToken: + clone = CreateGenerateToken(new_operands); + break; } SetupDerivedInstruction(clone.get()); clone->set_parent(parent_); - clone->set_backend_config(backend_config()); - if (clone_map != nullptr) { - InsertOrDie(clone_map, this, clone.get()); + clone->set_raw_backend_config_string(backend_config_); + if (context != nullptr) { + context->MapInstruction(this, clone.get()); + clone->ReplaceCalledComputations([&](HloComputation* callee) { + return callee->parent() != context->module() + ? context->module()->DeepCloneComputation(callee, context) + : callee; + }); } return clone; } @@ -1500,9 +1225,9 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( HloInstruction::~HloInstruction() {} std::unique_ptr HloInstruction::Clone( - const string& suffix, HloModule* module, CloneMap* clone_map) const { + const string& suffix, HloCloneContext* context) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_, module, clone_map); + CloneWithNewOperands(shape_, operands_, context); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1562,40 +1287,6 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const { return hlo; } -const Literal& HloInstruction::literal() const { - CHECK_EQ(HloOpcode::kConstant, opcode_); - return *literal_; -} - -bool HloInstruction::HasLiteral() const { return literal_ != nullptr; } - -bool HloInstruction::CanHaveDimensionsField() const { - return (opcode() == HloOpcode::kReverse || - opcode() == HloOpcode::kConcatenate || - opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast || - opcode() == HloOpcode::kTranspose); -} - -const std::vector& HloInstruction::dimensions() const { - CHECK(CanHaveDimensionsField()); - return dimensions_; -} - -int64 HloInstruction::dimensions(int64 index) const { - return dimensions()[index]; -} - -int64 HloInstruction::concatenate_dimension() const { - CHECK(opcode() == HloOpcode::kConcatenate); - CHECK_EQ(1, dimensions_.size()); - return dimensions(0); -} - -int64 HloInstruction::tuple_index() const { - CHECK_EQ(HloOpcode::kGetTupleElement, opcode_); - return tuple_index_; -} - const HloInstruction* HloInstruction::operand(int64 i) const { return operands_[i]; } @@ -1614,6 +1305,17 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { LOG(FATAL) << "target was not an operand: " << target->ToString(); } +HloInstruction::InstructionVector HloInstruction::unique_operands() const { + InstructionVector unique; + tensorflow::gtl::FlatSet seen; + for (HloInstruction* operand : operands()) { + if (seen.insert(operand).second) { + unique.push_back(operand); + } + } + return unique; +} + Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { TF_RET_CHECK(instruction->parent() == parent()); if (std::find(control_successors_.begin(), control_successors_.end(), @@ -1673,10 +1375,6 @@ void HloInstruction::AddUser(HloInstruction* user) { } } -bool HloInstruction::IsConstant() const { - return opcode_ == HloOpcode::kConstant; -} - bool HloInstruction::HasConstantOperand() const { for (const HloInstruction* operand : operands_) { if (operand->IsConstant()) { @@ -1706,7 +1404,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kCos: - case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: @@ -1746,41 +1443,12 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kTuple: return true; - // Broadcast, Concatenate, and Transpose need the same dimensions field. - case HloOpcode::kBroadcast: - case HloOpcode::kConcatenate: - case HloOpcode::kTranspose: - return dimensions() == other.dimensions(); - - case HloOpcode::kFusion: - return fusion_kind() == other.fusion_kind() && - eq_computations(fused_instructions_computation(), - other.fused_instructions_computation()); - // These opcodes have complex or special behavior so just return false. - case HloOpcode::kRng: - case HloOpcode::kTrace: + case HloOpcode::kDomain: case HloOpcode::kWhile: + case HloOpcode::kGenerateToken: return false; - case HloOpcode::kParameter: - return parameter_number() == other.parameter_number(); - - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: - case HloOpcode::kBatchNormGrad: - return feature_index() == other.feature_index() && - epsilon() == other.epsilon(); - - // A constant is defined by the value in the literal. - case HloOpcode::kConstant: - return literal() == other.literal(); - - // A reduce-precision operation is determined by the bit sizes. - case HloOpcode::kReducePrecision: - return exponent_bits() == other.exponent_bits() && - mantissa_bits() == other.mantissa_bits(); - // Convolution has a window and dimensions. case HloOpcode::kConvolution: return protobuf_util::ProtobufEquals(window(), other.window()) && @@ -1797,16 +1465,6 @@ bool HloInstruction::IdenticalSlowPath( other.gather_dimension_numbers()) && gather_window_bounds() == other.gather_window_bounds(); - // FFT has various types & lengths. - case HloOpcode::kFft: - return fft_type() == other.fft_type() && - fft_length() == other.fft_length(); - - // Reduction results are determined by the reduction dimension and the - // reduction computation. - case HloOpcode::kReduce: - return dimensions() == other.dimensions() && - eq_computations(to_apply(), other.to_apply()); case HloOpcode::kReduceWindow: return eq_computations(to_apply(), other.to_apply()) && protobuf_util::ProtobufEquals(window(), other.window()); @@ -1818,24 +1476,30 @@ bool HloInstruction::IdenticalSlowPath( eq_computations(scatter(), other.scatter()) && protobuf_util::ProtobufEquals(window(), other.window()); - // Remaining instructions with special values. - case HloOpcode::kGetTupleElement: - return tuple_index() == other.tuple_index(); case HloOpcode::kPad: return protobuf_util::ProtobufEquals(padding_config(), other.padding_config()); - case HloOpcode::kSlice: - return slice_starts_ == other.slice_starts_ && - slice_limits_ == other.slice_limits_ && - slice_strides_ == other.slice_strides_; case HloOpcode::kCall: - case HloOpcode::kMap: - return eq_computations(to_apply(), other.to_apply()); + case HloOpcode::kCrossReplicaSum: + return replica_group_ids() == other.replica_group_ids() && + cross_replica_sum_barrier() == other.cross_replica_sum_barrier() && + eq_computations(to_apply(), other.to_apply()); case HloOpcode::kCustomCall: + if ((window_ == nullptr) != (other.window_ == nullptr) || + (window_ != nullptr && + !protobuf_util::ProtobufEquals(window(), other.window()))) { + return false; + } + if ((convolution_dimension_numbers_ == nullptr) != + (other.convolution_dimension_numbers_ == nullptr) || + (convolution_dimension_numbers_ != nullptr && + !protobuf_util::ProtobufEquals( + convolution_dimension_numbers(), + other.convolution_dimension_numbers()))) { + return false; + } return custom_call_target_ == other.custom_call_target_; - case HloOpcode::kReverse: - return dimensions() == other.dimensions(); case HloOpcode::kConditional: return eq_computations(true_computation(), other.true_computation()) && eq_computations(false_computation(), other.false_computation()); @@ -1844,21 +1508,36 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - case HloOpcode::kSend: - case HloOpcode::kSendDone: case HloOpcode::kHostCompute: return false; - } -} -bool HloInstruction::IsRank2Transpose() const { - return (opcode_ == HloOpcode::kTranspose) && - dimensions_ == std::vector({1, 0}) && - shape_.dimensions_size() == 2 && - std::equal(shape_.dimensions().begin(), shape_.dimensions().end(), - operands_[0]->shape_.dimensions().rbegin()); + // Ops migrated to subclasses should never come to this line. + // TODO(b/80131774): Remove this switch when migration is complete. + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kFft: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReverse: + case HloOpcode::kConcatenate: + case HloOpcode::kReduce: + case HloOpcode::kTranspose: + case HloOpcode::kBroadcast: + case HloOpcode::kMap: + case HloOpcode::kSlice: + case HloOpcode::kConstant: + case HloOpcode::kTrace: + case HloOpcode::kFusion: + case HloOpcode::kRng: + case HloOpcode::kParameter: + case HloOpcode::kGetTupleElement: + case HloOpcode::kReducePrecision: + LOG(FATAL) << "Base class impl called for opcode with subclass: " + << opcode(); + } } void HloInstruction::RemoveUser(HloInstruction* user) { @@ -1964,6 +1643,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -2099,6 +1779,71 @@ string HloInstruction::ToString(const HloPrintOptions& options) const { return ToStringWithCanonicalNameMap(options, &new_map); } +bool HloInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + switch (opcode_) { + // Unary elementwise operations. + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kCeil: + case HloOpcode::kClz: + case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kFloor: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kReal: + case HloOpcode::kReducePrecision: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kTanh: + CHECK_EQ(1, operand_count()); + return true; + + // Binary elementwise operations, the same as in IsElementwiseBinary(). + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kComplex: + case HloOpcode::kDivide: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + CHECK_EQ(2, operand_count()); + return true; + + // Ternary elementwise operations. + case HloOpcode::kSelect: + return !ShapeUtil::IsTuple(shape_); + case HloOpcode::kClamp: + return true; + + default: + return false; + } +} + string HloInstruction::ToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { @@ -2134,8 +1879,8 @@ string HloInstruction::ToStringWithCanonicalNameMap( !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } - if (options.print_backend_config() && !backend_config().empty()) { - StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\""); + if (options.print_backend_config() && !backend_config_.empty()) { + StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\""); } return result; } @@ -2149,76 +1894,39 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { string operands; - if (opcode() == HloOpcode::kConstant) { - // For constants, show the actual value in place of an empty operand list. - // - // In HloInstruction, sometimes a constant literal is not constructed due - // to its size. Skip the printing in this case. - if (HasLiteral() && ((!ShapeUtil::IsTuple(shape()) && - ShapeUtil::ElementsIn(shape()) <= 10) || - options.print_large_constants())) { - // Literal::ToString emits multidimensional arrays over multiple - // lines. Compact this into one line by stripping out white space. - string tmp = literal().ToString(); - std::replace(tmp.begin(), tmp.end(), '\n', ' '); - std::vector v = tensorflow::str_util::Split(tmp, ' '); - bool first = true; - // Concatenate elements in "v" with spaces separating them, but ignoring - // empty entries. - for (const auto& s : v) { - if (s.empty()) { - continue; - } - StrAppend(&operands, (first ? "" : " "), s); - first = false; - } - } else { - // Do not show large constants or tuples. - operands = "{...}"; - } - } else if (opcode() == HloOpcode::kParameter) { - StrAppend(&operands, parameter_number_); - } else { - tensorflow::gtl::ArraySlice slice(operands_); - const int64 kMaxOperandsToShowIfCompact = 4; - if (options.compact_operands() && - slice.size() > kMaxOperandsToShowIfCompact) { - slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); + tensorflow::gtl::ArraySlice slice(operands_); + const int64 kMaxOperandsToShowIfCompact = 4; + if (options.compact_operands() && + slice.size() > kMaxOperandsToShowIfCompact) { + slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); + } + operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { + std::vector str; + if (options.print_operand_shape()) { + str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); } - operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { - std::vector str; - if (options.print_operand_shape()) { - str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); - } - // In a top-level HloInstruction::ToString() call, the operand name is not - // part of the canonical string. - if (options.canonicalize_instruction_names() && - options.is_in_nested_computation()) { - str.push_back(PrintName( - canonical_name_map->LookupOrInsert(operand->name()), options)); - } else if (!options.compact_operands()) { - str.push_back(PrintName(operand->name(), options)); - } - StrAppend(out, Join(str, " ")); - }); - const int64 remaining = operands_.size() - slice.size(); - if (slice.size() != operands_.size()) { - StrAppend(&operands, ", ...(+", remaining, ")"); + // In a top-level HloInstruction::ToString() call, the operand name is not + // part of the canonical string. + if (options.canonicalize_instruction_names() && + options.is_in_nested_computation()) { + str.push_back(PrintName( + canonical_name_map->LookupOrInsert(operand->name()), options)); + } else if (!options.compact_operands()) { + str.push_back(PrintName(operand->name(), options)); } + StrAppend(out, Join(str, " ")); + }); + const int64 remaining = operands_.size() - slice.size(); + if (slice.size() != operands_.size()) { + StrAppend(&operands, ", ...(+", remaining, ")"); } return operands; } std::vector HloInstruction::ExtraAttributesToString( const HloPrintOptions& options) const { - std::vector extra; - if (opcode() == HloOpcode::kFusion) { - extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); - } - if (CanHaveDimensionsField()) { - extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); - } + std::vector extra = ExtraAttributesToStringImpl(options); if (window_ != nullptr && window_->dimensions_size() != 0) { extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); } @@ -2226,32 +1934,16 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("padding=", xla::PaddingConfigToString(*padding_config_))); } - if (opcode() == HloOpcode::kSlice) { - std::vector bounds; - bounds.reserve(slice_starts_.size()); - const bool omit_stride = - std::all_of(slice_strides_.begin(), slice_strides_.end(), - [](int64 stride) { return stride == 1; }); - for (int i = 0; i < slice_starts_.size(); ++i) { - string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); - bounds.push_back(StrCat("[", slice_starts_[i], ":", slice_limits_[i], - stride_str, "]")); - } - extra.push_back(StrCat("slice={", Join(bounds, ", "), "}")); - } + if (opcode() == HloOpcode::kDynamicSlice) { extra.push_back( StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")); } - if (opcode() == HloOpcode::kBatchNormTraining || - opcode() == HloOpcode::kBatchNormInference || - opcode() == HloOpcode::kBatchNormGrad) { - extra.push_back(StrCat("epsilon=", epsilon())); - extra.push_back(StrCat("feature_index=", feature_index())); - } if (convolution_dimension_numbers_ != nullptr) { - extra.push_back(ConvolutionDimensionNumbersToString()); + extra.push_back(StrCat( + "dim_labels=", + ConvolutionDimensionNumbersToString(*convolution_dimension_numbers_))); } if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); @@ -2261,10 +1953,6 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); } - if (opcode() == HloOpcode::kFft) { - extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); - extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); - } if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { @@ -2284,7 +1972,8 @@ std::vector HloInstruction::ExtraAttributesToString( PrintName(false_computation()->name(), options))); } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || - opcode() == HloOpcode::kReduce) { + opcode() == HloOpcode::kReduce || + opcode() == HloOpcode::kCrossReplicaSum) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2334,14 +2023,7 @@ std::vector HloInstruction::ExtraAttributesToString( break; } } - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || - opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { - extra.push_back(StrCat("channel_id=", channel_id_)); - } - if (opcode() == HloOpcode::kGetTupleElement) { - extra.push_back(StrCat("index=", tuple_index())); - } if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } @@ -2361,13 +2043,17 @@ std::vector HloInstruction::ExtraAttributesToString( extra.push_back( StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); } - if (opcode() == HloOpcode::kRng) { + if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { + extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), + "\", entry=", operand_side_metadata_->ToString(), + ", exit=", user_side_metadata_->ToString(), "}")); + } + if (!replica_group_ids().empty()) { extra.push_back( - StrCat("distribution=", RandomDistributionToString(distribution_))); + StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")); } - if (opcode() == HloOpcode::kReducePrecision) { - extra.push_back(StrCat("exponent_bits=", exponent_bits_)); - extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); + if (!cross_replica_sum_barrier().empty()) { + extra.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); } // By contract, we print the custom call target even if @@ -2407,25 +2093,13 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_metadata() = metadata_; - proto.set_backend_config(backend_config()); - if (literal_ != nullptr) { - *proto.mutable_literal() = literal_->ToProto(); - } - proto.set_parameter_number(parameter_number_); - if (opcode() == HloOpcode::kFusion) { - proto.set_fusion_kind(xla::ToString(fusion_kind())); - proto.add_called_computation_ids( - fused_instructions_computation()->unique_id()); - } else { + proto.set_backend_config(backend_config_); + if (opcode() != HloOpcode::kFusion) { for (const HloComputation* computation : called_computations_) { proto.add_called_computation_ids(computation->unique_id()); } } - proto.set_tuple_index(tuple_index_); - for (int64 dimension : dimensions_) { - proto.add_dimensions(dimension); - } if (window_ != nullptr) { *proto.mutable_window() = *window_; } @@ -2444,14 +2118,7 @@ HloInstructionProto HloInstruction::ToProto() const { proto.add_gather_window_bounds(bound); } } - for (int i = 0; i < slice_starts_.size(); ++i) { - auto* slice_dimension = proto.add_slice_dimensions(); - slice_dimension->set_start(slice_starts_[i]); - slice_dimension->set_limit(slice_limits_[i]); - slice_dimension->set_stride(slice_strides_[i]); - } - proto.set_exponent_bits(exponent_bits_); - proto.set_mantissa_bits(mantissa_bits_); + for (int64 slice_size : dynamic_slice_sizes_) { proto.add_dynamic_slice_sizes(slice_size); } @@ -2459,19 +2126,9 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_padding_config() = *padding_config_; } proto.set_outfeed_config(outfeed_config_); - if (opcode() == HloOpcode::kRng) { - proto.set_distribution(distribution_); - } - proto.set_epsilon(epsilon_); - proto.set_feature_index(feature_index_); - proto.set_channel_id(channel_id_); proto.set_infeed_config(infeed_config_); proto.set_custom_call_target(custom_call_target_); *proto.mutable_outfeed_shape() = outfeed_shape_; - proto.set_fft_type(fft_type_); - for (int64 fft_len : fft_length_) { - proto.add_fft_length(fft_len); - } if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); @@ -2479,6 +2136,9 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_channel_name(channel_name_); proto.set_cost_estimate_ns(cost_estimate_ns_); + for (int64 replica_group_id : replica_group_ids_) { + proto.add_replica_group_ids(replica_group_id); + } return proto; } @@ -2531,12 +2191,6 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { trace_instruction_ = trace_instruction; } -string HloInstruction::TracingTag() const { - CHECK_EQ(HloOpcode::kTrace, opcode()); - CHECK(literal_ != nullptr); - return literal_->GetR1U8AsString(); -} - bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } bool HloInstruction::IsFusable() const { @@ -2546,6 +2200,7 @@ bool HloInstruction::IsFusable() const { } // Some kinds of instructions don't make sense to fuse. switch (opcode_) { + case HloOpcode::kDomain: case HloOpcode::kParameter: return false; // Side effecting instrutions cannot be fused. @@ -2554,49 +2209,6 @@ bool HloInstruction::IsFusable() const { } } -HloComputation* HloInstruction::fused_instructions_computation() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(!called_computations_.empty()); - auto* fused_instructions_computation = called_computations_.front(); - CHECK(fused_instructions_computation->IsFusionComputation()); - return fused_instructions_computation; -} - -HloInstruction* HloInstruction::fused_expression_root() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->root_instruction(); -} - -HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->parameter_instruction( - parameter_number); -} - -const std::vector& HloInstruction::fused_parameters() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->parameter_instructions(); -} - -const tensorflow::gtl::iterator_range>::const_iterator>> -HloInstruction::fused_instructions() const { - CHECK_EQ(opcode_, HloOpcode::kFusion); - const HloComputation* subcomp = fused_instructions_computation(); - return subcomp->instructions(); -} - -const tensorflow::gtl::iterator_range< - UnwrappingIterator>::iterator>> -HloInstruction::fused_instructions() { - CHECK_EQ(opcode_, HloOpcode::kFusion); - return fused_instructions_computation()->instructions(); -} - -int64 HloInstruction::fused_instruction_count() const { - return fused_instructions_computation()->instruction_count(); -} - HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) : unique_id_(-1), opcode_(opcode), @@ -2773,6 +2385,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSendDone(this); case HloOpcode::kGather: return visitor->HandleGather(this); + case HloOpcode::kDomain: + return visitor->HandleDomain(this); + case HloOpcode::kGenerateToken: + return visitor->HandleGenerateToken(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -3040,87 +2656,7 @@ bool HloInstruction::IsElementwiseBinary() const { } bool HloInstruction::IsElementwise() const { - switch (opcode_) { - // Nullary elementwise operations. - case HloOpcode::kConstant: - return true; - - // Unary elementwise operations. - case HloOpcode::kAbs: - case HloOpcode::kRoundNearestAfz: - case HloOpcode::kCeil: - case HloOpcode::kClz: - case HloOpcode::kConvert: - case HloOpcode::kBitcastConvert: - case HloOpcode::kCopy: - case HloOpcode::kCos: - case HloOpcode::kExp: - case HloOpcode::kExpm1: - case HloOpcode::kFloor: - case HloOpcode::kImag: - case HloOpcode::kIsFinite: - case HloOpcode::kLog: - case HloOpcode::kLog1p: - case HloOpcode::kNot: - case HloOpcode::kNegate: - case HloOpcode::kReal: - case HloOpcode::kReducePrecision: - case HloOpcode::kSign: - case HloOpcode::kSin: - case HloOpcode::kTanh: - CHECK_EQ(1, operand_count()); - return true; - - // Binary elementwise operations, the same as in IsElementwiseBinary(). - case HloOpcode::kAdd: - case HloOpcode::kAtan2: - case HloOpcode::kComplex: - case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kNe: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kAnd: - case HloOpcode::kOr: - case HloOpcode::kShiftLeft: - case HloOpcode::kShiftRightArithmetic: - case HloOpcode::kShiftRightLogical: - CHECK_EQ(2, operand_count()); - return true; - - // Ternary elementwise operations. - case HloOpcode::kSelect: - return !ShapeUtil::IsTuple(shape_); - case HloOpcode::kClamp: - return true; - - // Other operations. - case HloOpcode::kRng: - case HloOpcode::kMap: - return true; - case HloOpcode::kFusion: - if (fusion_kind() != FusionKind::kLoop) { - return false; - } - for (auto* fused : fused_instructions()) { - if (fused->opcode() != HloOpcode::kParameter && - !fused->IsElementwise()) { - return false; - } - } - return true; - - default: - return false; - } + return IsElementwiseImpl(tensorflow::gtl::nullopt); } bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { @@ -3128,54 +2664,8 @@ bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { return !ShapeUtil::SameDimensions(shape(), operand(operand_idx)->shape()); } -namespace { -bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, - const HloInstruction* operand) { - std::vector operand_indices = instruction->OperandIndices(operand); - return std::all_of( - operand_indices.begin(), operand_indices.end(), - [instruction](int64 operand_index) { - return instruction->IsElementwiseOnOperand(operand_index); - }); -} -} // namespace - bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const { - // For all instructions other than kFusion, being elementwise on one of the - // operands is equivalent to being elementwise on all the operands. - if (opcode() != HloOpcode::kFusion) { - return IsElementwise(); - } - - CHECK_EQ(HloOpcode::kFusion, opcode()); - if (fusion_kind() != FusionKind::kLoop) { - return false; - } - - // A loop-fusion is elementwise on an operand if all operations (computed - // using BFS) between the operand and the fused root are elementwise. - std::deque worklist; - std::unordered_set visited; - worklist.push_back(fused_parameter(operand_idx)); - visited.insert(fused_parameter(operand_idx)); - while (!worklist.empty()) { - HloInstruction* operand = worklist.front(); - worklist.pop_front(); - for (HloInstruction* user : operand->users()) { - CHECK_GE(user->unique_id(), 0); - if (ContainsKey(visited, user)) { - continue; - } - if (user->IsElementwise() || - IsInstructionElementwiseOnOperand(user, operand)) { - worklist.push_back(user); - visited.insert(user); - } else { - return false; - } - } - } - return true; + return IsElementwiseImpl(operand_idx); } // A helper class for memoized, recursive computation of HloOpcode::kFusion @@ -3197,8 +2687,10 @@ class HloInstruction::FusionReusesParamElements { static UseKind ComputeInternal( int64 i, const HloInstruction& hlo, tensorflow::gtl::FlatMap* cache) { - if (hlo.opcode_ == HloOpcode::kParameter && hlo.parameter_number_ == i) { - return UseKind::kUse; + if (auto hlo_param = DynCast(&hlo)) { + if (hlo_param->parameter_number() == i) { + return UseKind::kUse; + } } auto p = cache->emplace(&hlo, UseKind{}); @@ -3360,42 +2852,8 @@ string RandomDistributionToString(const RandomDistribution& distribution) { return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); } -StatusOr StringToRandomDistribution(const string& name) { - static std::unordered_map* map = [] { - static auto* map = new std::unordered_map; - for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { - if (RandomDistribution_IsValid(i)) { - auto value = static_cast(i); - (*map)[RandomDistributionToString(value)] = value; - } - } - return map; - }(); - auto found = map->find(tensorflow::str_util::Lowercase(name)); - if (found == map->end()) { - return InvalidArgument("Unknown distribution"); - } - return found->second; -} - -std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { - return os << ToString(kind); -} - -string HloInstruction::ConvolutionDimensionNumbersToString() const { - string result; - if (convolution_dimension_numbers_ == nullptr) { - return result; - } - const ConvolutionDimensionNumbers& dnums = *convolution_dimension_numbers_; - // Show the given dimension labels in order of major to minor based on the - // shape's layout. - const auto append_dims = [&](const std::vector& dims, - const Shape& shape) { - CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - StrAppend(&result, Join(dims, "")); - }; - +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums) { // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". std::vector lhs_dims(2 + dnums.input_spatial_dimensions().size()); @@ -3419,19 +2877,8 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } - result += "dim_labels="; - append_dims(lhs_dims, operand(0)->shape()); - result += "_"; - append_dims(rhs_dims, operand(1)->shape()); - result += "->"; - - // A convolution can be represented as a kConvolution HLO or as a CustomCall - // that returns a tuple, the first element of which is the result of the - // convolution. - Shape this_shape = - ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape(); - append_dims(output_dims, this_shape); - return result; + return StrCat(Join(lhs_dims, ""), "_", Join(rhs_dims, ""), "->", + Join(output_dims, "")); } string HloInstruction::DotDimensionNumbersToString() const { @@ -3457,6 +2904,28 @@ string HloInstruction::DotDimensionNumbersToString() const { return Join(result, ", "); } +StatusOr StringToRandomDistribution(const string& name) { + static std::unordered_map* map = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + +std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { + return os << ToString(kind); +} + string HloInstruction::GatherDimensionNumbersToString() const { CHECK_NE(gather_dimension_numbers_.get(), nullptr); string output_window_dims = @@ -3488,6 +2957,31 @@ bool HloInstruction::CouldBeBitcast() const { } } +Status HloInstruction::GetBackendConfigInternal( + tensorflow::protobuf::Message* proto) const { + proto->Clear(); + + // Empty string does not parse as valid JSON, but it's a valid backend config, + // corresponding to the empty proto. + if (backend_config_.empty()) { + return Status::OK(); + } + return tensorflow::HumanReadableJsonToProto(backend_config_, proto); +} + +Status HloInstruction::set_backend_config( + const tensorflow::protobuf::Message& proto) { + TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto)); + return Status::OK(); +} + +/* static */ StatusOr HloInstruction::BackendConfigToRawString( + const tensorflow::protobuf::Message& proto) { + string ret; + TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret)); + return ret; +} + HloModule* HloInstruction::GetModule() const { if (parent_) { return parent_->parent(); @@ -3505,21 +2999,173 @@ void HloInstruction::set_outer_dimension_partitions( outer_dimension_partitions_ = outer_dimension_partitions; } +// TODO(b/80131774): Remove these temporary methods after transition. +int64 HloInstruction::feature_index() const { + return Cast(this)->feature_index(); +} + +float HloInstruction::epsilon() const { + return Cast(this)->epsilon(); +} + +FftType HloInstruction::fft_type() const { + return Cast(this)->fft_type(); +} + +const std::vector& HloInstruction::fft_length() const { + return Cast(this)->fft_length(); +} + +int64 HloInstruction::channel_id() const { + return Cast(this)->channel_id(); +} + +int64 HloInstruction::concatenate_dimension() const { + return Cast(this)->concatenate_dimension(); +} + +bool HloInstruction::IsRank2Transpose() const { + auto transpose = DynCast(this); + return transpose != nullptr && transpose->IsRank2Transpose(); +} + +int64 HloInstruction::slice_starts(int64 dimension) const { + return Cast(this)->slice_starts(dimension); +} + +const std::vector& HloInstruction::slice_starts() const { + return Cast(this)->slice_starts(); +} + +int64 HloInstruction::slice_limits(int64 dimension) const { + return Cast(this)->slice_limits(dimension); +} + +const std::vector& HloInstruction::slice_limits() const { + return Cast(this)->slice_limits(); +} + +int64 HloInstruction::slice_strides(int64 dimension) const { + return Cast(this)->slice_strides(dimension); +} + +const std::vector& HloInstruction::slice_strides() const { + return Cast(this)->slice_strides(); +} + +bool HloInstruction::IsInPlaceSlice() const { + return Cast(this)->IsInPlaceSlice(); +} + +const Literal& HloInstruction::literal() const { + return Cast(this)->literal(); +} + +bool HloInstruction::IsConstant() const { + return DynCast(this) != nullptr; +} + void HloInstruction::RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index) { - CHECK_EQ(opcode(), HloOpcode::kConstant); - Shape* mutable_array_subshape = - ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); - CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + Cast(this)->RelayoutConstant(new_layout, shape_index); +} - // Normally array_subshape will always have a layout, but this invariant is - // temporarily broken in LayoutAssignment::AssignLayouts. +string HloInstruction::TracingTag() const { + return Cast(this)->TracingTag(); +} - if (!mutable_array_subshape->has_layout() || - !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { - literal_ = literal_->Relayout(new_layout, shape_index); - *mutable_array_subshape->mutable_layout() = new_layout; - } +HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { + return Cast(this)->AddFusionOperand(new_operand); +} + +// Delegates to HloFusionInstruction::MergeFusionInstruction. +void HloInstruction::MergeFusionInstruction( + HloInstruction* instruction_to_merge) { + return Cast(this)->MergeFusionInstruction( + Cast(instruction_to_merge)); +} + +// Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. +void HloInstruction::MergeFusionInstructionIntoMultiOutput( + HloInstruction* instruction_to_merge) { + return Cast(this) + ->MergeFusionInstructionIntoMultiOutput( + Cast(instruction_to_merge)); +} + +HloInstruction* HloInstruction::FuseInstruction( + HloInstruction* instruction_to_fuse) { + return Cast(this)->FuseInstruction(instruction_to_fuse); +} + +HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse) { + return Cast(this)->FuseInstructionIntoMultiOutput( + instruction_to_fuse); +} + +HloComputation* HloInstruction::fused_instructions_computation() const { + return Cast(this)->fused_instructions_computation(); +} + +HloInstruction* HloInstruction::fused_expression_root() const { + return Cast(this)->fused_expression_root(); +} + +const tensorflow::gtl::iterator_range>::const_iterator>> +HloInstruction::fused_instructions() const { + return Cast(this)->fused_instructions(); } +const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> +HloInstruction::fused_instructions() { + return Cast(this)->fused_instructions(); +} + +int64 HloInstruction::fused_instruction_count() const { + return Cast(this)->fused_instruction_count(); +} + +HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const { + return Cast(this)->fused_parameter(parameter_number); +} + +const std::vector& HloInstruction::fused_parameters() const { + return Cast(this)->fused_parameters(); +} + +const bool HloInstruction::IsMultiOutputFusion() const { + const HloFusionInstruction* fusion = DynCast(this); + return fusion != nullptr && fusion->IsMultiOutputFusion(); +} + +HloInstruction::FusionKind HloInstruction::fusion_kind() const { + return Cast(this)->fusion_kind(); +} + +void HloInstruction::set_fusion_kind(FusionKind kind) { + return Cast(this)->set_fusion_kind(kind); +} + +RandomDistribution HloInstruction::random_distribution() const { + return Cast(this)->random_distribution(); +} + +int64 HloInstruction::parameter_number() const { + return Cast(this)->parameter_number(); +} + +int64 HloInstruction::tuple_index() const { + return Cast(this)->tuple_index(); +} + +int32 HloInstruction::exponent_bits() const { + return Cast(this)->exponent_bits(); +} + +int32 HloInstruction::mantissa_bits() const { + return Cast(this)->mantissa_bits(); +} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 234dbc8399de2d88209dd8dd2be58dd152ddbe76..a206cdab2739cd5046295179ead5d8bf19d521c4 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -37,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -50,6 +52,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -319,7 +322,7 @@ class HloInstruction { kCustom, }; - ~HloInstruction(); + virtual ~HloInstruction(); // Creates an instruction from the given proto. Arguments: // @@ -423,10 +426,27 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits); - // Creates a cross replica sum op. + // Creates a cross replica reduction op. + // + // `reduction_computation`: the reduction function. + // + // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // `channel_id`: for Allreduce nodes from different models, if they have the + // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross models. + // + // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, - tensorflow::gtl::ArraySlice operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice replica_group_ids, + tensorflow::StringPiece barrier, + const tensorflow::gtl::optional& channel_id = + tensorflow::gtl::nullopt); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -597,6 +617,13 @@ class HloInstruction { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice window_bounds); + // Creates a kDomain instruction which delimits an HLO domain which have + // the provided user and operand side metadata. + static std::unique_ptr CreateDomain( + const Shape& shape, HloInstruction* operand, + std::unique_ptr operand_side_metadata, + std::unique_ptr user_side_metadata); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -638,6 +665,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates a token instruction used for joining or creating token types which + // thread through side-effecting operations. + static std::unique_ptr CreateGenerateToken( + tensorflow::gtl::ArraySlice operands); + // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( tensorflow::gtl::ArraySlice output_window_dims, @@ -676,6 +708,10 @@ class HloInstruction { using InstructionVector = tensorflow::gtl::InlinedVector; const InstructionVector& operands() const { return operands_; } + // Returns the vector of unique operands, in the same order they are found + // within the operand vector. + InstructionVector unique_operands() const; + // Returns the index of 'target' in the operands sequence. // Precondition: target must be an operand (or a fatal error will occur). int64 operand_index(const HloInstruction* target) const; @@ -762,15 +798,16 @@ class HloInstruction { } } + if (backend_config_ != other.backend_config_) { + return false; + } + return IdenticalSlowPath(other, eq_computations); } // Returns whether the instruction has a constant operand. bool HasConstantOperand() const; - // Returns whether this instruction does a rank-2 transposition. - bool IsRank2Transpose() const; - // Replaces the use of this instruction in "user" with "new_producer". Note // that there might be multiple uses of this instruction in "user"; all will // be replaced. @@ -839,38 +876,6 @@ class HloInstruction { template Status Visit(DfsHloVisitorBase* visitor); - // Returns the literal associated with this instruction. - // - // Note: only constant and parameter opcodes have an associated literal. - const Literal& literal() const; - - // Returns whether there is literal associated with this instruction. - bool HasLiteral() const; - - // Returns the parameter number associated with this instruction. - // - // Note: only parameter opcodes have an associated parameter number. - int64 parameter_number() const { - CHECK_EQ(HloOpcode::kParameter, opcode_); - return parameter_number_; - } - - // Returns the dimension sizes or numbers associated with this instruction. - // - // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, - // and reverse. - const std::vector& dimensions() const; - int64 dimensions(int64 index) const; - - // Accessor for the dimension in which a concatenate HLO should occur. - // Precondition: opcode() == HloOpcode::kConcatenate - int64 concatenate_dimension() const; - - // Returns the tuple index associated with this instruction. - // - // Precondition: opcode() == HloOpcode::kGetTupleElement - int64 tuple_index() const; - // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. @@ -965,7 +970,7 @@ class HloInstruction { string ToShortString() const; // Returns a serialized representation of this instruction. - HloInstructionProto ToProto() const; + virtual HloInstructionProto ToProto() const; // Returns a category for the HLO. This could be something like "convolution" // or "elementwise". @@ -977,111 +982,26 @@ class HloInstruction { HloInstruction* tracing() const; void set_tracing(HloInstruction* trace_instruction); - // Returns the channel id associated with the instruction. The id is - // shared between each Send/Recv pair and is globally unique to identify each - // channel. - // - // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv - int64 channel_id() const { return channel_id_; } - // Returns the channel name associated with the instruction. The name is // used to identify host Send/Recv operations. // // Precondition: opcode() == HloOpcode::kHostCompute string channel_name() const { return channel_name_; } - // Returns feature_index field associated with the instruction. The index - // represents the index of the feature dimension. - // - // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, - // or kBatchNormGrad. - int64 feature_index() const { return feature_index_; } - - // Returns a epsilon value associated with the instruction. The is a small - // number added to the variance to avoid divide-by-zero error. - // - // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, - // or kBatchNormGrad. - float epsilon() const { return epsilon_; } - // Returns the infeed configuration string. The infeed configuration includes // any metadata needed for the backend compiler (e.g., infeed buffer address) // and is target-dependent. string infeed_config() const { return infeed_config_; } void set_infeed_config(const string& config) { infeed_config_ = config; } - // Returns a tag to be used in tracing. - // - // Precondition: opcode() == HloOpcode::kTrace - string TracingTag() const; - - // Returns whether the instruction is a constant. - bool IsConstant() const; - // Returns true if this instruction is fused, ie contained within a fusion // instruction. bool IsFused() const; - // Returns the computation for this fused instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - HloComputation* fused_instructions_computation() const; - // Returns true if this instruction can be legally fused into a fusion // instruction. bool IsFusable() const; - // Returns the root instruction of the fused expression contained within this - // fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* fused_expression_root() const; - - // Returns the list of fused instructions inside this fusion instruction. The - // returned type is a range of HloInstruction*s. - // - // Precondition: opcode() == HloOpcode::kFusion - const tensorflow::gtl::iterator_range>::const_iterator>> - fused_instructions() const; - - const tensorflow::gtl::iterator_range< - UnwrappingIterator>::iterator>> - fused_instructions(); - - // Gets the number of instructions inside this fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - int64 fused_instruction_count() const; - - // Returns the fused parameter instruction in this fusion instruction - // corresponding to the given parameter number. - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* fused_parameter(int64 parameter_number) const; - - // Returns the vector of fused parameters inside this fusion instruction. - // - // Precondition: opcode() == HloOpcode::kFusion - const std::vector& fused_parameters() const; - - // Returns true if this instruction is a fusion instruction that generates - // multiple outputs. - const bool IsMultiOutputFusion() const { - return opcode() == HloOpcode::kFusion && - fused_expression_root()->opcode() == HloOpcode::kTuple; - } - - FusionKind fusion_kind() const { - CHECK_EQ(HloOpcode::kFusion, opcode_); - return fusion_kind_; - } - - void set_fusion_kind(FusionKind kind) { - CHECK_EQ(HloOpcode::kFusion, opcode_); - fusion_kind_ = kind; - } - // Returns the sharding applied to this operator. // REQUIRES: has_sharding() is true. const HloSharding& sharding() const { @@ -1094,16 +1014,23 @@ class HloInstruction { } // Returns the sharding unique device, if any. tensorflow::gtl::optional sharding_unique_device() const { - if (sharding_ == nullptr || !sharding_->HasUniqueDevice()) { + if (sharding_ == nullptr) { return tensorflow::gtl::optional(); } - return sharding_->UniqueDevice().ValueOrDie(); + auto device = sharding_->UniqueDevice(); + return device.ok() ? device.ValueOrDie() + : tensorflow::gtl::optional(); } // Sets the sharding of this operator. Should only be called by HloModule or // HloComputation methods. void set_sharding(const HloSharding& sharding) { sharding_ = MakeUnique(sharding); } + void set_single_sharding(const HloSharding& sharding); + // Sets a sharding that assigns the current instruction to device. + void set_device_sharding(int64 device) { + set_single_sharding(HloSharding::AssignDevice(device)); + } // Remove any sharding from this operator. void clear_sharding() { sharding_ = nullptr; } // Return true if this operator has a sharding assigned. @@ -1117,6 +1044,15 @@ class HloInstruction { return other->has_sharding() ? sharding() == other->sharding() : false; } + // Retrieves the operand side metadata of a kDomain instruction. + const DomainMetadata& operand_side_metadata() const { + return *operand_side_metadata_; + } + // Retrieves the user side metadata of a kDomain instruction. + const DomainMetadata& user_side_metadata() const { + return *user_side_metadata_; + } + // When creating a new instruction which either replaces, or shifts up (kCopy // insertion case), another instruction, we need to make sure the certain // properties of the new instruction are copied into the derived one. As of @@ -1124,93 +1060,6 @@ class HloInstruction { // instruction. void SetupDerivedInstruction(HloInstruction* derived_instruction) const; - // Adds a new operand the fusion instruction. - HloInstruction* AddFusionOperand(HloInstruction* new_operand); - - // Merges the fused instructions from 'instruction_to_merge' into the - // fused instruction set of 'this', updating operands as necessary. - // - // Precondition: opcode() == HloOpcode::kFusion - // Predondition: 'instruction_to_merge' must be an operand of 'this'. - void MergeFusionInstruction(HloInstruction* instruction_to_merge); - - // Merges the fused instructions from instruction_to_merge into the fused - // instruction set of 'this' and generates multioutput fusion instructions. - // All the users of instruction_to_merge will be redirected to 'this' - // instruction. instruction_to_merge will be removed from its parent - // computation. - // - // Precondition: opcode() == HloOpcode::kFusion - void MergeFusionInstructionIntoMultiOutput( - HloInstruction* instruction_to_merge); - - // Fuses the given instruction in this fusion instruction. instruction_to_fuse - // is cloned and the clone is placed in the fusion - // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather - // than moved to cleanly handle the case where the instruction has a use - // outside the fusion instruction. Moving such an instruction into a fusion - // instruction would violate the single-result invariant of HLO instructions - // and significantly complicate code generation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { - return FuseInstructionInternal(instruction_to_fuse); - } - - // Fuses the given instruction in this fusion instruction and generate - // multioutput fusion instruction. A clone of the instruction_to_fuse will - // be part of the output of fusion instructions. The users of - // instruction_to_fuse will be redirected to this fusion instructions. - // instruction_to_fuse will be removed from its parent computation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstructionIntoMultiOutput( - HloInstruction* instruction_to_fuse) { - return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); - } - - // Returns the start index in the given dimension for a slice node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_starts(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_starts_[dimension]; - } - const std::vector& slice_starts() const { return slice_starts_; } - - // Returns the (exclusive) limit index in the given dimension for a slice - // node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_limits(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_limits_[dimension]; - } - const std::vector& slice_limits() const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_limits_; - } - - // Returns the stride in the given dimension for a slice node. - // - // Precondition: opcode() == HloOpcode::kSlice - int64 slice_strides(int64 dimension) const { - CHECK_EQ(HloOpcode::kSlice, opcode_); - return slice_strides_[dimension]; - } - const std::vector& slice_strides() const { return slice_strides_; } - - // Returns the flag that describes whether a slice must be lowered into an - // offset into the original operand. - bool IsInPlaceSlice() const { return is_in_place_slice_; } - - // Sets and returns the flag that describes whether a slice must be lowered - // into an offset into the original operand. - bool SetIsInPlaceSlice(bool value) { - is_in_place_slice_ = value; - return value; - } - // Returns the size of the slice in the given dimension for a dynamic // slice node. // @@ -1224,22 +1073,6 @@ class HloInstruction { return dynamic_slice_sizes_; } - // Returns the number of exponent bits for a reduce-precision node. - // - // Precondition: opcode() == HloOpcode::kReducePrecision - int32 exponent_bits() const { - CHECK_EQ(HloOpcode::kReducePrecision, opcode_); - return exponent_bits_; - } - - // Returns the number of mantissa bits for a reduce-precision node. - // - // Precondition: opcode() == HloOpcode::kReducePrecision - int32 mantissa_bits() const { - CHECK_EQ(HloOpcode::kReducePrecision, opcode_); - return mantissa_bits_; - } - // Returns data on the window in a windowed operation such as // convolution. const Window& window() const { @@ -1277,19 +1110,6 @@ class HloInstruction { MakeUnique(dnums); } - FftType fft_type() const { - CHECK_EQ(HloOpcode::kFft, opcode_); - return fft_type_; - } - - const std::vector& fft_length() const { - CHECK_EQ(HloOpcode::kFft, opcode_); - return fft_length_; - } - - // Returns the dump string of the convolution dimension numbers. - string ConvolutionDimensionNumbersToString() const; - // Returns data on the dimension numbers used for a dot operation. const DotDimensionNumbers& dot_dimension_numbers() const { CHECK(dot_dimension_numbers_ != nullptr); @@ -1312,35 +1132,19 @@ class HloInstruction { // Returns the dump string of the gather dimension numbers. string GatherDimensionNumbersToString() const; - // Returns the random distribution for this rng node. - // - // Precondition: opcode() == HloOpcode::kRng - RandomDistribution random_distribution() const; - - // See documentation for Clone(). - using CloneMap = std::unordered_map; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of - // the instruction to form the name of the cloned instruction. Ignores the - // control predecessors and successors of this HLO instruction. - // - // If the module pointer is not nullptr, then any cloned computations will be - // added to this module in order to support deep cloning. Otherwise the module - // of the instruction is used. - // - // If clone_map is not nullptr, then each original instruction that is cloned - // will be inserted and map to its clone. clone_map should not already contain - // any of the instructions to clone. - std::unique_ptr Clone(const string& suffix = "clone", - HloModule* module = nullptr, - CloneMap* clone_map = nullptr) const; + // the instruction to form the name of the cloned instruction. + // Ignores the control predecessors and successors of this HLO instruction. + std::unique_ptr Clone( + const string& suffix = "clone", HloCloneContext* context = nullptr) const; // Clones the HLO instruction as above but with new shape and operands. std::unique_ptr CloneWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr, CloneMap* clone_map = nullptr) const; + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -1410,9 +1214,14 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Gets/sets the string identifier for this instruction. + // Gets the string identifier for this instruction. const string& name() const { return name_; } - void set_name(tensorflow::StringPiece name) { name_ = std::string(name); } + + // Sets the string identifier for this instruction. Name will be sanitized to + // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + void SetAndSanitizeName(const string& name) { + name_ = NameUniquer::GetSanitizedName(name); + } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1435,13 +1244,34 @@ class HloInstruction { // this field and they cannot interpret it due to its meaning being backend // specific. // - // TODO(b/78194644): Introduce structured configuration format as per - // go/xla-heuristics. - const string& backend_config() const { return backend_config_; } - void set_backend_config(string backend_config) { - backend_config_ = std::move(backend_config); + // ConfigProto should be a protobuf Message type. + template + StatusOr backend_config() const { + ConfigProto proto; + TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto)); + return std::move(proto); + } + Status set_backend_config(const tensorflow::protobuf::Message& proto); + + // Getter/setter for raw JSON-encoded backend config. Prefer the + // functions above that deal in proto Messages where possible. + const string& raw_backend_config_string() const { return backend_config_; } + void set_raw_backend_config_string(string config_str) { + backend_config_ = std::move(config_str); } + // Returns a string representation of a proto in the format used by + // raw_backend_config_string. + // + // This is morally equivalent to: + // + // HloInstruction instr; + // TF_RETURN_IF_ERROR(instr.set_backend_config(proto)); + // return instr.raw_backend_config_string(); + // + static StatusOr BackendConfigToRawString( + const tensorflow::protobuf::Message& proto); + // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } const OpMetadata& metadata() const { return metadata_; } @@ -1472,13 +1302,193 @@ class HloInstruction { void set_outer_dimension_partitions( const std::vector& outer_dimension_partitions); - // Change the layout for an Constant Hlo instruction to match new_layout. For - // tuple shaped constants shape_index is the path to the internal array - // subshape whose layout needs to be changed. + // Old methods kept for smooth subclassing transition BEGIN. + // TODO(b/80131774): Remove this code. + + // Delegates to HloBatchNormInstruction::feature_index. + int64 feature_index() const; + + // Delegates to HloBatchNormInstruction::epsilon. + float epsilon() const; + + // Delegates to HloFftInstruction::fft_type. + FftType fft_type() const; + + // Delegates to HloFftInstruction::fft_length. + const std::vector& fft_length() const; + + // Delegates to HloSendRecvInstruction::channel_id. + int64 channel_id() const; + + // Returns the dimension sizes or numbers associated with this instruction. + virtual const std::vector& dimensions() const { + LOG(FATAL) << "Unimplemented method."; + } + virtual int64 dimensions(int64 index) const { + LOG(FATAL) << "Unimplemented method."; + } + + // Delegates to HloConcatenateInstruction::concatenate_dimension. + int64 concatenate_dimension() const; + + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + + // Delegates to HloSliceInstruction::slice_start. + int64 slice_starts(int64 dimension) const; + const std::vector& slice_starts() const; + + // Delegates to HloSliceInstruction::slice_limits. + int64 slice_limits(int64 dimension) const; + const std::vector& slice_limits() const; + + // Delegates to HloSliceInstruction::slice_strides. + int64 slice_strides(int64 dimension) const; + const std::vector& slice_strides() const; + + // Delegates to HloSliceInstruction::IsInPlaceSlice. + bool IsInPlaceSlice() const; + + // Returns the literal associated with this instruction. + const Literal& literal() const; + + // Returns whether the instruction is a constant. + bool IsConstant() const; + + // Delegate to HloConstantInstruction::RelayoutConstant. void RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index = {}); + // Delegates to HloTraceInstruction::TracingTag. + string TracingTag() const; + + // Delegates to HloFusionInstruction::AddFusionOperand. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + + // Delegates to HloFusionInstruction::MergeFusionInstruction. + void MergeFusionInstruction(HloInstruction* instruction_to_merge); + + // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput. + void MergeFusionInstructionIntoMultiOutput( + HloInstruction* instruction_to_merge); + + // Delegates to HloFusionInstruction::FuseInstruction. + HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse); + + // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput. + HloInstruction* FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse); + + // Delegates to HloFusionInstruction::fused_instruction. + HloComputation* fused_instructions_computation() const; + + // Delegates to HloFusionInstruction::fused_expression_root. + HloInstruction* fused_expression_root() const; + + // Delegates to HloFusionInstruction::fused_instructions. + const tensorflow::gtl::iterator_range>::const_iterator>> + fused_instructions() const; + + const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> + fused_instructions(); + + // Delegates to HloFusionInstruction::fused_instruction_count. + int64 fused_instruction_count() const; + + // Delegates to HloFusionInstruction::fused_parameter. + HloInstruction* fused_parameter(int64 parameter_number) const; + + // Delegates to HloFusionInstruction::fused_parameters. + const std::vector& fused_parameters() const; + + // Returns true if this instruction is a fusion instruction that generates + // multiple outputs. + const bool IsMultiOutputFusion() const; + + // Delegates to HloFusionInstruction::fusion_kind. + FusionKind fusion_kind() const; + + // Delegates to HloFusionInstruction::set_fusion_kind. + void set_fusion_kind(FusionKind kind); + + // Delegates to HloRngInstruction::random_distribution. + RandomDistribution random_distribution() const; + + // Delegates to HloParameterInstruction::parameter_number. + int64 parameter_number() const; + + // Delegates to HloGetTupleElementInstruction::tuple_index. + int64 tuple_index() const; + + // Returns the number of exponent bits for a reduce-precision node. + int32 exponent_bits() const; + + // Returns the number of mantissa bits for a reduce-precision node. + int32 mantissa_bits() const; + // Old methods kept for smooth subclassing transition END. + + // Returns the group ids of each replica for CrossReplicaSum op. + const std::vector& replica_group_ids() const { + return replica_group_ids_; + } + + // Returns the barrier config used for the CrossReplicaSum implementation of + // each backend. + string cross_replica_sum_barrier() const { + return cross_replica_sum_barrier_; + } + void set_cross_replica_sum_barrier(string barrier) { + cross_replica_sum_barrier_ = barrier; + } + + protected: + enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; + // Helper class for computing OperandElementUse for kFusion. + class FusionReusesParamElements; + + // Internal constructor for a given opcode/shape, other fields must be filled + // by factory methods. + HloInstruction(HloOpcode opcode, const Shape& shape); + + // Appends operand to the list of operands and adds this instruction as a user + // of the operand. + void AppendOperand(HloInstruction* operand); + + void RemoveOperandAt(int index) { + operands_.erase(operands_.begin() + index); + } + + void AppendComputation(HloComputation* computation) { + called_computations_.push_back(computation); + } + + void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); } + private: + // Implementation for non-common logic of CloneWithNewOperands. + virtual std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + // TODO(b/80131774): This should be pure virtual. + LOG(FATAL) << "Unimplemented method."; + } + + // Implementation for non-common logic of ExtraAttributesToString. + virtual std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {}; + } + + // Implementation for IsElementwise if operand_idx is nullopt and for + // IsElementwiseOnOperand if otherwise. + // + // NOTE: For all instructions other than kFusion, being elementwise on one of + // the operands is equivalent to being elementwise on all the operands. + virtual bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const; // Prints an instruction to a string. // // The canonical string representation needs to name operands and instruction @@ -1489,7 +1499,7 @@ class HloInstruction { CanonicalNameMap* canonical_name_map) const; // Prints an operand to a string. - string OperandsToStringWithCanonicalNameMap( + virtual string OperandsToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const; @@ -1497,13 +1507,8 @@ class HloInstruction { // OperandsToStringWithCanonicalNameMap() functions. friend class HloComputation; - enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; - - // Helper class for computing OperandElementUse for kFusion. - class FusionReusesParamElements; - // See comments on Identical(). - bool IdenticalSlowPath( + virtual bool IdenticalSlowPath( const HloInstruction& other, const std::function& eq_computations) const; @@ -1513,55 +1518,19 @@ class HloInstruction { const Shape& shape, HloOpcode opcode, tensorflow::gtl::ArraySlice operands); - // Appends operand to the list of operands and adds this instruction as a user - // of the operand. - void AppendOperand(HloInstruction* operand); - // Adds a user for this instruction. void AddUser(HloInstruction* user); // Removes a user for this instruction. void RemoveUser(HloInstruction* user); - // Internal constructor for a given opcode/shape, other fields must be filled - // by factory methods. - HloInstruction(HloOpcode opcode, const Shape& shape); - - // Fuses the given instruction into this fusion instruction. When add_output - // is false (which is the default), instruction_to_fuse is cloned and the - // clone is placed in the fusion instruction. instruction_to_fuse is - // unchanged. - // - // When add_output is true, a clone of the instruction_to_fuse will be part - // of the output of fusion instructions. The users of instruction_to_fuse - // will be redirected to this fusion instructions. instruction_to_fuse will - // be removed from its parent computation. - // - // Precondition: this->opcode() == HloOpcode::kFusion - HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, - bool add_output = false); - - // Clones the given instruction_to_fuse and insert the clone into this fusion - // instruction. If add_output is true, a clone of instruction_to_fuse will - // be in the output of the this fusion instruction (part of the tuple of the - // fusion root). - // - // Precondition: opcode() == HloOpcode::kFusion - HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, - bool add_output = false); - - // Clones a fusion instruction with a new shape and operands. - std::unique_ptr CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands, - HloModule* module = nullptr) const; - - // Returns true if this instruction can legally have the dimensions field - // set. Used for checking precondition of dimensions field accessors. - bool CanHaveDimensionsField() const; - // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; + // Helper for implementing backend_config(). Parses backend_config_ into the + // given proto. + Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const; + int unique_id_; // Unique to this HloInstruction within a HloModule // Opcode for this instruction. @@ -1592,16 +1561,6 @@ class HloInstruction { // Result shape of this instruction. Shape shape_; - // Literal, only present for kConstant. - std::unique_ptr literal_; - - // Constant index, only present for kGetTupleElement. - int64 tuple_index_ = -1; - - // Dimensions present for some operations that require reshaping or - // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. - std::vector dimensions_; - // Describes the window in a windowed operation such as convolution. std::unique_ptr window_; @@ -1614,24 +1573,6 @@ class HloInstruction { std::unique_ptr gather_dimension_numbers_; std::vector gather_window_bounds_; - // Describes FFT type for an FFT instruction. - FftType fft_type_ = FftType::FFT; - - // Indicates the FFT length for an FFT instruction. - std::vector fft_length_; - - // Describes the [begin, end) index range for a slice. - std::vector slice_starts_; - std::vector slice_limits_; - std::vector slice_strides_; - - // Describes whether the slice can be lowered to an offset into the operand. - bool is_in_place_slice_ = false; - - // The bit sizes for a reduce-precision operation. - int32 exponent_bits_ = 0; - int32 mantissa_bits_ = 0; - // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). std::vector dynamic_slice_sizes_; @@ -1640,14 +1581,12 @@ class HloInstruction { // padding of this pad instruction. Only set for pad instructions. std::unique_ptr padding_config_; - // The type of the fusion. Used by kFusion only. - FusionKind fusion_kind_; - // The sharding, if one exists. std::unique_ptr sharding_; - // For parameter instructions this field holds the parameter number. - int64 parameter_number_ = 0; + // Fields used by the kDomain instruction. + std::unique_ptr operand_side_metadata_; + std::unique_ptr user_side_metadata_; // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; @@ -1686,22 +1625,6 @@ class HloInstruction { // an operand. HloInstruction* trace_instruction_ = nullptr; - // The distribution requested for random number generation. - // Only present for kRng. - RandomDistribution distribution_; - - // A small float number added to the variance to avoid divide-by-zero error. - // Only present for kBatchNormTraining. - float epsilon_ = 0.0f; - - // An integer value representing the index of the feature dimension. - // Only present for kBatchNormTraining. - int64 feature_index_ = -1; - - // Represents a unique identifier for each Send/Recv instruction pair. - // Only present for kSend or kRecv. - int64 channel_id_ = -1; - // The string representation of the infeed configuration. string infeed_config_; @@ -1709,6 +1632,12 @@ class HloInstruction { // HLO. See the documentation on backend_config(). string backend_config_; + // The group id of each replica for CrossReplicaSum. + std::vector replica_group_ids_; + + // The string representation of the barrier config used for CrossReplicaSum. + string cross_replica_sum_barrier_; + // String identifier for instruction. string name_; @@ -1731,6 +1660,9 @@ StatusOr StringToFusionKind( string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); string RandomDistributionToString(const RandomDistribution& distribution); +string ConvolutionDimensionNumbersToString( + const ConvolutionDimensionNumbers& dnums); + StatusOr StringToRandomDistribution(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index a61c472c72804b077d21274d2e866a69c5e73157..5d6f8b931f0c665fba03e1c845214fa83aabf12e 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -24,11 +24,13 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/window_util.h" namespace xla { namespace { @@ -340,7 +342,7 @@ TEST_F(HloInstructionTest, TrivialMap) { // Builds a parameter and feeds it to the map. HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10, "")); + HloInstruction::CreateParameter(0, f32a100x10, "p")); auto map = builder.AddInstruction( HloInstruction::CreateMap(f32a100x10, {param0}, add_f32)); module->AddEntryComputation(builder.Build()); @@ -379,7 +381,7 @@ TEST_F(HloInstructionTest, TrivialReduce) { // Builds a parameter and an initial value and feeds them to the reduce. HloComputation::Builder builder(TestName()); auto param0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, f32a100x10, "")); + HloInstruction::CreateParameter(0, f32a100x10, "p")); auto const0 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( @@ -978,6 +980,23 @@ TEST_F(HloInstructionTest, FullyElementwise) { } } +TEST_F(HloInstructionTest, MapIsElementwise) { + auto module = CreateNewModule(); + const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0}); + HloComputation::Builder builder(TestName()); + HloComputation::Builder map_builder("id"); + map_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + auto map_computation = module->AddEmbeddedComputation(map_builder.Build()); + auto x = + builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x")); + auto map = builder.AddInstruction( + HloInstruction::CreateMap(r2f32, {x}, map_computation)); + module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(map->IsElementwise()); +} + TEST_F(HloInstructionTest, PartiallyElementwise) { const Shape r1f32 = ShapeUtil::MakeShape(F32, {5}); const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5}); @@ -1494,5 +1513,117 @@ TEST_F(HloInstructionTest, CanonnicalStringificationConditional) { })"); } +TEST_F(HloInstructionTest, CheckDeepClone) { + const char* const hlo_string = R"( +HloModule Module + +addy (lhs: s32[], rhs: s32[]) -> s32[] { + lhs = s32[] parameter(0) + rhs = s32[] parameter(1) + ROOT zadd = s32[] add(lhs, rhs) +} + +calla (x: s32[]) -> s32[] { + x = s32[] parameter(0) + reduce = s32[] reduce-window(x, x), to_apply=addy + ROOT xadd = s32[] add(x, reduce) +} + +body (bparam: s32[]) -> s32[] { + constant = s32[] constant(1) + bparam = s32[] parameter(0) + v = s32[] call(bparam), to_apply=calla + ROOT add = s32[] add(constant, bparam) +} + +condition (cparam: s32[]) -> pred[] { + xconstant = s32[] constant(5) + cparam = s32[] parameter(0) + ROOT greater-than = pred[] greater-than(xconstant, cparam) +} + +ENTRY entry (param: s32[]) -> s32[] { + eparam = s32[] parameter(0) + ROOT while = s32[] while(eparam), condition=condition, body=body + } +)"; + // Check that deep clones really deep clones every instruction and + // computations, without leaving dangling pointers to the old module. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + std::unique_ptr clone = module->Clone(); + for (HloComputation* computation : clone->computations()) { + EXPECT_EQ(computation->parent(), clone.get()); + for (HloInstruction* instruction : computation->instructions()) { + EXPECT_EQ(instruction->parent()->parent(), clone.get()); + } + } +} + +TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) { + const Shape shape = ShapeUtil::MakeShape(F32, {42}); + HloComputation::Builder builder("test"); + HloInstruction* p = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p")); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p)); + + EXPECT_TRUE(add1->Identical(*add2)); + add1->set_raw_backend_config_string("abc"); + EXPECT_FALSE(add1->Identical(*add2)); +} + +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + Window w = window_util::MakeWindow({1, 2, 3}); + instr1->set_window(w); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) { + auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + auto instr2 = instr1->Clone(); + EXPECT_TRUE(instr1->Identical(*instr2)); + + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr1->set_convolution_dimension_numbers(dnums); + EXPECT_FALSE(instr1->Identical(*instr2)); +} + +TEST_F(HloInstructionTest, CloneWindowOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + Window w = window_util::MakeWindow({1, 2, 3}); + instr->set_window(w); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w)) + << clone->window().DebugString(); +} + +TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) { + auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}), + /*operands=*/{}, + /*custom_call_target=*/"foo"); + ConvolutionDimensionNumbers dnums; + dnums.set_output_batch_dimension(42); + instr->set_convolution_dimension_numbers(dnums); + auto clone = instr->Clone(); + EXPECT_TRUE(protobuf_util::ProtobufEquals( + clone->convolution_dimension_numbers(), dnums)) + << clone->convolution_dimension_numbers().DebugString(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc new file mode 100644 index 0000000000000000000000000000000000000000..d326d5d009fa3a378cfe92e1b081445f09d982e1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -0,0 +1,1287 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_instructions.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { +namespace { + +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, + const HloInstruction* operand) { + std::vector operand_indices = instruction->OperandIndices(operand); + return std::all_of( + operand_indices.begin(), operand_indices.end(), + [instruction](int64 operand_index) { + return instruction->IsElementwiseOnOperand(operand_index); + }); +} +} // namespace + +HloBatchNormInstruction::HloBatchNormInstruction( + HloOpcode opcode, const Shape& shape, HloInstruction* operand, + HloInstruction* scale, float epsilon, int64 feature_index) + : HloInstruction(opcode, shape), + epsilon_(epsilon), + feature_index_(feature_index) { + AppendOperand(operand); + AppendOperand(scale); +} + +bool HloBatchNormInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return feature_index() == casted_other.feature_index() && + epsilon() == casted_other.epsilon(); +} + +HloInstructionProto HloBatchNormInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_epsilon(epsilon_); + proto.set_feature_index(feature_index_); + return proto; +} + +std::vector HloBatchNormInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("epsilon=", epsilon()), + StrCat("feature_index=", feature_index())}; +} + +HloBatchNormTrainingInstruction::HloBatchNormTrainingInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormTraining, shape, operand, + scale, epsilon, feature_index) { + AppendOperand(offset); +} + +std::unique_ptr +HloBatchNormTrainingInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], epsilon(), + feature_index()); +} + +HloBatchNormInferenceInstruction::HloBatchNormInferenceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormInference, shape, operand, + scale, epsilon, feature_index) { + AppendOperand(offset); + AppendOperand(mean); + AppendOperand(variance); +} + +std::unique_ptr +HloBatchNormInferenceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 5); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); +} + +HloBatchNormGradInstruction::HloBatchNormGradInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* mean, HloInstruction* variance, HloInstruction* grad_output, + float epsilon, int64 feature_index) + : HloBatchNormInstruction(HloOpcode::kBatchNormGrad, shape, operand, scale, + epsilon, feature_index) { + AppendOperand(mean); + AppendOperand(variance); + AppendOperand(grad_output); +} + +std::unique_ptr +HloBatchNormGradInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 5); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); +} + +HloFftInstruction::HloFftInstruction( + const Shape& shape, HloInstruction* operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length) + : HloInstruction(HloOpcode::kFft, shape), fft_type_(fft_type) { + fft_length_.assign(fft_length.begin(), fft_length.end()); + AppendOperand(operand); +} + +HloInstructionProto HloFftInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_fft_type(fft_type_); + for (int64 fft_len : fft_length_) { + proto.add_fft_length(fft_len); + } + return proto; +} + +std::vector HloFftInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("fft_type=", FftType_Name(fft_type())), + StrCat("fft_length={", Join(fft_length(), ","), "}")}; +} + +bool HloFftInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return fft_type() == casted_other.fft_type() && + fft_length() == casted_other.fft_length(); +} + +std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], fft_type_, + fft_length_); +} + +HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, + const Shape& shape, + int64 channel_id) + : HloInstruction(opcode, shape), channel_id_(channel_id) {} + +HloInstructionProto HloSendRecvInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_channel_id(channel_id_); + return proto; +} + +std::vector HloSendRecvInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("channel_id=", channel_id_)}; +} + +bool HloSendRecvInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + // Not yet supported. + return false; +} + +// Send instruction produces a tuple of {aliased operand, U32 context}. +HloSendInstruction::HloSendInstruction(HloInstruction* operand, + int64 channel_id) + : HloSendRecvInstruction( + HloOpcode::kSend, + ShapeUtil::MakeTupleShape( + {CHECK_NOTNULL(operand)->shape(), ShapeUtil::MakeShape(U32, {})}), + channel_id) { + AppendOperand(operand); +} + +std::unique_ptr HloSendInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(new_operands[0], channel_id()); +} + +HloSendDoneInstruction::HloSendDoneInstruction(HloSendInstruction* operand) + : HloSendRecvInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil(), + CHECK_NOTNULL(operand)->channel_id()) { + AppendOperand(operand); +} + +std::unique_ptr +HloSendDoneInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + Cast(new_operands[0])); +} + +// Recv instruction produces a tuple of {receive buffer, U32 context}. +HloRecvInstruction::HloRecvInstruction(const Shape& shape, int64 channel_id) + : HloSendRecvInstruction( + HloOpcode::kRecv, + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}), + channel_id) {} + +std::unique_ptr HloRecvInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 0); + return MakeUnique( + ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); +} + +HloRecvDoneInstruction::HloRecvDoneInstruction(HloRecvInstruction* operand) + : HloSendRecvInstruction( + HloOpcode::kRecvDone, + ShapeUtil::GetTupleElementShape(operand->shape(), 0), + CHECK_NOTNULL(operand)->channel_id()) { + AppendOperand(operand); +} + +std::unique_ptr +HloRecvDoneInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + Cast(new_operands[0])); +} + +HloReverseInstruction::HloReverseInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) + : HloInstruction(HloOpcode::kReverse, shape), + dimensions_(dimensions.begin(), dimensions.end()) { + AppendOperand(operand); +} + +HloInstructionProto HloReverseInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloReverseInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloReverseInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr HloReverseInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloConcatenateInstruction::HloConcatenateInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension) + : HloInstruction(HloOpcode::kConcatenate, shape), dimensions_({dimension}) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloInstructionProto HloConcatenateInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloConcatenateInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloConcatenateInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloConcatenateInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, new_operands, + dimensions(0)); +} + +HloReduceInstruction::HloReduceInstruction( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation) + : HloInstruction(HloOpcode::kReduce, shape), + dimensions_(dimensions_to_reduce.begin(), dimensions_to_reduce.end()) { + AppendOperand(arg); + AppendOperand(init_value); + AppendComputation(reduce_computation); +} + +HloInstructionProto HloReduceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloReduceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloReduceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + // Reduction results are determined by the reduction dimension and the + // reduction computation. + return dimensions() == casted_other.dimensions() && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr HloReduceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique( + shape, new_operands[0], new_operands[1], dimensions(), to_apply()); +} + +HloTransposeInstruction::HloTransposeInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions) + : HloInstruction(HloOpcode::kTranspose, shape), + dimensions_(dimensions.begin(), dimensions.end()) { + CHECK_EQ(shape.dimensions().size(), dimensions.size()); + CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); + CHECK(std::equal(operand->shape().dimensions().begin(), + operand->shape().dimensions().end(), + Permute(dimensions, shape.dimensions()).begin())) + << "shape: " << ShapeUtil::HumanString(shape) + << ", operand->shape(): " << ShapeUtil::HumanString(shape) + << ", dimensions: {" << Join(dimensions, ", ") << "}"; + AppendOperand(operand); +} + +bool HloTransposeInstruction::IsRank2Transpose() const { + return dimensions() == std::vector({1, 0}) && + shape().dimensions_size() == 2 && + std::equal(shape().dimensions().begin(), shape().dimensions().end(), + operand(0)->shape().dimensions().rbegin()); +} + +HloInstructionProto HloTransposeInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloTransposeInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloTransposeInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloTransposeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloBroadcastInstruction::HloBroadcastInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimension) + : HloInstruction(HloOpcode::kBroadcast, shape), + dimensions_(broadcast_dimension.begin(), broadcast_dimension.end()) { + AppendOperand(operand); +} + +HloInstructionProto HloBroadcastInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +std::vector HloBroadcastInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloBroadcastInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return dimensions() == casted_other.dimensions(); +} + +std::unique_ptr +HloBroadcastInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + dimensions()); +} + +HloMapInstruction::HloMapInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation, + tensorflow::gtl::ArraySlice static_operands) + : HloInstruction(HloOpcode::kMap, shape) { + CHECK(static_operands.empty()) << "static_operands not yet supported"; + for (auto operand : operands) { + AppendOperand(operand); + } + AppendComputation(map_computation); + // TODO(b/65689298) Remove code below once Map is generalized to accept + // arbitrary map dimensions. + dimensions_.resize(ShapeUtil::Rank(shape)); + std::iota(dimensions_.begin(), dimensions_.end(), 0); +} + +HloInstructionProto HloMapInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + return proto; +} + +bool HloMapInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + if (!dimensions().empty()) { + // Check that the map is executed in elementwise compatible dimensions. + if (dimensions().size() != shape().dimensions_size()) { + return false; + } + for (int i = 0; i < dimensions().size(); ++i) { + if (dimensions()[i] != i) { + return false; + } + } + } + return true; +} + +std::vector HloMapInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("dimensions={", Join(dimensions(), ","), "}")}; +} + +bool HloMapInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return eq_computations(to_apply(), other.to_apply()); +} + +std::unique_ptr HloMapInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, new_operands, to_apply()); +} + +HloSliceInstruction::HloSliceInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) + : HloInstruction(HloOpcode::kSlice, shape), + slice_starts_(start_indices.begin(), start_indices.end()), + slice_limits_(limit_indices.begin(), limit_indices.end()), + slice_strides_(strides.begin(), strides.end()) { + AppendOperand(operand); + // For backward compatibility with old serialized computations: if there are + // no strides, assume all strides are 1. + // TODO(b/63317920): remove this code. + if (slice_strides_.empty()) { + slice_strides_ = std::vector(start_indices.size(), 1LL); + } +} + +HloInstructionProto HloSliceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + for (int i = 0; i < slice_starts_.size(); ++i) { + auto* slice_dimension = proto.add_slice_dimensions(); + slice_dimension->set_start(slice_starts_[i]); + slice_dimension->set_limit(slice_limits_[i]); + slice_dimension->set_stride(slice_strides_[i]); + } + return proto; +} + +std::vector HloSliceInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector bounds; + bounds.reserve(slice_starts_.size()); + const bool omit_stride = + std::all_of(slice_strides_.begin(), slice_strides_.end(), + [](int64 stride) { return stride == 1; }); + for (int i = 0; i < slice_starts_.size(); ++i) { + string stride_str = omit_stride ? "" : StrCat(":", slice_strides_[i]); + bounds.push_back( + StrCat("[", slice_starts_[i], ":", slice_limits_[i], stride_str, "]")); + } + return {StrCat("slice={", Join(bounds, ", "), "}")}; +} + +bool HloSliceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& other_slice = static_cast(other); + return slice_starts_ == other_slice.slice_starts_ && + slice_limits_ == other_slice.slice_limits_ && + slice_strides_ == other_slice.slice_strides_; +} + +std::unique_ptr HloSliceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], slice_starts_, + slice_limits_, slice_strides_); +} + +HloConstantInstruction::HloConstantInstruction(std::unique_ptr literal) + : HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()), + literal_(std::move(literal)) {} + +HloInstructionProto HloConstantInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_literal() = literal_->ToProto(); + return proto; +} + +bool HloConstantInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + return true; +} + +void HloConstantInstruction::RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index) { + Shape* mutable_array_subshape = + ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index); + CHECK(ShapeUtil::IsArray(*mutable_array_subshape)); + + // Normally array_subshape will always have a layout, but this invariant is + // temporarily broken in LayoutAssignment::AssignLayouts. + + if (!mutable_array_subshape->has_layout() || + !LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) { + literal_ = literal_->Relayout(new_layout, shape_index); + *mutable_array_subshape->mutable_layout() = new_layout; + } +} + +bool HloConstantInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& other_slice = static_cast(other); + return literal() == other_slice.literal(); +} + +std::unique_ptr +HloConstantInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(literal_->CloneToUnique()); +} + +string HloConstantInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + string operands; + // For constants, show the actual value in place of an empty operand list. + if ((!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || + options.print_large_constants()) { + // Literal::ToString emits multidimensional arrays over multiple + // lines. Compact this into one line by stripping out white space. + string tmp = literal().ToString(); + std::replace(tmp.begin(), tmp.end(), '\n', ' '); + std::vector v = tensorflow::str_util::Split(tmp, ' '); + bool first = true; + // Concatenate elements in "v" with spaces separating them, but ignoring + // empty entries. + for (const auto& s : v) { + if (s.empty()) { + continue; + } + StrAppend(&operands, (first ? "" : " "), s); + first = false; + } + } else { + // Do not show large constants or tuples. + operands = "{...}"; + } + return operands; +} + +HloTraceInstruction::HloTraceInstruction(const string& tag, + HloInstruction* operand) + : HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()), + literal_(Literal::CreateR1U8(tag)) { + AppendOperand(operand); + operand->set_tracing(this); +} + +HloInstructionProto HloTraceInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_literal() = literal_->ToProto(); + return proto; +} + +bool HloTraceInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return false; +} + +std::unique_ptr HloTraceInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode()); +} + +HloFusionInstruction::HloFusionInstruction(const Shape& shape, + FusionKind fusion_kind, + HloInstruction* fused_root) + : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { + CHECK(fused_root != nullptr); + SetAndSanitizeName("fusion"); + set_parent(fused_root->parent()); + set_metadata(fused_root->metadata()); + CloneAndFuseInternal(fused_root); +} + +HloFusionInstruction::HloFusionInstruction( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation) + : HloInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { + for (auto operand : operands) { + AppendOperand(operand); + } + SetAndSanitizeName("fusion"); + AppendComputation(fusion_computation); + fusion_computation->SetFusionInstruction(this); +} + +HloInstructionProto HloFusionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_fusion_kind(xla::ToString(fusion_kind())); + proto.add_called_computation_ids( + fused_instructions_computation()->unique_id()); + return proto; +} + +bool HloFusionInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + if (fusion_kind() != FusionKind::kLoop) { + return false; + } + + if (!operand_idx.has_value()) { + for (auto* fused : fused_instructions()) { + if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) { + return false; + } + } + return true; + } + // A loop-fusion is elementwise on an operand if all operations (computed + // using BFS) between the operand and the fused root are elementwise. + std::deque worklist; + std::unordered_set visited; + worklist.push_back(fused_parameter(operand_idx.value())); + visited.insert(fused_parameter(operand_idx.value())); + while (!worklist.empty()) { + HloInstruction* operand = worklist.front(); + worklist.pop_front(); + for (HloInstruction* user : operand->users()) { + CHECK_GE(user->unique_id(), 0); + if (ContainsKey(visited, user)) { + continue; + } + if (user->IsElementwise() || + IsInstructionElementwiseOnOperand(user, operand)) { + worklist.push_back(user); + visited.insert(user); + } else { + return false; + } + } + } + return true; +} + +HloInstruction* HloFusionInstruction::AddFusionOperand( + HloInstruction* new_operand) { + CHECK_EQ(operand_count(), + fused_instructions_computation()->parameter_instructions().size()); + const int64 param_no = operand_count(); + // Name the parameter after the instruction it represents in the outer + // (non-fusion) computation. + string param_name = StrCat(new_operand->name(), ".param_", param_no); + HloInstruction* fused_parameter = + fused_instructions_computation()->AddParameter( + HloInstruction::CreateParameter(param_no, new_operand->shape(), + param_name)); + AppendOperand(new_operand); + return fused_parameter; +} + +void HloFusionInstruction::MergeFusionInstruction( + HloFusionInstruction* instruction_to_merge) { + CHECK(std::find(operands().begin(), operands().end(), instruction_to_merge) != + operands().end()); + // Clone the instruction from which to merge fused instructions. + std::unique_ptr cloned = instruction_to_merge->Clone(); + HloFusionInstruction* cloned_fusion = + static_cast(cloned.get()); + // Replace uses of fused parameters with the corresponding operand of the + // fusion. Add all non-parameter fused instructions to + // 'unfused_instructions' to be merged into 'this'. This is done in reverse + // post order. + std::vector unfused_instructions; + auto fused_instructions = cloned_fusion->fused_instructions_computation() + ->MakeInstructionPostOrder(); + for (auto fused_it = fused_instructions.rbegin(); + fused_it != fused_instructions.rend(); ++fused_it) { + auto fused_instruction = *fused_it; + if (fused_instruction->opcode() == HloOpcode::kParameter) { + TF_CHECK_OK( + fused_instruction->ReplaceAllUsesWith(cloned_fusion->mutable_operand( + fused_instruction->parameter_number()))); + } else { + unfused_instructions.push_back(fused_instruction); + } + } + CHECK(unfused_instructions.front() == cloned_fusion->fused_expression_root()); + // Replace instruction_to_merge use of 'this' with unfused_root. + TF_CHECK_OK( + instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front())); + // Fuse 'unfused_instructions' into 'this'. + for (auto& instruction : unfused_instructions) { + FuseInstruction(instruction); + instruction->DetachFromOperands(); + } + CHECK_EQ(0, cloned_fusion->user_count()); + cloned_fusion->DetachFromOperands(); + TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( + cloned_fusion->fused_instructions_computation())); +} + +void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput( + HloFusionInstruction* instruction_to_merge) { + // Add all non-parameter fused instructions to 'unfused_instructions' to be + // merged into 'this'. `old_to_new' maps the instructions in the fused node + // to the disaseembled fusion instructions. + // Note that we add the unfused instructions to this->parent_ computation. + // This is necessary because the unique_id needs for an instruction and + // it's only added when inserting to the computation. + tensorflow::gtl::FlatMap old_to_new; + std::vector unfused_instructions; + auto computation_to_merge = + instruction_to_merge->fused_instructions_computation(); + auto post_order = computation_to_merge->MakeInstructionPostOrder(); + for (auto rit = post_order.rbegin(); rit != post_order.rend(); ++rit) { + auto fused_instruction = *rit; + if (fused_instruction->opcode() == HloOpcode::kParameter) { + InsertOrDie(&old_to_new, fused_instruction, + instruction_to_merge->mutable_operand( + fused_instruction->parameter_number())); + continue; + } + + // Here we clone the insertion and call FuseInstructionIntoMultiOutput() + // which clones again. This can be improved. + auto cloned_instruction = + parent()->AddInstruction(fused_instruction->Clone()); + unfused_instructions.push_back(cloned_instruction); + InsertOrDie(&old_to_new, fused_instruction, cloned_instruction); + } + for (auto unfused_instruction : unfused_instructions) { + for (int64 index = 0; index < unfused_instruction->operand_count(); + index++) { + auto new_operand = + FindOrDie(old_to_new, unfused_instruction->mutable_operand(index)); + TF_CHECK_OK(unfused_instruction->ReplaceOperandWith(index, new_operand)); + } + } + + HloInstruction* unfused_root = unfused_instructions.front(); + TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root)); + + TF_CHECK_OK( + instruction_to_merge->parent()->RemoveInstruction(instruction_to_merge)); + if (GetModule()) { + TF_CHECK_OK(GetModule()->RemoveEmbeddedComputation(computation_to_merge)); + } + + // Fuse the root instruction and generate multiple outputs. + FuseInstructionIntoMultiOutput(unfused_root); + TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); + // The rest instructions are of normal fusing. + for (int64 i = 1; i < unfused_instructions.size(); i++) { + auto instruction = unfused_instructions[i]; + FuseInstruction(instruction); + TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); + } +} + +HloComputation* HloFusionInstruction::fused_instructions_computation() const { + CHECK(!called_computations().empty()); + auto* fused_instructions_computation = called_computations().front(); + CHECK(fused_instructions_computation->IsFusionComputation()) + << "Computation " << fused_instructions_computation->name() + << " is not a fusion kind"; + return fused_instructions_computation; +} + +HloInstruction* HloFusionInstruction::fused_expression_root() const { + return fused_instructions_computation()->root_instruction(); +} + +HloInstruction* HloFusionInstruction::fused_parameter( + int64 parameter_number) const { + return fused_instructions_computation()->parameter_instruction( + parameter_number); +} + +const std::vector& HloFusionInstruction::fused_parameters() + const { + return fused_instructions_computation()->parameter_instructions(); +} + +const tensorflow::gtl::iterator_range>::const_iterator>> +HloFusionInstruction::fused_instructions() const { + const HloComputation* subcomp = fused_instructions_computation(); + return subcomp->instructions(); +} + +const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> +HloFusionInstruction::fused_instructions() { + return fused_instructions_computation()->instructions(); +} + +int64 HloFusionInstruction::fused_instruction_count() const { + return fused_instructions_computation()->instruction_count(); +} + +HloInstruction* HloFusionInstruction::FuseInstructionInternal( + HloInstruction* instruction_to_fuse, bool add_output) { + // When add_output is false, this fusion instruction must be a user of + // instruction_to_fuse. + if (!add_output) { + CHECK(IsUserOf(instruction_to_fuse)); + } + HloInstruction* fused_instruction = + CloneAndFuseInternal(instruction_to_fuse, add_output); + return fused_instruction; +} + +HloInstruction* HloFusionInstruction::CloneAndFuseInternal( + HloInstruction* instruction_to_fuse, bool add_output) { + CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); + VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); + HloInstruction* clone = nullptr; + if (called_computations().empty()) { + // New fusion instruction. It should not be a multioutput instruction. + CHECK(!add_output); + auto builder = HloComputation::Builder("fused_computation", this); + builder.AddInstruction(instruction_to_fuse->Clone(/*suffix=*/"")); + AppendComputation( + CHECK_NOTNULL(GetModule())->AddEmbeddedComputation(builder.Build())); + clone = fused_expression_root(); + } else { + clone = fused_instructions_computation()->AddInstruction( + instruction_to_fuse->Clone(/*suffix=*/"")); + // When add_output is false, instruction_to_fuse is necessarily an operand + // of the fusion instruction. After fusion this will no longer be the + // case. Remove the operand from the operand list and remove its + // corresponding fused parameter instruction. Renumber parameters as + // necessary to make parameter numbers consistent with their index in the + // fused_parameter_ vector. + bool in_operand_list = std::find(operands().begin(), operands().end(), + instruction_to_fuse) != operands().end(); + CHECK(add_output || in_operand_list); + const std::vector& fused_parameters = + fused_instructions_computation()->parameter_instructions(); + for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) { + if (instruction_to_fuse == operand(operand_num)) { + // replace the fused parameter instruction's uses with the clone. + HloInstruction* fused_parameter = fused_parameters[operand_num]; + TF_CHECK_OK(fused_parameter->ReplaceAllUsesWith(clone)); + + // Remove the corresponding fused parameter and operand from their + // respective vectors. + TF_CHECK_OK( + fused_instructions_computation()->RemoveParameter(operand_num)); + RemoveOperandAt(operand_num); + break; + } + } + // We've cloned instruction_to_fuse into this fusion instruction, so this + // fusion instruction is no longer a use of instruction_to_fuse. + if (in_operand_list) { + DetachFrom(instruction_to_fuse); + // When the instruction_to_fuse does not have other users, we don't need + // to generate a multioutput fusion instruction. + if (instruction_to_fuse->user_count() == 0) { + add_output = false; + } + } + } + + // Reread the parameters in the computation. + const std::vector& fused_parameters = + fused_instructions_computation()->parameter_instructions(); + + // Add each operand of the clone as an operand of the fusion instruction. A + // complication is that some clone operands may already be operands of the + // fusion instruction. + for (int64 operand_num = 0; operand_num < clone->operand_count(); + ++operand_num) { + HloInstruction* operand = clone->mutable_operand(operand_num); + + // See if this operand is already an operand of the fusion node. + CHECK_EQ(operands().size(), fused_parameters.size()); + HloInstruction* fused_param = nullptr; + for (int64 i = 0; i < operands().size(); ++i) { + if (this->operand(i) == operand) { + fused_param = fused_parameters[i]; + break; + } + } + + if (fused_param == nullptr) { + // Clone's operand was not already an operand of the fusion + // instruction. Add it as an operand and add a corresponding fused + // parameter instruction. + fused_param = AddFusionOperand(operand); + } + TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); + } + + if (add_output) { + CHECK_GT(instruction_to_fuse->user_count(), 0); + // If this is already a multioutput fusion instruction, expand the root + // tuple by 1. + HloInstruction* fused_root = fused_expression_root(); + HloInstruction::InstructionVector tuple_elements; + bool newly_created_tuple_instr = false; + if (fused_root->opcode() == HloOpcode::kTuple) { + tuple_elements = fused_root->operands(); + } else { + tuple_elements.push_back(fused_root); + newly_created_tuple_instr = true; + } + if (clone->opcode() == HloOpcode::kTuple) { + for (auto inst : clone->operands()) { + tuple_elements.push_back(inst); + } + } else { + tuple_elements.push_back(clone); + } + HloInstruction* new_root = fused_instructions_computation()->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); + fused_instructions_computation()->set_root_instruction(new_root); + *mutable_shape() = new_root->shape(); + if (fused_root->opcode() == HloOpcode::kTuple) { + TF_CHECK_OK( + fused_instructions_computation()->RemoveInstruction(fused_root)); + } + + // If this is a newly created multioutput instruction, we need to update + // the use of the original fusion instruction. + if (newly_created_tuple_instr) { + HloInstruction* new_instr = parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(fused_root->shape(), this, 0)); + TF_CHECK_OK(ReplaceAllUsesWith(new_instr)); + } + int64 index = tuple_elements.size(); + if (instruction_to_fuse->opcode() == HloOpcode::kTuple) { + index -= instruction_to_fuse->operand_count(); + std::vector to_be_removed; + for (auto old_gte : instruction_to_fuse->users()) { + CHECK_EQ(old_gte->opcode(), HloOpcode::kGetTupleElement); + int64 old_tuple_index = old_gte->tuple_index(); + HloInstruction* new_gte = + parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + old_gte->shape(), this, index + old_tuple_index)); + TF_CHECK_OK(old_gte->ReplaceAllUsesWith(new_gte)); + to_be_removed.push_back(old_gte); + } + for (auto old_gte : to_be_removed) { + TF_CHECK_OK(parent()->RemoveInstruction(old_gte)); + } + TF_CHECK_OK(fused_instructions_computation()->RemoveInstruction(clone)); + } else { + HloInstruction* new_gte = + parent()->AddInstruction(HloInstruction::CreateGetTupleElement( + clone->shape(), this, index - 1)); + TF_CHECK_OK(instruction_to_fuse->ReplaceAllUsesWith(new_gte)); + } + } + + VLOG(2) << "New clone:\n" << clone->ToString(); + return clone; +} + +std::vector HloFusionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("kind=", xla::ToString(fusion_kind()))}; +} + +bool HloFusionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return fusion_kind() == other.fusion_kind() && + eq_computations(fused_instructions_computation(), + other.fused_instructions_computation()); +} + +std::unique_ptr HloFusionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + HloModule* module = context != nullptr ? context->module() : GetModule(); + HloComputation* new_fused_computation = nullptr; + if (context != nullptr) { + new_fused_computation = + context->FindComputation(fused_instructions_computation()); + } + if (new_fused_computation == nullptr) { + new_fused_computation = module->AddEmbeddedComputation( + fused_instructions_computation()->Clone("clone", context)); + } + return MakeUnique(shape, fusion_kind(), new_operands, + new_fused_computation); +} + +HloRngInstruction::HloRngInstruction( + const Shape& shape, RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters) + : HloInstruction(HloOpcode::kRng, shape), distribution_(distribution) { + for (HloInstruction* param : parameters) { + AppendOperand(param); + } +} + +HloInstructionProto HloRngInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_distribution(distribution_); + return proto; +} + +std::vector HloRngInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("distribution=", RandomDistributionToString(distribution_))}; +} + +bool HloRngInstruction::IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const { + return true; +} + +bool HloRngInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + return false; +} + +std::unique_ptr HloRngInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(shape, distribution_, new_operands); +} + +HloParameterInstruction::HloParameterInstruction(int64 parameter_number, + const Shape& shape, + const string& name) + : HloInstruction(HloOpcode::kParameter, shape), + parameter_number_(parameter_number) { + SetAndSanitizeName(name); +} + +HloInstructionProto HloParameterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_parameter_number(parameter_number_); + return proto; +} + +string HloParameterInstruction::OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const { + return StrCat(parameter_number_); +} + +bool HloParameterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return parameter_number() == casted_other.parameter_number(); +} + +std::unique_ptr +HloParameterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + return MakeUnique(parameter_number_, shape, name()); +} + +HloGetTupleElementInstruction::HloGetTupleElementInstruction( + const Shape& shape, HloInstruction* operand, int64 index) + : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) { + CHECK(ShapeUtil::IsTuple(operand->shape())); + AppendOperand(operand); +} + +HloInstructionProto HloGetTupleElementInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_tuple_index(tuple_index_); + return proto; +} + +std::vector HloGetTupleElementInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("index=", tuple_index())}; +} + +bool HloGetTupleElementInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + return tuple_index() == casted_other.tuple_index(); +} + +std::unique_ptr +HloGetTupleElementInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique(shape, new_operands[0], + tuple_index()); +} + +HloReducePrecisionInstruction::HloReducePrecisionInstruction( + const Shape& shape, HloInstruction* operand, const int exponent_bits, + const int mantissa_bits) + : HloInstruction(HloOpcode::kReducePrecision, shape), + exponent_bits_(exponent_bits), + mantissa_bits_(mantissa_bits) { + AppendOperand(operand); +} + +HloInstructionProto HloReducePrecisionInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_exponent_bits(exponent_bits_); + proto.set_mantissa_bits(mantissa_bits_); + return proto; +} + +std::vector HloReducePrecisionInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {StrCat("exponent_bits=", exponent_bits_), + StrCat("mantissa_bits=", mantissa_bits_)}; +} + +bool HloReducePrecisionInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = + static_cast(other); + // A reduce-precision operation is determined by the bit sizes. + return exponent_bits() == casted_other.exponent_bits() && + mantissa_bits() == casted_other.mantissa_bits(); +} + +std::unique_ptr +HloReducePrecisionInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return MakeUnique( + shape, new_operands[0], exponent_bits(), mantissa_bits()); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h new file mode 100644 index 0000000000000000000000000000000000000000..6749d875559008b3a2bd479ff075c83d85d87509 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -0,0 +1,727 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// All HloInstruction subclasses are put in this file. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { + +class HloBatchNormInstruction : public HloInstruction { + public: + // Returns feature_index field associated with the instruction. The index + // represents the index of the feature dimension. + int64 feature_index() const { return feature_index_; } + + // Returns a epsilon value associated with the instruction. The is a small + // number added to the variance to avoid divide-by-zero error. + float epsilon() const { return epsilon_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + protected: + explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, float epsilon, + int64 feature_index); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // A small float number added to the variance to avoid divide-by-zero error. + float epsilon_ = 0.0f; + + // An integer value representing the index of the feature dimension. + int64 feature_index_ = -1; +}; + +class HloBatchNormTrainingInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormTrainingInstruction(const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, + HloInstruction* offset, + float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloBatchNormInferenceInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormInferenceInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, HloInstruction* mean, HloInstruction* variance, + float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloBatchNormGradInstruction : public HloBatchNormInstruction { + public: + explicit HloBatchNormGradInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* mean, HloInstruction* variance, + HloInstruction* grad_output, float epsilon, int64 feature_index); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloFftInstruction : public HloInstruction { + public: + explicit HloFftInstruction(const Shape& shape, HloInstruction* operand, + FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + FftType fft_type() const { return fft_type_; } + + const std::vector& fft_length() const { return fft_length_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes FFT type for an FFT instruction. + FftType fft_type_ = FftType::FFT; + + // Indicates the FFT length for an FFT instruction. + std::vector fft_length_; +}; + +class HloSendRecvInstruction : public HloInstruction { + public: + // Returns the channel id associated with the instruction. The id is + // shared between each Send/Recv pair and is globally unique to identify each + // channel. + int64 channel_id() const { return channel_id_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + protected: + explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, + int64 channel_id); + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Represents a unique identifier for each Send/Recv instruction pair. + int64 channel_id_; +}; + +class HloSendInstruction : public HloSendRecvInstruction { + public: + explicit HloSendInstruction(HloInstruction* operand, int64 channel_id); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloSendDoneInstruction : public HloSendRecvInstruction { + public: + explicit HloSendDoneInstruction(HloSendInstruction* operand); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloRecvInstruction : public HloSendRecvInstruction { + public: + explicit HloRecvInstruction(const Shape& shape, int64 channel_id); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloRecvDoneInstruction : public HloSendRecvInstruction { + public: + explicit HloRecvDoneInstruction(HloRecvInstruction* operand); + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; +}; + +class HloReverseInstruction : public HloInstruction { + public: + explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloConcatenateInstruction : public HloInstruction { + public: + explicit HloConcatenateInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + int64 dimension); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Accessor for the dimension in which a concatenate HLO should occur. + int64 concatenate_dimension() const { return dimensions(0); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloReduceInstruction : public HloInstruction { + public: + explicit HloReduceInstruction( + const Shape& shape, HloInstruction* arg, HloInstruction* init_value, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reduce_computation); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloTransposeInstruction : public HloInstruction { + public: + explicit HloTransposeInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice dimensions); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns whether this instruction does a rank-2 transposition. + bool IsRank2Transpose() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloBroadcastInstruction : public HloInstruction { + public: + explicit HloBroadcastInstruction( + const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice broadcast_dimension); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloMapInstruction : public HloInstruction { + public: + explicit HloMapInstruction( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloComputation* map_computation, + tensorflow::gtl::ArraySlice static_operands = {}); + // Returns the dimension sizes or numbers associated with this instruction. + const std::vector& dimensions() const override { return dimensions_; } + int64 dimensions(int64 index) const override { return dimensions()[index]; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + std::vector dimensions_; +}; + +class HloSliceInstruction : public HloInstruction { + public: + explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + HloInstructionProto ToProto() const override; + + // Returns the start index in the given dimension for a slice node. + int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; } + const std::vector& slice_starts() const { return slice_starts_; } + + // Returns the (exclusive) limit index in the given dimension for a slice + // node. + int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; } + const std::vector& slice_limits() const { return slice_limits_; } + + // Returns the stride in the given dimension for a slice node. + int64 slice_strides(int64 dimension) const { + return slice_strides_[dimension]; + } + const std::vector& slice_strides() const { return slice_strides_; } + + // Returns the flag that describes whether a slice must be lowered into an + // offset into the original operand. + bool IsInPlaceSlice() const { return is_in_place_slice_; } + + // Sets and returns the flag that describes whether a slice must be lowered + // into an offset into the original operand. + bool SetIsInPlaceSlice(bool value) { + is_in_place_slice_ = value; + return value; + } + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // Describes the [begin, end) index range for a slice. + std::vector slice_starts_; + std::vector slice_limits_; + std::vector slice_strides_; + + // Describes whether the slice can be lowered to an offset into the operand. + bool is_in_place_slice_ = false; +}; + +class HloConstantInstruction : public HloInstruction { + public: + explicit HloConstantInstruction(std::unique_ptr literal); + // Returns the literal associated with this instruction. + const Literal& literal() const { return *literal_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Change the layout for an Constant Hlo instruction to match new_layout. For + // tuple shaped constants shape_index is the path to the internal array + // subshape whose layout needs to be changed. + void RelayoutConstant(const Layout& new_layout, + const ShapeIndex& shape_index = {}); + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // TODO(b/36360764): Remove unique_ptr wrapping. + std::unique_ptr literal_; +}; + +class HloTraceInstruction : public HloInstruction { + public: + explicit HloTraceInstruction(const string& tag, HloInstruction* operand); + // Returns a tag to be used in tracing. + string TracingTag() const { return literal_->GetR1U8AsString(); } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + // TODO(b/36360764): Remove unique_ptr wrapping. + std::unique_ptr literal_; +}; + +class HloFusionInstruction : public HloInstruction { + public: + explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, + HloInstruction* fused_root); + + explicit HloFusionInstruction( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation); + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Adds a new operand the fusion instruction. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + + // Merges the fused instructions from 'instruction_to_merge' into the + // fused instruction set of 'this', updating operands as necessary. + // + // Predondition: 'instruction_to_merge' must be an operand of 'this'. + void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge); + + // Merges the fused instructions from instruction_to_merge into the fused + // instruction set of 'this' and generates multioutput fusion instructions. + // All the users of instruction_to_merge will be redirected to 'this' + // instruction. instruction_to_merge will be removed from its parent + // computation. + void MergeFusionInstructionIntoMultiOutput( + HloFusionInstruction* instruction_to_merge); + + // Fuses the given instruction in this fusion instruction. instruction_to_fuse + // is cloned and the clone is placed in the fusion + // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather + // than moved to cleanly handle the case where the instruction has a use + // outside the fusion instruction. Moving such an instruction into a fusion + // instruction would violate the single-result invariant of HLO instructions + // and significantly complicate code generation. + HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) { + return FuseInstructionInternal(instruction_to_fuse); + } + + // Fuses the given instruction in this fusion instruction and generate + // multioutput fusion instruction. A clone of the instruction_to_fuse will + // be part of the output of fusion instructions. The users of + // instruction_to_fuse will be redirected to this fusion instructions. + // instruction_to_fuse will be removed from its parent computation. + HloInstruction* FuseInstructionIntoMultiOutput( + HloInstruction* instruction_to_fuse) { + return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true); + } + + // Returns the computation for this fused instruction. + HloComputation* fused_instructions_computation() const; + + // Returns the root instruction of the fused expression contained within this + // fusion instruction. + HloInstruction* fused_expression_root() const; + + // Returns the list of fused instructions inside this fusion instruction. The + // returned type is a range of HloInstruction*s. + const tensorflow::gtl::iterator_range>::const_iterator>> + fused_instructions() const; + + const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> + fused_instructions(); + + // Gets the number of instructions inside this fusion instruction. + int64 fused_instruction_count() const; + + // Returns the fused parameter instruction in this fusion instruction + // corresponding to the given parameter number. + HloInstruction* fused_parameter(int64 parameter_number) const; + + // Returns the vector of fused parameters inside this fusion instruction. + const std::vector& fused_parameters() const; + + // Returns true if this instruction is a fusion instruction that generates + // multiple outputs. + const bool IsMultiOutputFusion() const { + return fused_expression_root()->opcode() == HloOpcode::kTuple; + } + + FusionKind fusion_kind() const { return fusion_kind_; } + + void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } + + private: + // Fuses the given instruction into this fusion instruction. When add_output + // is false (which is the default), instruction_to_fuse is cloned and the + // clone is placed in the fusion instruction. instruction_to_fuse is + // unchanged. + // + // When add_output is true, a clone of the instruction_to_fuse will be part + // of the output of fusion instructions. The users of instruction_to_fuse + // will be redirected to this fusion instructions. instruction_to_fuse will + // be removed from its parent computation. + HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse, + bool add_output = false); + // Clones the given instruction_to_fuse and insert the clone into this fusion + // instruction. If add_output is true, a clone of instruction_to_fuse will + // be in the output of the this fusion instruction (part of the tuple of the + // fusion root). + HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse, + bool add_output = false); + + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The type of the fusion. Used by kFusion only. + FusionKind fusion_kind_; +}; + +class HloRngInstruction : public HloInstruction { + public: + explicit HloRngInstruction( + const Shape& shape, RandomDistribution distribution, + tensorflow::gtl::ArraySlice parameters); + // Returns the random distribution for this rng node. + RandomDistribution random_distribution() const { return distribution_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IsElementwiseImpl( + const tensorflow::gtl::optional& operand_idx) const override; + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The distribution requested for random number generation. + RandomDistribution distribution_; +}; + +class HloParameterInstruction : public HloInstruction { + public: + explicit HloParameterInstruction(int64 parameter_number, const Shape& shape, + const string& name); + int64 parameter_number() const { return parameter_number_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + string OperandsToStringWithCanonicalNameMap( + const HloPrintOptions& options, + CanonicalNameMap* canonical_name_map) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + int64 parameter_number_ = 0; +}; + +class HloGetTupleElementInstruction : public HloInstruction { + public: + explicit HloGetTupleElementInstruction(const Shape& shape, + HloInstruction* operand, int64 index); + // Returns the tuple index associated with this instruction. + int64 tuple_index() const { return tuple_index_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + int64 tuple_index_ = -1; +}; + +class HloReducePrecisionInstruction : public HloInstruction { + public: + explicit HloReducePrecisionInstruction(const Shape& shape, + HloInstruction* operand, + const int exponent_bits, + const int mantissa_bits); + // Returns the number of exponent bits for a reduce-precision node. + int32 exponent_bits() const { return exponent_bits_; } + // Returns the number of mantissa bits for a reduce-precision node. + int32 mantissa_bits() const { return mantissa_bits_; } + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const override; + + // The bit sizes for a reduce-precision operation. + int32 exponent_bits_ = 0; + int32 mantissa_bits_ = 0; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc similarity index 95% rename from tensorflow/compiler/xla/tools/parser/hlo_lexer.cc rename to tensorflow/compiler/xla/service/hlo_lexer.cc index 350db126535e418cbfa914edd958f47ba90a3ee5..f0d9fdbc8f86da0bb9d7f9235239df677c9506bc 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/service/hlo_lexer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include @@ -26,9 +26,8 @@ limitations under the License. #include "tensorflow/core/platform/regexp.h" namespace xla { -namespace tools { -using tensorflow::StringPiece; +using ::tensorflow::StringPiece; namespace { @@ -67,12 +66,12 @@ bool HloLexer::CanDereference(const char* ptr) const { return ptr < buf_.end() && ptr >= buf_.begin(); } -StringPiece HloLexer::StringPieceFromPointers(const char* begin, - const char* end) const { +tensorflow::StringPiece HloLexer::StringPieceFromPointers( + const char* begin, const char* end) const { CHECK(begin <= end); CHECK(begin == buf_.end() || CanDereference(begin)); CHECK(end == buf_.end() || CanDereference(end)); - return StringPiece(begin, end - begin); + return tensorflow::StringPiece(begin, end - begin); } tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( @@ -197,7 +196,8 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kAttributeName; } - StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_); + tensorflow::StringPiece identifier = + StringPieceFromPointers(token_start_, current_ptr_); // See if this is a keyword. #define KEYWORD(STR) \ @@ -332,23 +332,24 @@ std::pair HloLexer::GetLineAndColumn(LocTy location) const { line_no_cache_.last_query = ptr; line_no_cache_.line_no_of_query = line_no; size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); - if (line_offset == StringPiece::npos) { + if (line_offset == tensorflow::StringPiece::npos) { line_offset = 0; } return {line_no, ptr - start - line_offset}; } -StringPiece HloLexer::GetLine(LocTy loc) const { +tensorflow::StringPiece HloLexer::GetLine(LocTy loc) const { if (!CanDereference(loc)) { return "LINE OUT OF RANGE"; } size_t line_start = StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); - const char* start = line_start == StringPiece::npos + const char* start = line_start == tensorflow::StringPiece::npos ? buf_.begin() : buf_.begin() + line_start + 1; size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); - const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end; + const char* end = + line_end == tensorflow::StringPiece::npos ? buf_.end() : loc + line_end; return StringPieceFromPointers(start, end); } @@ -370,7 +371,7 @@ TokKind HloLexer::LexString() { static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; if (RE2::Consume(&consumable, *escaping_pattern)) { current_ptr_ = consumable.begin(); - StringPiece raw = + tensorflow::StringPiece raw = StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); string error; if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { @@ -453,5 +454,4 @@ string TokKindToString(TokKind kind) { } } -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h similarity index 90% rename from tensorflow/compiler/xla/tools/parser/hlo_lexer.h rename to tensorflow/compiler/xla/service/hlo_lexer.h index 27880b9b8afbfa58abfedc3b2cecd5236b78a6d6..ceb674f25e94ac3ac2e6a4a0687a93ffdcd065e0 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/service/hlo_lexer.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ #include -#include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/service/hlo_token.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -27,9 +27,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace tools { // Lexer for the HloModule::ToString() format text. +// +// This class is meant to be used by hlo_parser.cc. You shouldn't need to use +// it directly. class HloLexer { public: explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { @@ -57,7 +59,7 @@ class HloLexer { CHECK(GetKind() == TokKind::kShape); return shape_val_; } - int64 GetInt64Val() const { + tensorflow::int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; } @@ -114,7 +116,7 @@ class HloLexer { TokKind current_kind_; string str_val_; Shape shape_val_; - int64 int64_val_; + tensorflow::int64 int64_val_; double decimal_val_; struct LineNoCacheTy { @@ -125,7 +127,6 @@ class HloLexer { mutable LineNoCacheTy line_no_cache_{nullptr, 0}; }; -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_LEXER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc index 8e2e2c7627ba6ac9e5078446056917a07436cbd7..0275294a1a86cef13e5b267ad578f30cc18858dc 100644 --- a/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_liveness_analysis_test.cc @@ -18,12 +18,12 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -59,7 +59,7 @@ class HloLivenessAnalysisTest : public HloTestBase { // Test that add instruction at entry root is live at all output shape indices. TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -75,7 +75,7 @@ TEST_F(HloLivenessAnalysisTest, AddAtEntryRoot) { // Test that a dead add instruction is marked as dead by analysis. TEST_F(HloLivenessAnalysisTest, DeadAdd) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -94,7 +94,7 @@ TEST_F(HloLivenessAnalysisTest, DeadAdd) { // Test that all output shape indices of entry root tuple (and defining // instruction in its output) are marked live. TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -113,7 +113,7 @@ TEST_F(HloLivenessAnalysisTest, TupleAtEntryRoot) { // Tests that all outputs of nested tuple and entry root (and defining // instruction values appearing in its output) are marked live. TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(1) @@ -140,7 +140,7 @@ TEST_F(HloLivenessAnalysisTest, NestedTupleAtEntryRoot) { // Tests that GTE at entry root of Tuple instruction only propgates liveness // to the live elements in tuple. TEST_F(HloLivenessAnalysisTest, GteOfTuple) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -162,7 +162,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfTuple) { // Tests that GTE at entry root of nested Tuple instruction only propgates // liveness to the live elements in tuple. TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -199,7 +199,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfNestedTuple) { // Tests that GTE of GTE (at entry root) of nested Tuple instruction only // propgates liveness to the live elements in tuple. TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleModule ENTRY SimpleComputation { constant.1 = s32[] constant(0) @@ -240,7 +240,7 @@ TEST_F(HloLivenessAnalysisTest, GteOfGteOfNestedTuple) { // Test that live/dead while tuple elements are marked live/dead correctly. TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -291,7 +291,7 @@ TEST_F(HloLivenessAnalysisTest, WhileWithDeadTupleElement) { // Tests that a tuple element live in while.cond computation, propagates // liveness to while.body.root/while.result/while.operand (where it is unused). TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -345,7 +345,7 @@ TEST_F(HloLivenessAnalysisTest, WhileCondPropagatesLiveness) { // Tests that a use of while.result{0} propagates liveness to // while.body.param{1} to while.body.root{1}, and then to while.body.param{2}. TEST_F(HloLivenessAnalysisTest, WhileWithLiveTupleElements) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[], s32[]) parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index c33bdadf1c7145bf2aff09b01423c6c21382da0c..c570b420c21fed4d7828feb24ee5c7859db94a79 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MATCHERS_H_ #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -324,6 +325,12 @@ inline ::testing::Matcher Sharding( return ::testing::MakeMatcher( new ::xla::testing::HloShardingMatcher(sharding)); } +// Matcher for Sharding from sharding string +inline ::testing::Matcher Sharding( + tensorflow::StringPiece sharding) { + return ::testing::MakeMatcher(new ::xla::testing::HloShardingMatcher( + ParseSharding(sharding).ValueOrDie())); +} // Verifies that no HloSharding is set for an HLO instruction. inline ::testing::Matcher NoSharding() { return ::testing::MakeMatcher( diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 016cc01e33840aa195dfc0a21e8ac8f3d24a3e06..9a3010cf1ff75e840130d8442bbe26d6041cef25 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace op = xla::testing::opcode_matchers; using ::testing::_; @@ -147,6 +147,18 @@ TEST(HloMatchersTest, ShardingMatcher) { "param.1"); p1->set_sharding(HloSharding::AssignDevice(1)); + auto tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {7}), ShapeUtil::MakeShape(S32, {9}), + ShapeUtil::MakeShape(F32, {11})}); + auto p2 = HloInstruction::CreateParameter(1, tuple_shape, "param.2"); + Array assignment({2}); + assignment.SetValues({0, 1}); + auto sharding = HloSharding::Tuple( + tuple_shape, + {HloSharding::Tile(ShapeUtil::MakeShape(F32, {5}), assignment), + HloSharding::AssignDevice(1), HloSharding::Replicate()}); + p2->set_sharding(sharding); + EXPECT_THAT(p0.get(), op::NoSharding()); EXPECT_THAT(p0.get(), ::testing::Not(op::Sharding(HloSharding::AssignDevice(1)))); @@ -155,6 +167,11 @@ TEST(HloMatchersTest, ShardingMatcher) { ::testing::Not(op::Sharding(HloSharding::AssignDevice(0)))); EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1))); + EXPECT_THAT( + p2.get(), + op::Sharding( + "{{f32[5] devices=[2]0,1}, {maximal device=1}, {replicated}}")); + EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))), "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: " "{maximal device=1})"); @@ -178,7 +195,7 @@ ENTRY DotOperationFusion_TransposeFusion { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, op::Dot(op::Parameter(0), op::Parameter(1), diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index fbf1d58007e318a8a08aa9e11d9d54811533703e..9c59374b4a9d7e3dbfb99d8a6b30d4230e553658 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -32,15 +32,6 @@ limitations under the License. namespace xla { -HloModule::HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle, - const HloModuleConfig& config) - : name_(NameUniquer::GetSanitizedName(name)), - config_(config), - has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle), - unique_id_(next_unique_module_id_++) {} - HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(NameUniquer::GetSanitizedName(name)), config_(config), @@ -234,8 +225,7 @@ HloModuleProto HloModule::ToProto() const { /* static */ StatusOr> HloModule::CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config, - const VersionedComputationHandle& entry_computation_handle) { + const HloModuleProto& proto, const HloModuleConfig& module_config) { // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. TF_RET_CHECK(proto.has_program_shape()) @@ -287,8 +277,7 @@ StatusOr> HloModule::CreateFromProto( } TF_RET_CHECK(entry != nullptr); - auto module = MakeUnique(proto.name(), entry_computation_handle, - module_config); + auto module = MakeUnique(proto.name(), module_config); // Sort the computations in the proto id's order. std::sort(computations.begin(), computations.end(), @@ -401,7 +390,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( // as a parameter in the new function. arguments.push_back(old_operand); *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter( - parameter_count, old_operand->shape(), "")); + parameter_count, old_operand->shape(), "p")); ++parameter_count; } TF_CHECK_OK( @@ -496,7 +485,18 @@ std::list HloModule::MakeComputationPostOrder() const { added_computations.insert(computation.get()); } } - CHECK_EQ(post_order.size(), computations_.size()); + if (post_order.size() != computations_.size()) { + for (HloComputation* computation : post_order) { + LOG(ERROR) << "Post Order: " << computation->name() << " (" + << computation->parent()->name() << ")"; + } + for (auto& computation : computations_) { + LOG(ERROR) << "Computations: " << computation->name() << " (" + << computation->parent()->name() << ")"; + } + LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size() + << " computation_count=" << computations_.size(); + } return post_order; } @@ -514,57 +514,26 @@ std::vector HloModule::MakeNonfusionComputations() const { std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; auto module = MakeUnique(name_ + "-" + suffix, config_); - module->entry_computation_handle_ = entry_computation_handle_; - module->has_entry_computation_handle_ = has_entry_computation_handle_; - - std::unordered_map clone_map; - for (auto& computation : computations_) { - if (computation->IsFusionComputation()) { - // Cloning of a fused computation is handled by its fusion instruction. - continue; - } - // When cloning a computation, pass in the new module, so that for any - // fusion instruction in this computation, the fused computation will be - // deep cloned to the new module. - auto cloned_computation = computation->Clone(suffix, module.get()); - InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); - - if (entry_computation_ == computation.get()) { - module->AddEntryComputation(std::move(cloned_computation)); - } else { - module->AddEmbeddedComputation(std::move(cloned_computation)); - } - } - - for (auto& cloned_computation : module->computations_) { - for (auto* instruction : cloned_computation->instructions()) { - // Rewrite instruction's called_computation to point to the cloned - // computations. - instruction->ReplaceCalledComputations([&](HloComputation* hlo) { - if (hlo->IsFusionComputation()) { - // Cloning of a fused computation has already been handled when its - // fusion instruction is cloned. So this hlo computation is already - // the cloned one. - return hlo; - } - return FindOrDie(clone_map, hlo); - }); - } - } + HloCloneContext context(module.get(), suffix); + auto cloned_computation = entry_computation_->Clone(suffix, &context); + module->AddEntryComputation(std::move(cloned_computation)); return module; } -HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) { - HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this)); - TF_CHECK_OK( - clone->root_instruction()->Accept([this](HloInstruction* instruction) { - instruction->ReplaceCalledComputations([this](HloComputation* callee) { - return DeepCloneComputation(callee); - }); - return Status::OK(); - })); - return clone; +HloComputation* HloModule::DeepCloneComputation(HloComputation* computation, + HloCloneContext* context) { + HloComputation* new_computation; + if (context != nullptr) { + if ((new_computation = context->FindComputation(computation)) != nullptr) { + return new_computation; + } + new_computation = + AddEmbeddedComputation(computation->Clone(context->suffix(), context)); + } else { + new_computation = AddEmbeddedComputation(computation->Clone("")); + } + return new_computation; } uint64 HloModule::RandomNew64() const { diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 02918c377776b73f2086fe41afc406567a12af4c..757e65bda286d983d05e5a791aa7dffe97bac945 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -26,11 +26,11 @@ limitations under the License. #include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -56,10 +56,6 @@ namespace xla { // attached to. class HloModule { public: - HloModule(const string& name, - const VersionedComputationHandle& entry_computation_handle, - const HloModuleConfig& config); - // Constructor without a versioned computation handle. This constructor should // only be used for HloModules used outside of the XLA service (eg // tests). The versioned handle is used by the service in the compilation @@ -94,8 +90,10 @@ class HloModule { std::unique_ptr Clone(const string& suffix = "clone") const; // Performs a deep clone of the computation, by recursively cloning all - // the called computations as well. - HloComputation* DeepCloneComputation(HloComputation* computation); + // the called computations as well. If the clone context is specified, it + // will be populated with the cloned object mappings. + HloComputation* DeepCloneComputation(HloComputation* computation, + HloCloneContext* context = nullptr); // Return a pointer to the entry computation of the module.. const HloComputation* entry_computation() const { @@ -123,10 +121,6 @@ class HloModule { return config_.device_entry_computation_layout(); } - const VersionedComputationHandle& entry_computation_handle() const { - return entry_computation_handle_; - } - // Gets the computations in this module. // // Returns a view of HloComputation*s, so you can iterate over this in the @@ -185,9 +179,7 @@ class HloModule { // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; static StatusOr> CreateFromProto( - const HloModuleProto& proto, const HloModuleConfig& module_config, - const VersionedComputationHandle& entry_computation_handle = - VersionedComputationHandle()); + const HloModuleProto& proto, const HloModuleConfig& module_config); // Creates and returns an HloModuleConfig with an appropriate program shape // for the HLO module in the given proto. @@ -261,10 +253,6 @@ class HloModule { mutable std::mt19937_64 rng_{42}; mutable tensorflow::mutex rng_mutex_; - // Versioned handle of the entry computation of the module. - bool has_entry_computation_handle_ = false; - VersionedComputationHandle entry_computation_handle_; - // Unique name generator for computation and instruction names, which are // unique per module. NameUniquer computation_name_uniquer_{/*separator=*/"."}; diff --git a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc index 53b7d0ed3964ca8a2c3bb73c62015a1c7dbfe487..363862e4905fc13a4ef07aeaac255259fc6b86ba 100644 --- a/tensorflow/compiler/xla/service/hlo_module_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_dce_test.cc @@ -19,11 +19,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/types.h" @@ -73,7 +73,7 @@ class HloModuleDceTest : public HloTestBase { // Tests that a while with all outputs live is unmodified. TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -110,7 +110,7 @@ TEST_F(HloModuleDceTest, WhileWithLiveOutputs) { // Tests a while loop with one unused output (which is used in the while loop // body by an instruction with side-effects: rng) is unmodified. TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], f32[]) parameter(0) @@ -150,7 +150,7 @@ TEST_F(HloModuleDceTest, WhileWithUnusedSideEffectingTupleElement) { // Tests that a while loop with one dead tuple element at {1} has its while // loop body modified to make that tuple element pass-through the while body. TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -193,7 +193,7 @@ TEST_F(HloModuleDceTest, OneWhileWithDeadTupleElement) { // dead in while.body{1} and at while.result{1}) propgates liveness of this // tuple element to while.body{1} and at while.result{1}. TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body { loop_var.1 = (s32[], s32[]) parameter(0) @@ -235,7 +235,7 @@ TEST_F(HloModuleDceTest, OneWhileWithTupleElementUsedByCond) { // Tests that HloModuleDCE can remove a dead tuple element at index {1} between // two dependent while loops. TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body0 { loop_var.1 = (s32[], s32[3]{0}) parameter(0) @@ -303,7 +303,7 @@ TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElement) { // Tests that HloModuleDCE can remove a dead tuple element at while.1{0} and // while.2{1}, between two dependent while loops. TEST_F(HloModuleDceTest, TwoWhilesWithDeadTupleElementSwizzled) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule SimpleLoop SimpleLoop.body0 { loop_var.1 = (s32[3]{0}, s32[]) parameter(0) diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index b4cd3c730e323b8459312edbebc564e08f9d6840..bf33640db16638803f4f8e6c66f35d6bb6e2c9fe 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -87,6 +87,7 @@ Status HloModuleGroupMetadata::Build() { << "Peer instruction does not match the computation kind"; TF_RETURN_IF_ERROR( AddCompanion(tracked->instruction(), peer_tracked->instruction())); + tracked_instructions_comms_[tracked->instruction()].push_back(hlo); } // Add the parents of companion instructions (they must be all of the same @@ -112,27 +113,43 @@ Status HloModuleGroupMetadata::Build() { } } TF_RETURN_IF_ERROR(VerifyCompanionSets()); + if (VLOG_IS_ON(4)) { + DumpCollectedStats(); + } return Status::OK(); } Status HloModuleGroupMetadata::VerifyCompanionSets() const { - // TODO(dlibenzi): Migrate this to use the device instead of module ID, once - // the kDomain CL goes in. for (const auto& companions : companion_sets_) { // A companion set must be composed at most of an instruction per // device/module. std::unordered_set devices; for (HloInstruction* instruction : *companions) { - int64 device = GetModuleId(instruction->parent()->parent()); - if (!devices.insert(device).second) { - std::stringstream ss; - ss << "Companion set:" << std::endl; - for (HloInstruction* hlo : *companions) { - ss << " " << hlo->name() << " (" - << GetModuleId(hlo->parent()->parent()) << ")" << std::endl; + // Go through all the communicating instructions (send, recv) of the given + // companion, and record their device. + auto it = tracked_instructions_comms_.find(instruction); + if (it == tracked_instructions_comms_.end()) { + // Companions can be added even if they have no communicating + // instructions, if they are parent of companions. + continue; + } + std::unordered_set comm_devices; + for (HloInstruction* comm_instruction : it->second) { + auto device = GetInstructionDevice(*comm_instruction); + TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString() + << " does not have a device"; + comm_devices.insert(*device); + } + for (int64 device : comm_devices) { + if (!devices.insert(device).second) { + std::stringstream ss; + ss << "Companion set:" << std::endl; + for (HloInstruction* hlo : *companions) { + ss << " " << hlo->name() << std::endl; + } + ss << "has multiple instructions on the same device"; + return FailedPrecondition("%s", ss.str().c_str()); } - ss << "has multiple instructions on the same device"; - return FailedPrecondition("%s", ss.str().c_str()); } } } @@ -223,6 +240,28 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const { LOG(FATAL) << "unknown module"; } +tensorflow::gtl::optional HloModuleGroupMetadata::GetInstructionDevice( + const HloInstruction& instruction) const { + // The module group metadata can be created in both "single module, multiple + // devices" and "multiple modules, no explicit devices" fashions. + // The API returns an optional even though the current implementation always + // returns a device, to account for cases where we cannot guess a device. + // In such cases the VerifyChannelInstructions() will return proper errors. + tensorflow::gtl::optional device = + instruction.sharding_unique_device(); + if (!device) { + device = GetModuleId(instruction.parent()->parent()); + } + return device; +} + +int64 HloModuleGroupMetadata::GetDeviceModulesCount() const { + return std::count_if(modules_.begin(), modules_.end(), + [](const HloModule* module) { + return !module->config().is_host_module(); + }); +} + Status HloModuleGroupMetadata::RecordInstructions() { const auto visitor = [this](HloInstruction* hlo) -> Status { if (hlo->opcode() == HloOpcode::kWhile) { @@ -284,6 +323,7 @@ Status HloModuleGroupMetadata::RecordInstructions() { TF_RETURN_IF_ERROR(computation->Accept(visitor)); } } + VLOG(2) << "Created " << channels_.size() << " channels"; return Status::OK(); } @@ -346,26 +386,38 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() { if (!ShapeUtil::Compatible(send_shape, recv_shape)) { return FailedPrecondition("send/recv shapes do not match"); } - const HloModule* send_module = channel.send->parent()->parent(); - const HloModule* send_done_module = channel.send_done->parent()->parent(); - if (send_module != send_done_module) { + auto send_device = GetInstructionDevice(*channel.send); + auto send_done_device = GetInstructionDevice(*channel.send_done); + if (!send_device) { + return FailedPrecondition("send instruction must have a device: %s", + channel.send->ToString().c_str()); + } + if (!send_done_device) { + return FailedPrecondition("send_done instruction must have a device: %s", + channel.send_done->ToString().c_str()); + } + if (*send_device != *send_done_device) { return FailedPrecondition( "send and send-done (channel=%lld) must be on the same device: %lld " "vs. %lld", - channel.id, GetModuleId(send_module), GetModuleId(send_done_module)); + channel.id, *send_device, *send_done_device); } - const HloModule* recv_module = channel.recv->parent()->parent(); - const HloModule* recv_done_module = channel.recv_done->parent()->parent(); - if (recv_module != recv_done_module) { + auto recv_device = GetInstructionDevice(*channel.recv); + auto recv_done_device = GetInstructionDevice(*channel.recv_done); + if (!recv_done_device) { + return FailedPrecondition("recv_done instruction must have a device: %s", + channel.recv_done->ToString().c_str()); + } + if (*recv_device != *recv_done_device) { return FailedPrecondition( "recv and recv-done (channel=%lld) must be on the same device: %lld " "vs. %lld", - channel.id, GetModuleId(recv_module), GetModuleId(recv_done_module)); + channel.id, *recv_device, *recv_done_device); } - if (send_module == recv_module) { + if (*send_device == *recv_device) { return FailedPrecondition( "send and recv (channel=%lld) must be on different devices: %lld", - channel.id, GetModuleId(send_module)); + channel.id, *send_device); } } @@ -402,4 +454,36 @@ Status HloModuleGroupMetadata::CheckCommunicatingInstruction( return FailedPrecondition("channel is used in disallowed computation"); } +void HloModuleGroupMetadata::DumpCollectedStats() const { + std::map, int64> communication_histogram; + for (auto& channel : channels_) { + auto from_device = GetInstructionDevice(*channel.send); + auto to_device = GetInstructionDevice(*channel.recv); + LOG(INFO) << "Channel " << channel.id << ": from_device=" << *from_device + << " to_device=" << *to_device << " send=" << channel.send->name() + << " send_done=" << channel.send_done->name() + << " recv=" << channel.recv->name() + << " recv_done=" << channel.recv_done->name(); + communication_histogram[std::pair(*from_device, + *to_device)] += 1; + } + for (auto& fromto_count : communication_histogram) { + LOG(INFO) << "From " << fromto_count.first.first << " to " + << fromto_count.first.second << ": " << fromto_count.second; + } + for (auto& companion_set : companion_sets_) { + LOG(INFO) << "Companion set:"; + for (HloInstruction* instruction : *companion_set) { + LOG(INFO) << " " << instruction->name(); + } + } + for (auto& instruction_comm : tracked_instructions_comms_) { + LOG(INFO) << "Communicating instruction " << instruction_comm.first->name(); + for (HloInstruction* instruction : instruction_comm.second) { + auto device = GetInstructionDevice(*instruction); + LOG(INFO) << " " << instruction->name() << " on device " << *device; + } + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 3ef4542f9129632de4975688ae7e9e2c5f43a7ee..ffde3a332dfc141ca928a44cfdf4686900e9f47b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -148,6 +149,15 @@ class HloModuleGroupMetadata { // the module in the module vector. int64 GetModuleId(const HloModule* module) const; + // Retrieves the device an instruction is assigned to. Either from the + // sharding information, or from the ordinal of the module the instruction + // is in. + tensorflow::gtl::optional GetInstructionDevice( + const HloInstruction& instruction) const; + + // Returns the number of modules for devices (excluding the host module). + int64 GetDeviceModulesCount() const; + // Returns the companion instructions for the given instruction. // // Precondition: IsCompanionWhile(instruction) is true. @@ -220,6 +230,9 @@ class HloModuleGroupMetadata { return it != tracked_instructions_.end() ? &it->second : nullptr; } + // Dump all the collected module group statistics to the logs. + void DumpCollectedStats() const; + // List of all companion instructions sets in the module. std::vector>> companion_sets_; @@ -231,6 +244,11 @@ class HloModuleGroupMetadata { tensorflow::gtl::FlatMap tracked_instructions_; + // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of + // communicating instructions within the proper called computation(s). + tensorflow::gtl::FlatMap> + tracked_instructions_comms_; + // All channels in the module. std::vector channels_; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index ac7cd2f2f517cf8831416d9265fc48bbf9fce340..a35546f5f41b149d119ee141fd734da8bfd055b2 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -69,6 +69,7 @@ namespace xla { V(kCrossReplicaSum, "cross-replica-sum") \ V(kCustomCall, "custom-call") \ V(kDivide, "divide") \ + V(kDomain, "domain") \ V(kDot, "dot") \ V(kDynamicSlice, "dynamic-slice") \ V(kDynamicUpdateSlice, "dynamic-update-slice") \ @@ -80,6 +81,7 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kGenerateToken, "generate-token", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kHostCompute, "host-compute") \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index cd2ce5c69f030c65b889d67e082a3677b8739ddb..774345124b4ad62e35d9423a23f1dbaa28e44d80 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kConcatenate: case HloOpcode::kFusion: case HloOpcode::kMap: + case HloOpcode::kGenerateToken: case HloOpcode::kTuple: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index ee526d8dd7f7e81b3a846741d3e452935f486bd2..985f3fa64d8767b0c0063ee900f7d11c3b7f6d4a 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -183,6 +183,10 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: + // TODO(dimvar): HloModuleSequence is not a good name because it sounds like + // a sequence of modules, instead of a map of schedules for all computations + // in a module. We should change it at some point. + // // A sequence of instructions for each computation in the module. using HloModuleSequence = tensorflow::gtl::FlatMap module, - tools::Parse(module_str)); + ParseHloString(module_str)); DependencyHloOrdering ordering(module.get()); ordering.ToString(); // Shouldn't crash. } @@ -347,7 +347,7 @@ ENTRY root { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); TF_ASSERT_OK_AND_ASSIGN(auto dataflow, HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); DependencyHloOrdering ordering(module.get()); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc similarity index 86% rename from tensorflow/compiler/xla/tools/parser/hlo_parser.cc rename to tensorflow/compiler/xla/service/hlo_parser.cc index d0e7af8844203da93dac5b45cb7e13916448dd47..fef475380c5c810e1c4712406dde6b1135be3d97 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -24,18 +26,17 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { -namespace tools { namespace { -using tensorflow::StringPiece; -using tensorflow::gtl::optional; -using tensorflow::str_util::Join; -using tensorflow::str_util::Split; -using tensorflow::str_util::SplitAndParseAsInts; -using tensorflow::strings::Printf; -using tensorflow::strings::StrAppend; -using tensorflow::strings::StrCat; +using ::tensorflow::StringPiece; +using ::tensorflow::gtl::optional; +using ::tensorflow::str_util::Join; +using ::tensorflow::str_util::Split; +using ::tensorflow::str_util::SplitAndParseAsInts; +using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; const double kF16max = 65504; @@ -56,6 +57,11 @@ class HloParser { // Returns the error information. string GetError() const { return Join(error_, "\n"); } + // Stand alone parsing utils for various aggregate data types. + StatusOr ParseShardingOnly(); + StatusOr ParseWindowOnly(); + StatusOr ParseConvolutionDimensionNumbersOnly(); + private: // ParseXXX returns false if an error occurred. bool ParseHloModule(); @@ -78,11 +84,15 @@ class HloParser { // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. - bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal); - bool SetValueInLiteral(double value, int64 linear_index, Literal* literal); - bool SetValueInLiteral(bool value, int64 linear_index, Literal* literal); + bool SetValueInLiteral(tensorflow::int64 value, + tensorflow::int64 linear_index, Literal* literal); + bool SetValueInLiteral(double value, tensorflow::int64 linear_index, + Literal* literal); + bool SetValueInLiteral(bool value, tensorflow::int64 linear_index, + Literal* literal); template - bool SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, + bool SetValueInLiteralHelper(ParsedElemT value, + tensorflow::int64 linear_index, Literal* literal); bool ParseOperands(std::vector* operands); @@ -94,9 +104,15 @@ class HloParser { // Describes the start, limit, and stride on every dimension of the operand // being sliced. struct SliceRanges { - std::vector starts; - std::vector limits; - std::vector strides; + std::vector starts; + std::vector limits; + std::vector strides; + }; + + // The data parsed for the kDomain instruction. + struct DomainData { + std::unique_ptr entry_metadata; + std::unique_ptr exit_metadata; }; // Types of attributes. @@ -117,6 +133,7 @@ class HloParser { kMetadata, kFusionKind, kDistribution, + kDomain, }; struct AttrConfig { @@ -164,21 +181,27 @@ class HloParser { bool ParseComputationName(HloComputation** value); // Parses a list of names and finds the corresponding hlo instructions. bool ParseInstructionNames(std::vector* instructions); - bool ParseWindow(Window* window); + // Pass expect_outer_curlies == true when parsing a Window in the context of a + // larger computation. Pass false when parsing a stand-alone Window string. + bool ParseWindow(Window* window, bool expect_outer_curlies); bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); bool ParsePaddingConfig(PaddingConfig* padding); bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + // Parses the metadata behind a kDOmain instruction. + bool ParseDomain(DomainData* domain); + // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. - bool ParseDxD(const string& name, std::vector* result); + bool ParseDxD(const string& name, std::vector* result); // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. - bool ParseWindowPad(std::vector>* pad); + bool ParseWindowPad(std::vector>* pad); bool ParseSliceRanges(SliceRanges* result); bool ParseInt64List(const TokKind start, const TokKind end, - const TokKind delim, std::vector* result); + const TokKind delim, + std::vector* result); bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); @@ -190,7 +213,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); - bool ParseInt64(int64* result); + bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); @@ -384,6 +407,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { } *entry_computation = computation; } + instruction_pool_.clear(); return AddComputation(name, computation, name_loc); } @@ -447,7 +471,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { - int64 parameter_number; + tensorflow::int64 parameter_number; if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || !ParseInt64(¶meter_number) || @@ -563,11 +587,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { + optional to_apply; + optional> replica_group_ids; + optional barrier; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; + attrs["replica_group_ids"] = { + /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; + attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands)); + + if (replica_group_ids) { + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, *replica_group_ids, + barrier ? *barrier : "")); + } else { + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, {}, barrier ? *barrier : "")); + } break; } case HloOpcode::kReshape: { @@ -579,6 +620,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateReshape(shape, operands[0])); break; } + case HloOpcode::kGenerateToken: { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateGenerateToken(operands)); + break; + } case HloOpcode::kTuple: { if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; @@ -602,7 +651,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kRecv: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { @@ -613,7 +662,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kRecvDone: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -627,7 +676,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSend: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -638,7 +687,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kSendDone: { - optional channel_id; + optional channel_id; attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -652,7 +701,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGetTupleElement: { - optional index; + optional index; attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || !ParseAttributes(attrs)) { @@ -710,7 +759,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kFft: { optional fft_type; - optional> fft_length; + optional> fft_length; attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, &fft_length}; @@ -723,7 +772,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kBroadcast: { - optional> broadcast_dimensions; + optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &broadcast_dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -735,7 +784,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kConcatenate: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs) || @@ -750,6 +799,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional to_apply; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; + optional> dimensions; + attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, + &dimensions}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } @@ -761,7 +813,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional reduce_computation; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; - optional> dimensions_to_reduce; + optional> dimensions_to_reduce; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -774,7 +826,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReverse: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -818,7 +870,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kDynamicSlice: { - optional> dynamic_slice_sizes; + optional> dynamic_slice_sizes; attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; if (!ParseOperands(&operands, /*expected_size=*/2) || @@ -842,7 +894,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kTranspose: { - optional> dimensions; + optional> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if (!ParseOperands(&operands, /*expected_size=*/1) || @@ -856,7 +908,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormTraining: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/3) || @@ -872,7 +924,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormInference: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -889,7 +941,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kBatchNormGrad: { optional epsilon; attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional feature_index; + optional feature_index; attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, &feature_index}; if (!ParseOperands(&operands, /*expected_size=*/5) || @@ -960,8 +1012,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReducePrecision: { - optional exponent_bits; - optional mantissa_bits; + optional exponent_bits; + optional mantissa_bits; attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, &exponent_bits}; attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, @@ -1006,7 +1058,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kHostCompute: { optional channel_name; - optional cost_estimate_ns; + optional cost_estimate_ns; attrs["channel_name"] = {/*required=*/true, AttrTy::kString, &channel_name}; attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, @@ -1019,16 +1071,16 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kDot: { - optional> lhs_contracting_dims; + optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims}; - optional> rhs_contracting_dims; + optional> rhs_contracting_dims; attrs["rhs_contracting_dims"] = { /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims}; - optional> lhs_batch_dims; + optional> lhs_batch_dims; attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &lhs_batch_dims}; - optional> rhs_batch_dims; + optional> rhs_batch_dims; attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, &rhs_batch_dims}; @@ -1060,20 +1112,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kGather: { - optional> output_window_dims; + optional> output_window_dims; attrs["output_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; - optional> elided_window_dims; + optional> elided_window_dims; attrs["elided_window_dims"] = { /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; - optional> gather_dims_to_operand_dims; + optional> gather_dims_to_operand_dims; attrs["gather_dims_to_operand_dims"] = {/*required=*/true, AttrTy::kBracedInt64List, &gather_dims_to_operand_dims}; - optional index_vector_dim; + optional index_vector_dim; attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, &index_vector_dim}; - optional> window_bounds; + optional> window_bounds; attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, &window_bounds}; @@ -1093,12 +1145,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, dim_numbers, *window_bounds)); break; } + case HloOpcode::kDomain: { + DomainData domain; + attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateDomain( + shape, operands[0], std::move(domain.entry_metadata), + std::move(domain.exit_metadata))); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } - instruction->set_name(name); + instruction->SetAndSanitizeName(name); + if (instruction->name() != name) { + return Error(name_loc, + StrCat("illegal instruction name: ", name, + "; suggest renaming to: ", instruction->name())); + } // Add shared attributes like metadata to the instruction, if they were seen. if (sharding) { @@ -1118,7 +1187,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, instruction->set_metadata(*metadata); } if (backend_config) { - instruction->set_backend_config(std::move(*backend_config)); + instruction->set_raw_backend_config_string(std::move(*backend_config)); } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -1169,8 +1238,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; - std::vector devices; - std::vector tile_assignment_dimensions; + std::vector devices; + std::vector tile_assignment_dimensions; Shape tile_shape; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { @@ -1197,7 +1266,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } do { - int64 dim; + tensorflow::int64 dim; if (!ParseInt64(&dim)) { return false; } @@ -1209,7 +1278,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return false; } do { - int64 device; + tensorflow::int64 device; if (!ParseInt64(&device)) { return false; } @@ -1268,10 +1337,10 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); *sharding->mutable_tile_shape() = tile_shape; - for (int64 dim : tile_assignment_dimensions) { + for (tensorflow::int64 dim : tile_assignment_dimensions) { sharding->add_tile_assignment_dimensions(dim); } - for (int64 device : devices) { + for (tensorflow::int64 device : devices) { sharding->add_tile_assignment_devices(device); } } @@ -1280,6 +1349,34 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return true; } +// domain ::= '{' 'kind=' domain_kind ',' 'entry=' entry_sharding ',' +// 'exit=' exit_sharding '}' +bool HloParser::ParseDomain(DomainData* domain) { + std::unordered_map attrs; + optional kind; + optional entry_sharding; + optional exit_sharding; + attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind}; + attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding}; + attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding}; + if (!ParseSubAttributes(attrs)) { + return false; + } + if (*kind == ShardingMetadata::KindName()) { + auto entry_sharding_ptr = MakeUnique( + HloSharding::FromProto(*entry_sharding).ValueOrDie()); + auto exit_sharding_ptr = MakeUnique( + HloSharding::FromProto(*exit_sharding).ValueOrDie()); + domain->entry_metadata = + MakeUnique(std::move(entry_sharding_ptr)); + domain->exit_metadata = + MakeUnique(std::move(exit_sharding_ptr)); + } else { + return TokenError(StrCat("unsupported domain kind: ", *kind)); + } + return true; +} + // '{' name+ '}' bool HloParser::ParseInstructionNames( std::vector* instructions) { @@ -1306,40 +1403,50 @@ bool HloParser::ParseInstructionNames( "expects '}' at the end of instruction name list"); } -bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, +bool HloParser::SetValueInLiteral(tensorflow::int64 value, + tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case S8: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case S64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U8: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U32: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case U64: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); default: LOG(FATAL) << "unknown integral primitive type " << PrimitiveType_Name(shape.element_type()); } } -bool HloParser::SetValueInLiteral(double value, int64 linear_index, +bool HloParser::SetValueInLiteral(double value, tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { case F16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, literal); case BF16: - return SetValueInLiteralHelper(value, linear_index, literal); + return SetValueInLiteralHelper(value, linear_index, + literal); case F32: return SetValueInLiteralHelper(value, linear_index, literal); case F64: @@ -1350,7 +1457,7 @@ bool HloParser::SetValueInLiteral(double value, int64 linear_index, } } -bool HloParser::SetValueInLiteral(bool value, int64 linear_index, +bool HloParser::SetValueInLiteral(bool value, tensorflow::int64 linear_index, Literal* literal) { const Shape& shape = literal->shape(); switch (shape.element_type()) { @@ -1363,7 +1470,8 @@ bool HloParser::SetValueInLiteral(bool value, int64 linear_index, } template -bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, +bool HloParser::SetValueInLiteralHelper(ParsedElemT value, + tensorflow::int64 linear_index, Literal* literal) { // Check that linear_index is in range. if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) { @@ -1475,7 +1583,7 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape) { - const int64 rank = ShapeUtil::Rank(shape); + const tensorflow::int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; } @@ -1483,8 +1591,8 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // Create a literal with the given shape in default layout. *literal = Literal::CreateFromDimensions(shape.element_type(), AsInt64Slice(shape.dimensions())); - int64 nest_level = 0; - int64 linear_index = 0; + tensorflow::int64 nest_level = 0; + tensorflow::int64 linear_index = 0; // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}}, // when we are parsing the 2nd '{' (right before '1'), we are seeing a @@ -1492,14 +1600,14 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, // the first '}' (right after '3'), it means the sub-array ends, and the // sub-array is supposed to contain exactly 3 elements, so check if // elems_seen_per_dim[1] is 3. - std::vector elems_seen_per_dim(rank); + std::vector elems_seen_per_dim(rank); auto get_index_str = [&elems_seen_per_dim](int dim) -> string { - std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), - elems_seen_per_dim.begin() + dim); + std::vector elems_seen_until_dim( + elems_seen_per_dim.begin(), elems_seen_per_dim.begin() + dim); return StrCat("[", Join(elems_seen_until_dim, ",", - [](string* out, const int64& num_elems) { - tensorflow::strings::StrAppend(out, num_elems - 1); + [](string* out, const tensorflow::int64& num_elems) { + StrAppend(out, num_elems - 1); }), "]"); }; @@ -1575,7 +1683,7 @@ bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { LocTy loc = lexer_.GetLoc(); - int64 value; + tensorflow::int64 value; if (!ParseInt64(&value)) { return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); @@ -1615,29 +1723,29 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, switch (shape.element_type()) { case PRED: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S8: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S32: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case S64: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U8: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U32: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case U64: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F32: return ParseSparseLiteralHelper(literal, shape); case BF16: - return ParseSparseLiteralHelper(literal, shape); + return ParseSparseLiteralHelper(literal, shape); case F64: return ParseSparseLiteralHelper(literal, shape); default: @@ -1650,9 +1758,9 @@ bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, template bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, const Shape& shape) { - std::vector index; + std::vector index; - int64 rank = ShapeUtil::Rank(shape); + tensorflow::int64 rank = ShapeUtil::Rank(shape); *literal = MakeUnique(shape); @@ -1670,7 +1778,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, LocTy index_loc = lexer_.GetLoc(); index.clear(); if (lexer_.GetKind() == TokKind::kInt) { - int64 single_index = lexer_.GetInt64Val(); + tensorflow::int64 single_index = lexer_.GetInt64Val(); lexer_.Lex(); if (rank != 1) { return Error( @@ -1703,7 +1811,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, value = static_cast(lexer_.GetKind() == TokKind::kw_true); lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { - int64 value_s64; + tensorflow::int64 value_s64; if (!ParseInt64(&value_s64)) { return Error(value_loc, StrCat("expects integer for primitive type: ", @@ -1876,23 +1984,24 @@ bool HloParser::ParseAttributeHelper( LocTy attr_loc = lexer_.GetLoc(); switch (attr_type) { case AttrTy::kInt64: { - int64 result; + tensorflow::int64 result; if (!ParseInt64(&result)) { return false; } - static_cast*>(attr_out_ptr)->emplace(result); + static_cast*>(attr_out_ptr) + ->emplace(result); return true; } case AttrTy::kInt32: { - int64 result; + tensorflow::int64 result; if (!ParseInt64(&result)) { return false; } - if (result != static_cast(result)) { + if (result != static_cast(result)) { return Error(attr_loc, "value out of range for int32"); } - static_cast*>(attr_out_ptr) - ->emplace(static_cast(result)); + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); return true; } case AttrTy::kFloat: { @@ -1926,7 +2035,7 @@ bool HloParser::ParseAttributeHelper( } case AttrTy::kWindow: { Window result; - if (!ParseWindow(&result)) { + if (!ParseWindow(&result, /*expect_outer_curlies=*/true)) { return false; } static_cast*>(attr_out_ptr)->emplace(result); @@ -1968,12 +2077,12 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kBracedInt64List: { - std::vector result; + std::vector result; if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, &result)) { return false; } - static_cast>*>(attr_out_ptr) + static_cast>*>(attr_out_ptr) ->emplace(result); return true; } @@ -2018,6 +2127,9 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kDomain: { + return ParseDomain(static_cast(attr_out_ptr)); + } } }(); if (!success) { @@ -2044,9 +2156,10 @@ bool HloParser::ParseComputationName(HloComputation** value) { // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' // The subattributes can appear in any order. 'size=' is required, others are // optional. -bool HloParser::ParseWindow(Window* window) { +bool HloParser::ParseWindow(Window* window, bool expect_outer_curlies) { LocTy loc = lexer_.GetLoc(); - if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { + if (expect_outer_curlies && + !ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { return false; } @@ -2056,7 +2169,9 @@ bool HloParser::ParseWindow(Window* window) { std::vector lhs_dilate; std::vector rhs_dilate; std::vector rhs_reversal; - while (lexer_.GetKind() != TokKind::kRbrace) { + const auto end_token = + expect_outer_curlies ? TokKind::kRbrace : TokKind::kEof; + while (lexer_.GetKind() != end_token) { LocTy attr_loc = lexer_.GetLoc(); string field_name; if (!ParseAttributeName(&field_name)) { @@ -2120,7 +2235,8 @@ bool HloParser::ParseWindow(Window* window) { window->mutable_dimensions(i)->set_window_reversal( rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); } - return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); + return !expect_outer_curlies || + ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); } // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. @@ -2144,7 +2260,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( << str; } - const int64 rank = lhs_rhs_out[0].length(); + const tensorflow::int64 rank = lhs_rhs_out[0].length(); if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { return TokenError( "convolution lhs, rhs, and output must have the same rank"); @@ -2258,7 +2374,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { return false; } - std::vector> ranges; + std::vector> ranges; if (lexer_.GetKind() == TokKind::kRbrace) { // empty return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); @@ -2292,7 +2408,7 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { // ::= int64_val (delim int64_val)* bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, - std::vector* result) { + std::vector* result) { if (!ParseToken(start, StrCat("expects an int64 list starting with ", TokKindToString(start)))) { return false; @@ -2301,7 +2417,7 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, // empty } else { do { - int64 i; + tensorflow::int64 i; if (!ParseInt64(&i)) { return false; } @@ -2418,7 +2534,8 @@ bool HloParser::ParseString(string* result) { return true; } -bool HloParser::ParseDxD(const string& name, std::vector* result) { +bool HloParser::ParseDxD(const string& name, + std::vector* result) { LocTy loc = lexer_.GetLoc(); if (!result->empty()) { return Error(loc, @@ -2426,7 +2543,7 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { } // 1D if (lexer_.GetKind() == TokKind::kInt) { - int64 number; + tensorflow::int64 number; if (!ParseInt64(&number)) { return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); } @@ -2446,7 +2563,8 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { return TokenError("expects token type kInt or kDxD"); } -bool HloParser::ParseWindowPad(std::vector>* pad) { +bool HloParser::ParseWindowPad( + std::vector>* pad) { LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { return Error(loc, "sub-attribute 'pad=' already exists"); @@ -2457,7 +2575,7 @@ bool HloParser::ParseWindowPad(std::vector>* pad) { string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (int i = 0; i < padding_str.size(); i++) { - std::vector low_high; + std::vector low_high; if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || low_high.size() != 2) { return Error(loc, @@ -2481,7 +2599,7 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (const auto& padding_dim_str : padding_str) { - std::vector padding_dim; + std::vector padding_dim; if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { return Error(loc, @@ -2503,7 +2621,7 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { optional op_type; optional op_name; optional source_file; - optional source_line; + optional source_line; attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; @@ -2590,7 +2708,7 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { return true; } -bool HloParser::ParseInt64(int64* result) { +bool HloParser::ParseInt64(tensorflow::int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { return TokenError("expects integer"); @@ -2673,10 +2791,48 @@ bool HloParser::AddComputation(const string& name, HloComputation* computation, return true; } +StatusOr HloParser::ParseShardingOnly() { + lexer_.Lex(); + OpSharding op_sharding; + if (!ParseSharding(&op_sharding)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after sharding"); + } + return HloSharding::FromProto(op_sharding); +} + +StatusOr HloParser::ParseWindowOnly() { + lexer_.Lex(); + Window window; + if (!ParseWindow(&window, /*expect_outer_curlies=*/false)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument("Syntax error:\nExtra content after window"); + } + return window; +} + +StatusOr +HloParser::ParseConvolutionDimensionNumbersOnly() { + lexer_.Lex(); + ConvolutionDimensionNumbers dnums; + if (!ParseConvolutionDimensionNumbers(&dnums)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + if (lexer_.GetKind() != TokKind::kEof) { + return InvalidArgument( + "Syntax error:\nExtra content after convolution dnums"); + } + return dnums; +} + } // namespace -StatusOr> Parse(StringPiece str, - const HloModuleConfig& config) { +StatusOr> ParseHloString( + tensorflow::StringPiece str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); @@ -2684,10 +2840,29 @@ StatusOr> Parse(StringPiece str, return parser.ConsumeHloModule(); } -StatusOr> Parse(StringPiece str) { +StatusOr> ParseHloString( + tensorflow::StringPiece str) { + HloModuleConfig config; + return ParseHloString(str, config); +} + +StatusOr ParseSharding(tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseShardingOnly(); +} + +StatusOr ParseWindow(tensorflow::StringPiece str) { + HloModuleConfig config; + HloParser parser(str, config); + return parser.ParseWindowOnly(); +} + +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str) { HloModuleConfig config; - return Parse(str, config); + HloParser parser(str, config); + return parser.ParseConvolutionDimensionNumbersOnly(); } -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h similarity index 52% rename from tensorflow/compiler/xla/tools/parser/hlo_parser.h rename to tensorflow/compiler/xla/service/hlo_parser.h index 2f97a2b9b19d0cdb64a2869913da62c55e14c1d5..3f3a51215e34bbdd667f1cb20d0ae968e0ce5efd 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -13,30 +13,47 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_lexer.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { -namespace tools { + +// For details about the syntax accepted by this parser, see +// g3doc/hlo_parser.md. // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with the given config. -StatusOr> Parse(tensorflow::StringPiece str, - const HloModuleConfig& config); +StatusOr> ParseHloString( + tensorflow::StringPiece str, const HloModuleConfig& config); // The api of the hlo parser. Given a string in the HloModule::ToString() // format, parses the string and creates a HloModule with default config. -StatusOr> Parse(tensorflow::StringPiece str); +StatusOr> ParseHloString( + tensorflow::StringPiece str); + +// Parses the result of HloSharding::ToString(), e.g. "{replicated}". +StatusOr ParseSharding(tensorflow::StringPiece str); + +// Parses the result of window_util::ToString(const Window&). +StatusOr ParseWindow(tensorflow::StringPiece str); + +// Parses the result of ConvolutionDimensionNumbersToString(), e.g. +// "b0f_0io->b0f". +StatusOr ParseConvolutionDimensionNumbers( + tensorflow::StringPiece str); + +// ParseHloString sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string. +StatusOr ParseSharding(tensorflow::StringPiece str); -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc similarity index 88% rename from tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc rename to tensorflow/compiler/xla/service/hlo_parser_test.cc index 131aded95ab04c4327c275ed8cd18b8fc7ac1bd6..f834d34d57106b11cf398f966d8c0224f00d1b8d 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -13,19 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace xla { -namespace tools { + namespace { -using tensorflow::StringPiece; +using ::tensorflow::StringPiece; struct TestData { string test_name; @@ -233,6 +234,17 @@ ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f3 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}} } +)" +}, +{ +"DomainParsing", +R"(HloModule DomainParsing_module + +ENTRY %DomainParsing (v1: f32[]) -> f32[] { + %v1 = f32[] parameter(0) + ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} +} + )" }, // int32 result = 0; @@ -753,7 +765,7 @@ add_F32.v3 { ENTRY MapBinaryAdder.v3 { param0 = f32[4]{0} parameter(0) param1 = f32[4]{0} parameter(1) - ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3 + ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3 } )" @@ -888,6 +900,42 @@ ENTRY Gather { )" }, +// cross-replica-sum +{ +"CrossReplicaSum", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + ROOT crs = f32[8]{0} cross-replica-sum(input), to_apply=add +} + +)" +}, +// cross-replica-sum with subgroups +{ +"CrossReplicaSumWithSubgroups", +R"(HloModule CRS_Subgroups + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CrossReplicaSumWithSubgroups { + input = f32[128,32]{0,1} parameter(0) + ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), to_apply=add, replica_group_ids={0,0,1,1}, barrier="abc" +} + +)" +} }); // clang-format on } @@ -900,12 +948,12 @@ class HloParserTest : public ::testing::Test, << "'" << s << "' does not contain '" << expected << "'"; } - // Expects "ToString(Parse(string)) == string", that is, parses the string, - // asserts that it succeeded, stringifies the parsed module, and checks that - // the it equals the original string. + // Expects "ToString(ParseHloString(string)) == string", that is, parses the + // string, asserts that it succeeded, stringifies the parsed module, and + // checks that the it equals the original string. void ExpectEqual() { const string& original = GetParam().module_string; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString( HloPrintOptions().set_print_large_constants(true))); @@ -916,7 +964,7 @@ class HloParserShortTest : public HloParserTest { protected: void ExpectEqualShort() { const string& original = GetParam().module_string; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ(original, result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); @@ -937,13 +985,13 @@ INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, TEST_F(HloParserTest, Empty) { const string original = ""; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } TEST_F(HloParserTest, Garbage) { const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -957,7 +1005,7 @@ ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -969,7 +1017,7 @@ ENTRY %blabla (x: g32[]) -> g32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -982,7 +1030,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -993,7 +1041,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); } @@ -1008,7 +1056,7 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_EXPECT_OK(result.status()); // Constant instructions have no name. The string will be parsed successfully // but the constant names will not be exactly the same. @@ -1019,12 +1067,12 @@ TEST_F(HloParserTest, ConfigurationField) { ENTRY %configuration_test() -> s32[] { %constant = s32[] constant(42), backend_config="foo bar" })"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_ASSERT_OK(result.status()); EXPECT_EQ("foo bar", result.ValueOrDie() ->entry_computation() ->root_instruction() - ->backend_config()); + ->raw_backend_config_string()); } TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { @@ -1035,7 +1083,7 @@ ENTRY %some_2 () -> f32[2] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 1, but sees larger"); @@ -1049,7 +1097,7 @@ ENTRY %some_2x3 () -> f32[2,3] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects nested array in rank 2, but sees 1"); @@ -1063,7 +1111,7 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "expects 3 elements in the [0]th element"); @@ -1078,7 +1126,7 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); EXPECT_NE(Status::OK(), result.status()); ExpectHasSubstr(result.status().error_message(), "is out of range for literal's primitive type F16"); @@ -1092,7 +1140,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { } )"; - auto result = Parse(original); + auto result = ParseHloString(original); TF_EXPECT_OK(result.status()); // The string will be parsed successfully but the output strings are not // exactly the same, because "3e2" is parsed into value 300 and will be @@ -1110,7 +1158,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, InvalidDimLabels) { @@ -1126,17 +1174,18 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; + ExpectHasSubstr(ParseHloString(tensorflow::strings::StrCat( + prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); + ExpectHasSubstr( - Parse(tensorflow::strings::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + ParseHloString(tensorflow::strings::StrCat( + prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), - "expects dim labels pattern"); - - ExpectHasSubstr(Parse(tensorflow::strings::StrCat( - prefix, ",dim_labels=010_1100->010", suffix)) - .status() - .error_message(), - "must have the same rank"); + "must have the same rank"); } TEST_F(HloParserTest, UnexpectedAttribute) { @@ -1151,7 +1200,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "unexpected attribute \"calls\""); } @@ -1167,7 +1216,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "attribute channel_id is expected but not seen"); } @@ -1183,7 +1232,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "'done' is not defined"); } @@ -1196,7 +1245,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { @@ -1210,7 +1259,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } )"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "expects padding_low and padding_high separated by '_'"); } @@ -1222,7 +1271,7 @@ ENTRY %test_comma.v4 () -> f32[] { } )"; - TF_EXPECT_OK(Parse(original).status()); + TF_EXPECT_OK(ParseHloString(original).status()); } TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) { @@ -1232,7 +1281,7 @@ ENTRY %CustomCall () -> f32[1] { %constant = f32[1]{0} constant({12345}) ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar" })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "Shape of computation CustomCall, f32[1], is not compatible " "with that of its root instruction foo, f32[1,2,3]"); } @@ -1251,7 +1300,7 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); auto program_layout = module.ValueOrDie()->host_entry_computation_layout(); ASSERT_EQ(program_layout.parameter_count(), 1); @@ -1274,7 +1323,7 @@ c1 { c2 { const2 = f32[1]{0} constant({67890}) })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2"); } @@ -1285,7 +1334,7 @@ ENTRY consts { first = f32[1]{0} constant({12345}) last = f32[1]{0} constant({67890}) })"; - auto module = Parse(original); + auto module = ParseHloString(original); TF_ASSERT_OK(module.status()); EXPECT_EQ( module.ValueOrDie()->entry_computation()->root_instruction()->name(), @@ -1300,7 +1349,7 @@ ENTRY c1 { ENTRY c2 { const2 = f32[1]{0} constant({67890}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "expects only one ENTRY"); } @@ -1310,25 +1359,10 @@ ENTRY consts { ROOT const1 = f32[1]{0} constant({12345}) ROOT const2 = f32[1]{0} constant({12345}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), "one computation should have only one ROOT"); } -TEST_F(HloParserTest, InstructionExists) { - const string original = R"(HloModule comp_exists -c1 { - instr = f32[1]{0} constant({12345}) -} -c2 { - instr = f32[1]{0} constant({67890}) -})"; - - ExpectHasSubstr(Parse(original).status().error_message(), - R"(was parsing 3:3: error: instruction previously defined here - instr = f32[1]{0} constant({12345}) - ^)"); -} - TEST_F(HloParserTest, ComputationExists) { const string original = R"(HloModule comp_exists comp { @@ -1337,12 +1371,52 @@ comp { comp { const2 = f32[1]{0} constant({67890}) })"; - ExpectHasSubstr(Parse(original).status().error_message(), + ExpectHasSubstr(ParseHloString(original).status().error_message(), R"(was parsing 2:1: error: computation previously defined here comp { ^)"); } +TEST_F(HloParserTest, CrossComputationLookup) { + const string original = R"(HloModule cross_computation_lookup: +tcalla (a: (s32[], s32[])) -> (s32[], s32[]) { + ROOT aparam = (s32[], s32[]) parameter(0) +} + +tcallb (b: (s32[], s32[])) -> s32[] { + rparam = (s32[], s32[]) parameter(0) + ROOT gte0 = s32[] get-tuple-element(aparam), index=0 +} + +ENTRY entry { + param = (s32[], s32[]) parameter(0) + call0 = (s32[], s32[]) call(param), to_apply=tcalla + ROOT call1 = s32[] call(param), to_apply=tcallb +})"; + ExpectHasSubstr( + ParseHloString(original).status().error_message(), + "was parsing 8:39: error: instruction does not exist: aparam"); +} + +TEST_F(HloParserTest, ParseSharding) { + const string original = "{maximal device=42}"; + TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); + EXPECT_EQ(sharding.ToString(), original); +} + +TEST_F(HloParserTest, ParseWindow) { + Window original = window_util::MakeWindow({1, 2, 3}); + TF_ASSERT_OK_AND_ASSIGN(Window parsed, + ParseWindow(window_util::ToString(original))) + EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed)); +} + +TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) { + const string original = "b0f_0io->b0f"; + TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums, + ParseConvolutionDimensionNumbers(original)); + EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums)); +} + } // namespace -} // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 8e167633bb13476301fa0c4afa0b123c9b47e40d..4738e46f8aeb96a4c25d04b3246bd21f644fe3ea 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -33,17 +33,27 @@ bool HloReachabilityMap::SetReachabilityToUnion( const HloInstruction* instruction) { BitVector& bit_vector = GetBitVector(instruction); tmp_bit_vector_ = bit_vector; + SetReachabilityToUnionHelper(inputs, instruction, &bit_vector); + return bit_vector != tmp_bit_vector_; +} +void HloReachabilityMap::FastSetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction) { + SetReachabilityToUnionHelper(inputs, instruction, &GetBitVector(instruction)); +} + +void HloReachabilityMap::SetReachabilityToUnionHelper( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction, BitVector* bit_vector) { // If instruction is part of inputs, don't reset the bit_vector. if (std::find(inputs.begin(), inputs.end(), instruction) == inputs.end()) { - bit_vector.SetToZero(); + bit_vector->SetToZero(); } - bit_vector.Set(GetIndex(instruction)); + bit_vector->Set(GetIndex(instruction)); for (const HloInstruction* input : inputs) { - bit_vector.OrWith(GetBitVector(input)); + bit_vector->OrWith(GetBitVector(input)); } - - return bit_vector != tmp_bit_vector_; } void HloReachabilityMap::SetReachable(const HloInstruction* a, diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 553ec11f6f9a2997ab7113f9b8241e04c7fe20d5..69bb2b3cee6dafe058c45b4e74e93401bea2cfc9 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -57,6 +57,11 @@ class HloReachabilityMap { tensorflow::gtl::ArraySlice inputs, const HloInstruction* instruction); + // As above, but faster because it does not check if the reachability changed. + void FastSetReachabilityToUnion( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction); + // Sets entry so that IsReachable(a, b) will return true // // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency @@ -133,6 +138,11 @@ class HloReachabilityMap { return bit_vectors_[GetIndex(instruction)]; } + // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. + void SetReachabilityToUnionHelper( + tensorflow::gtl::ArraySlice inputs, + const HloInstruction* instruction, BitVector* bit_vector); + // Return the index of the given instruction. The value is used to index into // the vector of BitVectors and the BitVectors themselves. int GetIndex(const HloInstruction* instruction) const { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 39b85de0f12024f5e20ddd37618987c6d06bc307..9c7bc7a5ea7c77dadb8772f08b823c3579cf2154 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -71,6 +71,20 @@ bool IsRematerializable(const HloInstruction* instruction) { } } +// Checks whether an instruction can be rematerialized, by looking up the +// cache before, and eventually calling the IsRematerializable() API. +bool CanBeRematerialized( + const HloInstruction* instruction, + tensorflow::gtl::FlatMap* remat_able) { + auto it = remat_able->find(instruction); + if (it != remat_able->end()) { + return it->second; + } + bool rematerializable = IsRematerializable(instruction); + (*remat_able)[instruction] = rematerializable; + return rematerializable; +} + // Type holding a unique identifier for each Buffer object. using BufferId = int64; using BufferIdList = tensorflow::gtl::InlinedVector; @@ -843,9 +857,10 @@ int64 RematerializationCost(const HloInstruction* instruction, // candidate which reduce memory use at the program point of the current // instruction as indicated by memory_tracker. nullptr is returned if no // candidate can be found. -Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, - const InstructionList& instruction_list, - int64 memory_limit_bytes) { +Item* PickRematerializationCandidate( + const MemoryUsageTracker& memory_tracker, + const InstructionList& instruction_list, int64 memory_limit_bytes, + tensorflow::gtl::FlatMap* remat_able) { Item* best_item = nullptr; int64 best_cost = 0; @@ -869,8 +884,7 @@ Item* PickRematerializationCandidate(const MemoryUsageTracker& memory_tracker, << " is excluded from rematerialization"; continue; } - - if (!IsRematerializable(candidate)) { + if (!CanBeRematerialized(candidate, remat_able)) { VLOG(5) << "candidate " << candidate->name() << " not viable: is not rematerializable"; continue; @@ -974,6 +988,9 @@ StatusOr HloRematerialization::RematerializeComputation( // blacklist. tensorflow::gtl::FlatSet remat_move_instructions; + // The map from instructions to their rematerializable status. + tensorflow::gtl::FlatMap remat_able; + // The peak memory of the computation at any point in the instruction // sequence. int64 peak_memory = memory_tracker.memory_usage(); @@ -1011,7 +1028,7 @@ StatusOr HloRematerialization::RematerializeComputation( << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); Item* best_item = PickRematerializationCandidate( - memory_tracker, instruction_list, memory_limit_bytes); + memory_tracker, instruction_list, memory_limit_bytes, &remat_able); if (best_item == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -1213,7 +1230,7 @@ StatusOr HloRematerialization::Run( XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( + TF_ASSIGN_OR_RETURN(*sequence, ScheduleComputationsInModule( *module, [this](const BufferValue& buffer) { return size_function_(buffer.shape()); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 83de54f3fa56ee660b79d8c366dbc0b52f9fde87..e81334d5a84268a129cd4e90091e97dc23243226 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -40,7 +41,8 @@ class HloRematerializationTest : public HloTestBase { // Creates and returns a computation which can benefit from // rematerialization. The computation looks like: // - // F32[] %param = {...} + // F32[1] %param = {...} + // F32[] %reshape = reshape(F32[], param) // F32[1024] %bcast = broadcast(%param) // F32[1024] %negate = negate(%bcast) // F32[2048] %concat_1 = concat({%negate, %negate}) @@ -57,9 +59,11 @@ class HloRematerializationTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast)); auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate( @@ -100,9 +104,11 @@ class HloRematerializationTest : public HloTestBase { const string& suffix = "") { auto builder = HloComputation::Builder(TestName() + suffix); auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, param, {})); + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); auto slice_1 = builder.AddInstruction( HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0}, /*limit_indices=*/{1}, @@ -135,6 +141,15 @@ class HloRematerializationTest : public HloTestBase { return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } + StatusOr RunHloRematerialization( + int64 memory_limit_bytes, HloModule* module, + SequentialHloOrdering::HloModuleSequence* sequence) { + TF_EXPECT_OK(verifier().Run(module).status()); + return HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler, + sequence); + } + // Various shapes used in the canned computations. const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); @@ -158,11 +173,9 @@ TEST_F(HloRematerializationTest, SingleComputation) { SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/14 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/14 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -188,18 +201,16 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { HloComputation* computation = module->AddEntryComputation(MakeRematerializableComputation()); - EXPECT_EQ(computation->instruction_count(), 7); + EXPECT_EQ(computation->instruction_count(), 8); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/20 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/20 * 1024, + module.get(), &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); - EXPECT_EQ(computation->instruction_count(), 7); + EXPECT_EQ(computation->instruction_count(), 8); } // Test rematerialization of a computation which calls another computation via a @@ -225,23 +236,21 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/body_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); // The body computation uses 16KB and the entry computation uses 2KB at the // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/17 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/17 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 8); + EXPECT_EQ(body_computation->instruction_count(), 8); } // Test rematerialization of a computation which calls another computation via a @@ -264,20 +273,18 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/body_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(body_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(body_computation->instruction_count(), 8); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/15 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/15 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); - // Both computations should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(body_computation->instruction_count(), 8); + // Both computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(body_computation->instruction_count(), 9); } // Test rematerialization of a doubly nested computation. All computations @@ -303,24 +310,22 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { module->AddEntryComputation(MakeRematerializableWhileComputation( while_cond, /*while_body=*/middle_computation)); - EXPECT_EQ(entry_computation->instruction_count(), 6); - EXPECT_EQ(middle_computation->instruction_count(), 6); - EXPECT_EQ(inner_computation->instruction_count(), 7); + EXPECT_EQ(entry_computation->instruction_count(), 7); + EXPECT_EQ(middle_computation->instruction_count(), 7); + EXPECT_EQ(inner_computation->instruction_count(), 8); // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/13 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/13 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); - // All computations should have a rematerialized instruction added. - EXPECT_EQ(entry_computation->instruction_count(), 7); - EXPECT_EQ(middle_computation->instruction_count(), 7); - EXPECT_EQ(inner_computation->instruction_count(), 8); + // All computations should have rematerialized instructions added. + EXPECT_EQ(entry_computation->instruction_count(), 9); + EXPECT_EQ(middle_computation->instruction_count(), 9); + EXPECT_EQ(inner_computation->instruction_count(), 9); } TEST_F(HloRematerializationTest, RngNotRematerialized) { @@ -382,10 +387,9 @@ TEST_F(HloRematerializationTest, RngNotRematerialized) { // parameter and output) and 20KB (peak memory possible with // rematerialization). TF_ASSERT_OK_AND_ASSIGN( - bool changed, HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, + bool changed, RunHloRematerialization( /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), - module.get(), DefaultMemoryScheduler, &sequence)); + module.get(), &sequence)); EXPECT_TRUE(changed); // The rng should not have been rematerialized. EXPECT_EQ(count_rngs(entry_computation), 1); @@ -476,11 +480,9 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, + module.get(), &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -573,11 +575,9 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), - DefaultMemoryScheduler, &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloRematerialization( + /*memory_limit_bytes=*/22 * 1024, + module.get(), &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 2a601ec3d183023954b6f1b6bca7594384378169..e1f9d8efd4974055947438c8a2e15cb77d1b5c75 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -22,9 +22,9 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -36,7 +36,7 @@ HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, const DebugOptions& debug_options) { HloModuleConfig config; config.set_debug_options(debug_options); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } namespace { @@ -80,7 +80,7 @@ HloRunner::ReadModuleFromHloTextFile(const std::string& filename, filename, &hlo_string)); HloModuleConfig config; config.set_debug_options(debug_options); - return tools::Parse(hlo_string, config); + return ParseHloString(hlo_string, config); } HloRunner::HloRunner(se::Platform* platform) { @@ -92,53 +92,108 @@ HloRunner::HloRunner(se::Platform* platform) { HloRunner::~HloRunner() {} -StatusOr> HloRunner::Execute( - std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes) { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - CreateExecutable(std::move(module), run_hlo_passes)); - se::Stream stream(backend().default_stream_executor()); - stream.Init(); +StatusOr HloRunner::TransferLiteralToDevice( + const Literal& literal) { + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + literal.shape(), backend().memory_allocator(), + backend().default_device_ordinal())); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + backend().default_stream_executor(), literal, buffer)); + return std::move(buffer); +} - ServiceExecutableRunOptions service_run_options(GetServiceRunOptionsForDevice( - backend().default_device_ordinal(), &stream, nullptr)); - const ExecutableRunOptions& run_options = service_run_options.run_options(); +StatusOr> HloRunner::TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice literals) { + std::vector buffers; + for (const Literal* literal : literals) { + CHECK(literal != nullptr); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer buffer, + TransferLiteralToDevice(*literal)); + buffers.push_back(std::move(buffer)); + } + return std::move(buffers); +} - // Copy arguments to device. - std::vector argument_buffers; - for (Literal* argument : arguments) { - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer argument_buffer, - backend().transfer_manager()->AllocateScopedShapedBuffer( - argument->shape(), run_options.allocator(), - run_options.device_ordinal())); - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - stream.parent(), *argument, argument_buffer)); - argument_buffers.push_back(std::move(argument_buffer)); +StatusOr> HloRunner::TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice> literals) { + std::vector literal_pointers; + literal_pointers.reserve(literals.size()); + for (const auto& literal : literals) { + literal_pointers.push_back(literal.get()); } + return TransferLiteralsToDevice(literal_pointers); +} + +StatusOr> HloRunner::TransferLiteralFromDevice( + const ShapedBuffer& buffer) { + return backend().transfer_manager()->TransferLiteralFromDevice( + backend().default_stream_executor(), buffer); +} - std::vector argument_buffer_ptrs; - argument_buffer_ptrs.reserve(argument_buffers.size()); - for (const auto& buf : argument_buffers) { - argument_buffer_ptrs.push_back(&buf); +StatusOr> HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + TF_ASSIGN_OR_RETURN(std::vector argument_buffers, + TransferLiteralsToDevice(arguments)); + TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, + ExecuteWithDeviceBuffers( + /*module=*/std::move(module), + /*arguments=*/argument_buffers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile)); + return TransferLiteralFromDevice(result); +} + +StatusOr> HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice> arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(argument.get()); } + return Execute( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); +} - TF_ASSIGN_OR_RETURN( - ScopedShapedBuffer result, - executable->ExecuteOnStreamWrapper( - &service_run_options, /*profile=*/nullptr, argument_buffer_ptrs)); - - auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice( - stream.parent(), result); - if (result_literal.ok()) { - VLOG(4) << "Executed binary and got result: " - << result_literal.ValueOrDie()->ToString(); - } else { - VLOG(4) << "Executed binary and got status: " - << result_literal.status().ToString(); +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + // Get service run options. + se::Stream stream(backend().default_stream_executor()); + stream.Init(); + ServiceExecutableRunOptions service_run_options = + GetServiceRunOptionsForDevice(backend().default_device_ordinal(), &stream, + nullptr); + + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + CreateExecutable(std::move(module), run_hlo_passes)); + return executable->ExecuteOnStreamWrapper(&service_run_options, + /*profile=*/profile, arguments); +} + +StatusOr HloRunner::ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes, ExecutionProfile* profile) { + std::vector argument_pointers; + argument_pointers.reserve(arguments.size()); + for (const auto& argument : arguments) { + argument_pointers.push_back(&argument); } - return result_literal; + return ExecuteWithDeviceBuffers( + /*module=*/std::move(module), + /*arguments=*/argument_pointers, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); } StatusOr>> HloRunner::ExecuteReplicated( @@ -295,4 +350,8 @@ Backend& HloRunner::backend() { return *backend_; } +const Backend& HloRunner::backend() const { + return const_cast(this)->backend(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 53f7c6fe4a09111c5ee24f2290f0f4aeed0a4401..65537f07f56e74b7fe2c2f9792af21efc7229573 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -102,6 +102,15 @@ class HloRunner { static StatusOr> ReadModuleFromHloTextFile( const std::string& filename, const DebugOptions& debug_options); + // Transfers data between the host and device. + StatusOr TransferLiteralToDevice(const Literal& literal); + StatusOr> TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice literals); + StatusOr> TransferLiteralsToDevice( + const tensorflow::gtl::ArraySlice> literals); + StatusOr> TransferLiteralFromDevice( + const ShapedBuffer& buffer); + // Executes the given module with given literals as input and returns the // result as a Literal. // @@ -109,20 +118,25 @@ class HloRunner { // optimization. StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes = true); + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); StatusOr> Execute( std::unique_ptr module, const tensorflow::gtl::ArraySlice> arguments, - bool run_hlo_passes = true) { - // Construct a vector of plain pointers for the arguments. - std::vector argument_pointers; - c_transform( - arguments, std::back_inserter(argument_pointers), - [](const std::unique_ptr& literal) { return literal.get(); }); - return Execute(std::move(module), argument_pointers, run_hlo_passes); - } + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + // As Execute(), but accepts and returns device buffers instead of host + // buffers. + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); + + StatusOr ExecuteWithDeviceBuffers( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true, ExecutionProfile* profile = nullptr); // Executes a given HLO module into a set of replicas, and returns a map // with the replica number as key, and the corresponding returned literal as @@ -137,6 +151,7 @@ class HloRunner { // This creates the backend lazily so it's possible to instantiate an // HloRunner in a program without any backends linked in. Backend& backend(); + const Backend& backend() const; private: // Creates an executable object given an HLO module. If run_hlo_passes is diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 854aa943199397c0e3f84d48a74ef41ae0d3db56..b14ade3549d093acdb5cdc7ae99dd025a42d5621 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -36,29 +36,6 @@ using ::tensorflow::strings::HumanReadableNumBytes; namespace xla { -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { - return 0; - } - - const HloModule* module = module_sequence.begin()->first->parent(); - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // The absolute minimum memory required for a given sequence of instructions - // is determined by the sequence of Alloc and Free calls on a simulated heap, - // ignoring fragmentation. We run the heap simulation on the whole module, - // rather than summing each computation, since it gives us a better lower - // bound, by minimizing the liveness of sub-computations. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), *module, - module_sequence, *points_to_analysis, size_function)); - return result.heap_size; -} - namespace { // Class implementing a list scheduler of HLO instructions which produces a @@ -299,6 +276,8 @@ class ListScheduler { auto best_it = ready_queue.end(); --best_it; const HloInstruction* best = best_it->second.instruction; + VLOG(2) << "Schedule instruction: " << best->ToShortString() + << " Bytes freed: " << best_it->first.first; ready_queue.erase(best_it); ready_instructions.erase(best); schedule.push_back(best); @@ -396,7 +375,7 @@ int64 SumLogicalBufferSizes( return size; } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleComputationsInModule( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_function, @@ -414,18 +393,6 @@ StatusOr> CreateMemoryMinimizingSequence( } // namespace -StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique(), computation, - sequence, points_to_analysis, size_function)); - return result.heap_size; -} - StatusOr> DFSMemoryScheduler( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, @@ -437,6 +404,7 @@ StatusOr> DFSMemoryScheduler( // simply users-1 for each instruction. By subtracting 1, we're saying that // instructions with no users or a single user don't count; instructions with // lots of fan-out will be visited earlier. + int64 cumulative_total_size = 0; tensorflow::gtl::FlatMap extra_users; tensorflow::gtl::FlatMap total_sizes; for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { @@ -449,12 +417,21 @@ StatusOr> DFSMemoryScheduler( int64 logical_buffer_size = SumLogicalBufferSizes( points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); total_sizes[hlo] = logical_buffer_size; + cumulative_total_size += logical_buffer_size; tensorflow::gtl::FlatSet unique_operands( hlo->operands().begin(), hlo->operands().end()); for (const HloInstruction* operand : unique_operands) { extra_users[hlo] += extra_users[operand]; total_sizes[hlo] += total_sizes[operand]; } + // total_sizes[hlo] transitively includes the sizes of all nodes that + // lead to it. But computation is a DAG, so we are double-counting nodes, + // which can lead to overflows for large programs. + // cumulative_total_size caps the size to prevent overflows. + // NOTE(dimvar): this is quite ugly and should be changed. It's unclear + // why we care about transitive sizes; when scheduling a node, its input + // and output buffers should be all that matters, not its "history". + total_sizes[hlo] = std::min(total_sizes[hlo], cumulative_total_size); } CHECK_EQ(extra_users.size(), computation.instruction_count()); CHECK_EQ(total_sizes.size(), computation.instruction_count()); @@ -564,10 +541,9 @@ StatusOr> DefaultMemoryScheduler( } } -StatusOr -CreateMemoryMinimizingSequence(const HloModule& module, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm) { +StatusOr ScheduleComputationsInModule( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm) { SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); @@ -575,7 +551,7 @@ CreateMemoryMinimizingSequence(const HloModule& module, for (const auto* computation : module.MakeComputationPostOrder()) { if (!computation->IsFusionComputation()) { TF_ASSIGN_OR_RETURN(auto one_computation_sequence, - CreateMemoryMinimizingSequence( + ScheduleComputationsInModule( *computation, *points_to_analysis, size_function, algorithm, memory_by_computation)); memory_by_computation[computation] = @@ -588,15 +564,15 @@ CreateMemoryMinimizingSequence(const HloModule& module, return sequence; } -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); tensorflow::gtl::FlatMap empty_map; - return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function, nullptr, empty_map); + return ScheduleComputationsInModule(computation, *points_to_analysis, + size_function, nullptr, empty_map); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index 49b927eefd24f4e26df781dd8d2b977bedba2b80..2b33ccc8bfb895286bb3747aab0a16cf25e2cfae 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -28,20 +28,6 @@ limitations under the License. namespace xla { -// Returns the minimum memory required to compute the given module sequence, -// assuming no fragmentation. -StatusOr MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function); - -// Returns the minimum memory required to compute the given computation, -// assuming no fragmentation. -StatusOr MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function); - // A memory scheduler computes an execution sequence for the HLO instructions in // 'computation' that minimizes peak memory, given a points-to analysis result // that describes buffer aliasing, together with a target-specific size function @@ -89,14 +75,13 @@ StatusOr> DefaultMemoryScheduler( // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. -StatusOr -CreateMemoryMinimizingSequence(const HloModule& module, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerAlgorithm& algorithm = {}); +StatusOr ScheduleComputationsInModule( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerAlgorithm& algorithm = {}); -// Overload of above that computes the sequence for a single computation. +// Computes the schedule for a single computation. // Currently only used by the GPU backend. -StatusOr> CreateMemoryMinimizingSequence( +StatusOr> ScheduleOneComputation( const HloComputation& computation, const LogicalBuffer::SizeFunction& size_function); diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index c018ba2ffc404d0c6a0d08b8f5c63a9f90888b70..6f1b1215d39dfbaeff768de70fa0a0859cd97381 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -22,74 +22,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { -class MinimumMemoryForSequenceTest : public HloTestBase {}; - -TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { - auto module = CreateNewModule(); - const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); - - auto cond_builder = HloComputation::Builder("WhileCond"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); - HloInstruction* cond_iter = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); - HloInstruction* cond_data = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); - // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) - HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); - HloComputation* cond_computation = - module->AddEmbeddedComputation(cond_builder.Build()); - - auto body_builder = HloComputation::Builder("WhileBody"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "body_param")); - HloComputation* body_computation = - module->AddEmbeddedComputation(body_builder.Build()); - - auto builder = HloComputation::Builder(TestName()); - // Entry params: 8 bytes (4 bytes per param), TOTAL=8 - HloInstruction* iter = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); - HloInstruction* data = builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "param_data")); - // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 - HloInstruction* tuple = - builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); - // While: 8 bytes (4 bytes per element), TOTAL=32 - // Both cond and body use a max of 24 bytes, TOTAL=56 - HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( - tuple_shape, cond_computation, body_computation, tuple)); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - }; - - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, - MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); -} - class HloSchedulingTest : public HloTestBase {}; TEST_F(HloSchedulingTest, LastUseScheduledFirst) { @@ -124,7 +65,7 @@ TEST_F(HloSchedulingTest, LastUseScheduledFirst) { TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, [](const BufferValue& buffer) { + ScheduleComputationsInModule(*module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); })); // Verify that all instructions are in the sequence. @@ -158,14 +99,14 @@ ENTRY root { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(module_str)); + ParseHloString(module_str)); auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); }; TF_ASSERT_OK_AND_ASSIGN( SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence(*module, size_fn, ListMemoryScheduler)); + ScheduleComputationsInModule(*module, size_fn, ListMemoryScheduler)); // Verify that all instructions are in the sequence. EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.at(module->entry_computation()).size()); @@ -270,7 +211,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { module->AddEntryComputation(builder.Build()); TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, - CreateMemoryMinimizingSequence( + ScheduleComputationsInModule( *module, [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); @@ -289,5 +230,100 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { EXPECT_TRUE(ordering.ExecutesBefore(transpose, add)); } +TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) { + auto builder = HloComputation::Builder(TestName()); + const auto TUPLE_SIZE = 1; + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6}); + + // Wrap lit in abs because constants are considered free by + // IgnoreInstruction, and it skews the accounting. + auto lit = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 1, 1, 1, 1, 1}))); + auto abs_const = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit)); + + auto abs_abs1 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple( + tensorflow::gtl::ArraySlice({abs_abs1}))); + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto abs_abs2 = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const)); + + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, + tuple_elm, abs_abs2)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN( + SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule(*module, + [&TUPLE_SIZE](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf( + buffer.shape(), TUPLE_SIZE); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // tuple allocates the tuple buffer and doesn't free anything. + // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0. + // abs_abs2 should be scheduled before tuple by List. + EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple)); +} + +TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { + const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5}); + HloComputation::Builder builder(TestName()); + + auto c1 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 1, 1, 1, 1}))); + auto c2 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1, 2, 3, 4, 5}))); + auto c3 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0, 2, 4, 6, 8}))); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul})); + + auto tuple_elm = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(r1f32, tuple, 0)); + + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3)); + + builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + + auto fusion = computation->CreateFusionInstruction( + {tuple, mul, add}, HloInstruction::FusionKind::kLoop); + + TF_ASSERT_OK_AND_ASSIGN(SequentialHloOrdering::HloModuleSequence sequence, + ScheduleComputationsInModule( + *module, + [](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), 2); + }, + ListMemoryScheduler)); + + // Verify that all instructions are in the sequence. + EXPECT_EQ(module->entry_computation()->instruction_count(), + sequence.at(module->entry_computation()).size()); + SequentialHloOrdering ordering(module.get(), sequence); + // fusion allocates memory for the tuple elements and doesn't free anything, + // so it's more expensive than exp. + EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 7f7e3f7dab03ce0ad64bd0fcfe4ddd020d31bf56..9fb15df7c26951fb7f0d62b0d6533d6312e7a4d5 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -39,6 +39,34 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { return HloSharding(tile_shape, assignment); } +HloSharding HloSharding::Tuple(const ShapeTree& sub_shardings) { + std::vector flattened_list; + flattened_list.reserve(sub_shardings.leaf_count()); + for (const auto& index_to_sharding : sub_shardings.leaves()) { + flattened_list.push_back(index_to_sharding.second); + } + if (flattened_list.empty()) { + // Empty tuple sharding ends up having no leaves, but we want to allow + // empty tuple HLO instruction results to have sharding, so we fetch the + // root ({}) sharding value from the ShapeTree. + // A ShapeTree created with ShapeTree(shape, init) will have + // init as value at its root. + flattened_list.push_back(sub_shardings.element(ShapeIndex({}))); + } + return HloSharding(flattened_list); +} + +HloSharding HloSharding::Tuple( + const Shape& tuple_shape, + tensorflow::gtl::ArraySlice shardings) { + CHECK(ShapeUtil::IsTuple(tuple_shape)) << ShapeUtil::HumanString(tuple_shape); + std::vector flattened_list(shardings.begin(), shardings.end()); + CHECK_EQ(flattened_list.size(), RequiredLeaves(tuple_shape)) + << "Flat list has " << flattened_list.size() << ", required " + << RequiredLeaves(tuple_shape); + return HloSharding(flattened_list); +} + string HloSharding::ToString() const { if (IsTuple()) { std::vector parts; @@ -49,9 +77,6 @@ string HloSharding::ToString() const { return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); } - string result = StrCat("{", (replicated_ ? " replicated" : ""), - (maximal_ ? " maximal" : "")); - if (replicated_) { return "{replicated}"; } else if (maximal_) { @@ -126,6 +151,49 @@ std::vector HloSharding::TileLimitForDevice(int64 device) const { return index; } +int64 HloSharding::RequiredLeaves(const Shape& shape) { + // Empty tuples have no leaf nodes as far as ShapeUtil and ShapeTree are + // concerned, but they do have a single tuple_elements_ entry since we want + // to allow empty tuple results to have sharding. + return ShapeUtil::IsEmptyTuple(shape) ? 1 : ShapeUtil::GetLeafCount(shape); +} + +Status HloSharding::CheckLeafCount(const Shape& shape) const { + int64 shape_leaves = RequiredLeaves(shape); + TF_RET_CHECK(shape_leaves == tuple_elements_.size()) + << "Shape " << ShapeUtil::HumanString(shape) << " has " << shape_leaves + << " leaf nodes while this sharding has " << tuple_elements_.size(); + return Status::OK(); +} + +StatusOr> HloSharding::AsShapeTree( + const Shape& shape) const { + if (IsTuple()) { + ShapeTree result(shape, HloSharding::Replicate()); + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); + auto it = tuple_elements_.begin(); + for (auto& index_to_sharding : result.leaves()) { + index_to_sharding.second = *it++; + } + if (ShapeUtil::IsEmptyTuple(shape)) { + // Empty tuples have no leaves, but we want to assign them a sharding + // anyway, so we use the root element sharding. + *result.mutable_element(ShapeIndex({})) = *it; + } + return std::move(result); + } else { + return ShapeTree(shape, *this); + } +} + +StatusOr HloSharding::GetTupleSharding(const Shape& shape) const { + if (IsTuple()) { + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); + return *this; + } + return Tuple(ShapeTree(shape, *this)); +} + StatusOr HloSharding::UniqueDevice() const { if (IsTuple()) { if (tuple_elements_.empty()) { @@ -167,28 +235,12 @@ Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { return tensorflow::errors::InvalidArgument( StrCat("Sharding is tuple-shaped but validation shape is not.")); } - // The easiest way to get the number of elements in a nested tuple is just to - // create a shape tree. We could call GetAsShapeTree, but that will try and - // apply our tuple_shardings_ to the shape tree, and that might cause a crash - // at this point as we haven't validated them. - ShapeTree bool_shape_tree(shape, false); - int64 num_leaves = - std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end()); - if (num_leaves != tuple_elements_.size()) { - return tensorflow::errors::InvalidArgument( - StrCat("Validation tuple shape has ", num_leaves, - " leaf elements, but this sharding contains ", - tuple_elements_.size(), " elements.")); - } + TF_RETURN_IF_ERROR(CheckLeafCount(shape)); // Now we've validated the number of tuple elements, it's safe to request a // shape tree. ShapeTree shape_tree = GetAsShapeTree(shape); for (const auto& index_to_sharding : shape_tree.leaves()) { - if (index_to_sharding.first.empty()) { - // An empty tuple has a ShapeTree with a single leaf at the empty index. - continue; - } Status status = index_to_sharding.second.ValidateNonTuple( ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices); if (!status.ok()) { @@ -370,11 +422,21 @@ HloSharding HloSharding::GetSubSharding(const Shape& shape, Shape sub_shape = ShapeUtil::GetSubshape(shape, index); ShapeTree sub_shape_tree(sub_shape, Replicate()); sub_shape_tree.CopySubtreeFrom(GetAsShapeTree(shape), index, {}); - if (ShapeUtil::IsTuple(sub_shape)) { - return Tuple(sub_shape_tree); - } else { - return sub_shape_tree.element({}); + return ShapeUtil::IsTuple(sub_shape) ? Tuple(sub_shape_tree) + : sub_shape_tree.element(ShapeIndex({})); +} + +tensorflow::gtl::optional HloSharding::ExtractSingleSharding() + const { + if (!IsTuple()) { + return *this; + } + for (int64 i = 1; i < tuple_elements_.size(); ++i) { + if (tuple_elements_[0] != tuple_elements_[i]) { + return tensorflow::gtl::optional(); + } } + return tuple_elements_.front(); } std::ostream& operator<<(std::ostream& out, const HloSharding& sharding) { diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 2b8e757f42991f697df37d3d34bfdff6a36bc509..6a744e0247273e25c5de3143b7bbba2b79ee816a 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -70,26 +70,13 @@ class HloSharding { // Creates a new sharding for a tuple type. The given ShapeTree must have // elements for every leaf shape contained in the tuple. - static HloSharding Tuple(const ShapeTree& sub_shardings) { - std::vector flattened_list; - flattened_list.reserve( - std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end())); - for (const auto& index_to_sharding : sub_shardings.leaves()) { - flattened_list.push_back(index_to_sharding.second); - } - return HloSharding(flattened_list); - } + static HloSharding Tuple(const ShapeTree& sub_shardings); - // Creates a new sharding for a tuple type. The requested tuple shape must not - // be nested. For nested tuples, use the ShapeTree overload. + // Creates a new sharding for a tuple type. The number of elements in + // shardings must match the number of leaf nodes in tuple_shape. For + // empty tuples, the shardings array must have one element. static HloSharding Tuple(const Shape& tuple_shape, - tensorflow::gtl::ArraySlice shardings) { - CHECK(ShapeUtil::IsTuple(tuple_shape)); - CHECK(!ShapeUtil::IsNestedTuple(tuple_shape)); - std::vector flattened_list(shardings.begin(), shardings.end()); - CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape)); - return HloSharding(flattened_list); - } + tensorflow::gtl::ArraySlice shardings); // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); @@ -99,6 +86,9 @@ class HloSharding { static bool IsReservedDevice(int64 device) { return device < 0; } OpSharding ToProto() const; + + // Note that this string canonically has outer curly braces, e.g. + // "{replicated}". string ToString() const; // Validate that this sharding can be applied to a tensor with shape `shape`. @@ -160,25 +150,27 @@ class HloSharding { // tuple, if IsTuple, or a ShapeTree with a single element containing this // sharding. Only the leaf elements are populated. This creates a new // ShapeTree object so is not cheap. + StatusOr> AsShapeTree(const Shape& shape) const; ShapeTree GetAsShapeTree(const Shape& shape) const { - if (IsTuple()) { - ShapeTree result(shape, HloSharding::Replicate()); - CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()), - tuple_elements_.size()); - auto it = tuple_elements_.begin(); - for (auto& index_to_sharding : result.leaves()) { - index_to_sharding.second = *it++; - } - return result; - } else { - return ShapeTree(shape, *this); - } + return AsShapeTree(shape).ValueOrDie(); } // Retrieves the sub sharding at a given index, out of a tuple sharding. // REQUIRES: IsTuple() HloSharding GetSubSharding(const Shape& shape, const ShapeIndex& index) const; + // If the current sharding is a tuple sharding, return itself as result. + // Otherwise returns a tuple sharding for the input shape, with all the leaves + // having this object sharding. + StatusOr GetTupleSharding(const Shape& shape) const; + + // Extracts the sharding that is common within the current sharding. + // If the current sharding is not a tuple sharding, the current sharding will + // be returned. If it is a tuple, and all the tuple elements are common, the + // common element will be returned. Otherwise the optional will contain no + // value. + tensorflow::gtl::optional ExtractSingleSharding() const; + bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && ShapeUtil::Compatible(tile_shape_, other.tile_shape_) && @@ -208,6 +200,12 @@ class HloSharding { return h; } + struct Hasher { + size_t operator()(const HloSharding& sharding) const { + return sharding.Hash(); + } + }; + // Gets the tile shape. // REQUIRES: !IsTileMaximal() && !IsTuple() const Shape& tile_shape() const { return tile_shape_; } @@ -261,11 +259,19 @@ class HloSharding { tile_assignment_({0}), tuple_elements_(tuple_shardings) {} + // Checks that the number of elements in tuple_elements_ is consistent with + // the tuple shape passes as argument. + Status CheckLeafCount(const Shape& shape) const; + // Internal helper to validate a tuple sharding. Status ValidateTuple(const Shape& shape, int64 num_devices) const; + // Internal helper to validate a non-tuple (leaf) sharding. Status ValidateNonTuple(const Shape& shape, int64 num_devices) const; + // Returns the number of tuple_elements_ entries to fit the shape. + static int64 RequiredLeaves(const Shape& shape); + bool replicated_; bool maximal_; bool tuple_; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b4b071af46df19520f9ba1f1f632692d489de59 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -0,0 +1,394 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +namespace { + +struct PassThrough { + PassThrough(HloInstruction* user, HloInstruction* operand) + : user(user), operand(operand) {} + + HloInstruction* user = nullptr; + HloInstruction* operand = nullptr; +}; + +void SetSingleSharding(HloInstruction* instruction, + const HloSharding& sharding) { + VLOG(4) << " " << instruction->name() << " to " << sharding; + instruction->set_single_sharding(sharding); +} + +bool ShardingMatches(const HloSharding& sharding1, + const HloSharding& sharding2) { + auto single_sharding1 = sharding1.ExtractSingleSharding(); + if (single_sharding1) { + auto single_sharding2 = sharding2.ExtractSingleSharding(); + if (single_sharding2) { + return *single_sharding1 == single_sharding2; + } + } + // Anything which is not unique across all elements, gets a full sharding + // compare. + return sharding1 == sharding2; +} + +// When we create domains, they are never "empty", where with empty we mean +// that a kDomain instruction has as operand another kDomain instruction of the +// same kind. +// But when the HLO optimizations are run, empty domains can be created. +// For example: +// +// Domain(device=None, device=0) -> +// Tuple(device=0) -> +// GTE(device=0) -> +// Domain(device=0, device=None) +// +// In that case the tuple simplifier could create something like: +// +// Domain(device=None, device=0) -> Domain(device=0, device=None) +// +// Which is a so called empty domain. +// In the case above, crossing an empty domain which was transiting through +// device 0, requires the normalization phase to fixup the empty domain by +// adding back a Tuple+GTE pair with the proper device. +// One particular case where this can create problems is the result of the +// entry computation, where the GTE assignments are used by TF to tell the +// XLA where the results should be sent. +std::vector LocatePassThroughDomainLinks( + const DomainMetadata::Domain& domain) { + std::vector pass_through; + for (HloInstruction* instruction : domain.enter_domains) { + CHECK(instruction->opcode() == HloOpcode::kDomain) + << "Instruction is not a kDomain: " << instruction->ToString(); + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kDomain && + domain.exit_domains.count(user) != 0) { + pass_through.emplace_back(user, instruction); + VLOG(2) << "Found passthrough domain link:"; + VLOG(2) << " " << user->ToString(); + VLOG(2) << " " << instruction->ToString(); + } + } + } + return pass_through; +} + +Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + for (auto& pass_through : LocatePassThroughDomainLinks(domain)) { + HloInstruction* tuple = pass_through.operand->parent()->AddInstruction( + HloInstruction::CreateTuple({pass_through.operand})); + HloInstruction* gte = pass_through.operand->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(pass_through.operand->shape(), + tuple, 0)); + gte->set_sharding(sharding); + TF_RETURN_IF_ERROR( + pass_through.operand->ReplaceUseWith(pass_through.user, gte)); + } + return Status::OK(); +} + +std::unique_ptr CloneShardingForDomain( + const HloSharding& sharding) { + auto single_sharding = sharding.ExtractSingleSharding(); + if (!single_sharding) { + return MakeUnique(sharding); + } + return MakeUnique(*single_sharding); +} + +Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + VLOG(4) << "Applying " << sharding << " sharding"; + for (HloInstruction* instruction : domain.instructions) { + // We only change instructions without sharding, since otherwise we might + // mess up with eventual HLO passes which has knowledge of it. + if (!instruction->has_sharding()) { + SetSingleSharding(instruction, sharding); + } else { + VLOG(4) << " " << instruction->name() << " already has sharding " + << instruction->sharding(); + } + } + return Status::OK(); +} + +// Retrieves the sharding of a tuple shaped instruction in form of a ShapeTree. +// If the instruction has no sharding, a ShapeTree with HloSharding::Replicate() +// sharding will be returned. +ShapeTree GetTupleSharding(HloInstruction* tuple) { + if (tuple->has_sharding()) { + return tuple->sharding().GetAsShapeTree(tuple->shape()); + } + return ShapeTree(tuple->shape(), HloSharding::Replicate()); +} + +// Retrieves the sharding of operand, asked from a user instruction which is +// within domain. If operand is a kDomain, it means that sharding argument is +// the operand sharding, otherwise the operand's own sharding will be returned. +const HloSharding* GetOperandSharding(const HloInstruction* operand, + const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + DCHECK_EQ(domain.reach_set.count(const_cast(operand)), 1); + // Here the user of operand is within the domain instruction set, and since it + // is user of operand, we need to look into the enter_domains set. If this is + // not a kDomain within the user domains set, then return the operand + // sharding, if any. + if (operand->opcode() != HloOpcode::kDomain || + domain.enter_domains.count(const_cast(operand)) == 0) { + return operand->has_sharding() ? &operand->sharding() : nullptr; + } + // At this point operand is a kDomain of the currently processed domain, so we + // can refer to sharding as the domain sharding. + return &sharding; +} + +// Tries to propagate the sharding information into the instructions that are +// part of the domain, in a post order manner (operand propagate to user). +StatusOr ApplyDomainShardingPass(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + int64 assigned = 0; + for (HloInstruction* instruction : domain.instructions) { + if (instruction->has_sharding()) { + continue; + } + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + HloInstruction* tuple = instruction->mutable_operand(0); + const HloSharding* tuple_sharding = + GetOperandSharding(tuple, domain, sharding); + if (tuple_sharding != nullptr) { + if (tuple_sharding->IsTuple()) { + HloSharding sub_sharding = tuple_sharding->GetSubSharding( + tuple->shape(), {instruction->tuple_index()}); + VLOG(4) << " " << instruction->name() << " to sharding " + << sub_sharding; + instruction->set_sharding(sub_sharding); + } else { + SetSingleSharding(instruction, *tuple_sharding); + } + ++assigned; + } + } else if (instruction->opcode() == HloOpcode::kTuple) { + int64 tuple_assigned = 0; + ShapeTree shape_tree = GetTupleSharding(instruction); + for (int64 i = 0; i < instruction->operand_count(); ++i) { + const HloSharding* operand_sharding = + GetOperandSharding(instruction->operand(i), domain, sharding); + if (operand_sharding != nullptr && + shape_tree.element({i}) != *operand_sharding) { + *shape_tree.mutable_element({i}) = *operand_sharding; + ++tuple_assigned; + } + } + if (tuple_assigned > 0) { + HloSharding tuple_sharding = HloSharding::Tuple(shape_tree); + VLOG(4) << " " << instruction->name() << " to sharding " + << tuple_sharding; + instruction->set_sharding(tuple_sharding); + ++assigned; + } + } else { + // If all the operand of the given instruction has the same single device + // assignment, assign that device to this instruction as well. + const HloSharding* common_sharding = nullptr; + for (const HloInstruction* operand : instruction->operands()) { + const HloSharding* operand_sharding = + GetOperandSharding(operand, domain, sharding); + if (operand_sharding != nullptr) { + if (common_sharding != nullptr && + *common_sharding != *operand_sharding) { + common_sharding = nullptr; + break; + } + common_sharding = operand_sharding; + } + } + if (common_sharding != nullptr) { + VLOG(4) << " " << instruction->name() << " to sharding " + << *common_sharding; + instruction->set_sharding(*common_sharding); + ++assigned; + } + } + } + return assigned; +} + +Status ApplyDomainSharding(const DomainMetadata::Domain& domain, + const HloSharding& sharding) { + auto single_sharding = sharding.ExtractSingleSharding(); + if (single_sharding) { + // Shortcut the simple case. We have a unique sharding, so we call + // the ApplyDomainSingleSharding() API which will apply array or tuple + // shaped sharding to the domain instructions. + return ApplyDomainSingleSharding(domain, *single_sharding); + } + VLOG(1) << "Assigning non-trivial sharding " << sharding; + for (;;) { + TF_ASSIGN_OR_RETURN(int64 assigned, + ApplyDomainShardingPass(domain, sharding)); + if (assigned == 0) { + break; + } + } + int64 unassigned = 0; + for (HloInstruction* instruction : domain.instructions) { + if (!instruction->has_sharding()) { + LOG(WARNING) << "Unassigned instruction: " << instruction->ToString(); + ++unassigned; + } + } + // Should we error out if unassigned > 0? + return Status::OK(); +} + +// Creates a kDomain instruction to be placed between instruction and operand. +// The kDomain instruction will be created only if the sharding differ between +// the instruction and the operand. +std::unique_ptr CreateDomain(HloInstruction* instruction, + HloInstruction* operand) { + const HloSharding* instruction_sharding = + instruction->has_sharding() ? &instruction->sharding() : nullptr; + const HloSharding* operand_sharding = + operand->has_sharding() ? &operand->sharding() : nullptr; + // No need for domain if they both have no sharding. + if (instruction_sharding == nullptr && operand_sharding == nullptr) { + return nullptr; + } + // No need for domain if they match. + if (instruction_sharding != nullptr && operand_sharding != nullptr && + ShardingMatches(*instruction_sharding, *operand_sharding)) { + return nullptr; + } + std::unique_ptr real_instruction_sharding; + std::unique_ptr real_operand_sharding; + if (instruction_sharding != nullptr) { + real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); + } + if (operand_sharding != nullptr) { + real_operand_sharding = CloneShardingForDomain(*operand_sharding); + } + VLOG(3) << "Creating domain:"; + VLOG(3) << " Instruction: " << instruction->name(); + VLOG(3) << " Operand: " << operand->name(); + VLOG(3) << " User side sharding: " + << (real_instruction_sharding != nullptr + ? real_instruction_sharding->ToString() + : "None"); + VLOG(3) << " Operand side sharding: " + << (real_operand_sharding != nullptr + ? real_operand_sharding->ToString() + : "None"); + + std::unique_ptr operand_side_metadata = + MakeUnique(std::move(real_operand_sharding)); + std::unique_ptr user_side_metadata = + MakeUnique(std::move(real_instruction_sharding)); + return HloInstruction::CreateDomain(operand->shape(), operand, + std::move(operand_side_metadata), + std::move(user_side_metadata)); +} + +StatusOr> ExtractOriginalCommonSharding( + tensorflow::gtl::ArraySlice instructions) { + // If we are here, all the instructions being passed had the same sharding + // (or no sharding), by the means of the ShardingMatches() API. + // As such, no kDomain was inserted, and here we are asked to extract the + // original common sharding. + // All the instructions passed to this API are part of the same computation. + const HloSharding* sharding = nullptr; + for (HloInstruction* instruction : instructions) { + if (instruction->has_sharding()) { + if (sharding == nullptr) { + sharding = &instruction->sharding(); + } else { + TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding())) + << "Sharding " << *sharding << " does not match the one in " + << instruction->ToString(); + } + } + } + if (sharding == nullptr) { + return std::unique_ptr(); + } + VLOG(4) << "Extracted sharding is " << *sharding; + return CloneShardingForDomain(*sharding); +} + +} // namespace + +std::unique_ptr ShardingMetadata::Clone() const { + std::unique_ptr sharding; + if (sharding_ != nullptr) { + sharding = MakeUnique(*sharding_); + } + return MakeUnique(std::move(sharding)); +} + +bool ShardingMetadata::Matches(const DomainMetadata& other) const { + const ShardingMetadata* other_ptr = + dynamic_cast(&other); + if (other_ptr == nullptr) { + // If other is not a ShardingMetadata, then it is clearly a no match. + return false; + } + if (sharding_ == nullptr) { + return other_ptr->sharding_ == nullptr; + } + return other_ptr->sharding_ != nullptr + ? ShardingMatches(*sharding_, *other_ptr->sharding_) + : false; +} + +string ShardingMetadata::ToString() const { + return sharding_ != nullptr ? sharding_->ToString() : "None"; +} + +Status ShardingMetadata::NormalizeInstructions( + const DomainMetadata::Domain& domain) const { + if (sharding_ != nullptr) { + VLOG(4) << "Normalizing sharding to " << sharding_->ToString() << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding_)); + TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding_)); + } + return Status::OK(); +} + +Status NormalizeShardingDomain(const DomainMetadata::Domain& domain) { + TF_ASSIGN_OR_RETURN(std::unique_ptr sharding, + ExtractOriginalCommonSharding(domain.instructions)); + if (sharding != nullptr) { + VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString() + << ":"; + TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding)); + } else { + VLOG(1) << "Unable to find common sharding"; + } + return Status::OK(); +} + +std::unique_ptr CreateShardingDomain( + HloInstruction* instruction, HloInstruction* operand) { + return CreateDomain(instruction, operand); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h new file mode 100644 index 0000000000000000000000000000000000000000..ec162c34904ee2dfac3daeeee37133282a9c9698 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ + +#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +// A DomainMetadata implementation that internally wraps a sharding attribute. +class ShardingMetadata : public DomainMetadata { + public: + explicit ShardingMetadata(std::unique_ptr sharding) + : sharding_(std::move(sharding)) {} + + std::unique_ptr Clone() const override; + + tensorflow::StringPiece Kind() const override { return KindName(); } + + bool Matches(const DomainMetadata& other) const override; + + string ToString() const override; + + Status NormalizeInstructions( + const DomainMetadata::Domain& domain) const override; + + static tensorflow::StringPiece KindName() { return "sharding"; } + + private: + std::unique_ptr sharding_; +}; + +// Within a set of instructions which had common sharding attributes before +// entring the HLO passes pipeline, apply sharding heuristics and normalize the +// instructions whose sharding deviates from the one which is inferred as to be +// the original one. +// Policy wise, HLO passes are allowed to create new unassigned instructions, +// but if they do create assigned ones, they have to conform to the ones around. +Status NormalizeShardingDomain(const DomainMetadata::Domain& domain); + +// Given an HLO graph edge between instruction and one of its operands, creates +// a ShardingMetadata based kDomain instruction if the sharding between +// instruction and operand changes. Returns nullptr if there is no need for a +// domain separation. +std::unique_ptr CreateShardingDomain( + HloInstruction* instruction, HloInstruction* operand); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_METADATA_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index 3bf0d25efb7fad78aeccdd9269c289950b2171ab..54b7402b866361748d9eb35182b0bf486c4c9bdc 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/hlo_sharding.h" - #include #include #include #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -312,5 +311,50 @@ TEST_F(HloShardingTest, OstreamTest) { EXPECT_EQ(oss.str(), "{f32[3,5,7,11] devices=[1,1,2,2]0,1,2,3}"); } +TEST_F(HloShardingTest, ParseHloString) { + auto check = [](const HloSharding& sharding) { + TF_ASSERT_OK_AND_ASSIGN(auto parsed_sharding, + ParseSharding(sharding.ToString())); + EXPECT_EQ(sharding, parsed_sharding); + }; + check(HloSharding::Replicate()); + check(HloSharding::AssignDevice(2)); + check(HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}}))); + // Empty tuple. One sharding is required for empty tuples, as we need to be + // able to assign sharding to them, even though they have no leaves. + check(HloSharding::Tuple(ShapeUtil::MakeTupleShape({}), + {HloSharding::Replicate()})); + { + // Non-nested tuple. + auto tuple_shape = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 1, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 7})}); + check(HloSharding::Tuple( + tuple_shape, {HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}})), + HloSharding::Replicate(), HloSharding::AssignDevice(1)})); + } + { + // Nested tuple. + auto tuple_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 1, 5, 7}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3, 5, 7}), + ShapeUtil::MakeShape(F32, {3, 7})})}); + std::vector leaf_shardings = { + HloSharding::Tile(ShapeUtil::MakeShape(F32, {3, 1, 3, 7}), + Array4D({{{{0}, {1}}}})), + HloSharding::Replicate(), HloSharding::AssignDevice(1)}; + ShapeTree sharding_tree(tuple_shape, HloSharding::Replicate()); + // Assign leaf_shardings to sharding_tree leaves. + auto it = leaf_shardings.begin(); + for (auto& index_to_sharding : sharding_tree.leaves()) { + index_to_sharding.second = *it++; + } + check(HloSharding::Tuple(sharding_tree)); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/service/hlo_token.h similarity index 84% rename from tensorflow/compiler/xla/tools/parser/hlo_token.h rename to tensorflow/compiler/xla/service/hlo_token.h index 7928bee5c2097f353b182095a555c334d7b69c95..533429608bc2e13626a3e746fbe465398e1f4bb4 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/service/hlo_token.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ -#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ #include @@ -22,9 +22,11 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace xla { -namespace tools { // Defines different kinds of tokens in a hlo module string. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. enum class TokKind { // Markers kEof, @@ -72,7 +74,6 @@ enum class TokKind { string TokKindToString(TokKind kind); -} // namespace tools } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 7d6d0d9eaf70969c1a3762959233b561706398c2..9034073cc8a82311297ccd087741e6713110a5a7 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -376,6 +376,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { case HloOpcode::kConstant: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: + case HloOpcode::kDomain: case HloOpcode::kFusion: case HloOpcode::kGetTupleElement: case HloOpcode::kInfeed: @@ -425,6 +426,14 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } +Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) { + std::vector operand_shapes; + for (const HloInstruction* operand : token->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(token, ShapeInference::InferTokenShape(operand_shapes)); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, const Shape& inferred_shape) { // If allow_mixed_precision_ is false, check if there are operands with @@ -790,6 +799,46 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { return Status::OK(); } +namespace { + +// Returns true if the given Shape has a TOKEN shape as any subshape. +bool ShapeContainsToken(const Shape& shape) { + bool contains_token = false; + ShapeUtil::ForEachSubshape( + shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsToken(subshape)) { + contains_token = true; + } + }); + return contains_token; +} + +// Verifies that all types entering and exiting the entry computation are +// legal. For example, TOKEN types have no Literal representation and cannot be +// on the interface of the entry computation (parameters and root instruction). +Status VerifyEntryAndExitShapes(const HloModule& module) { + for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { + HloInstruction* param = + module.entry_computation()->parameter_instruction(i); + if (ShapeContainsToken(param->shape())) { + return InternalError( + "Entry parameter %d is or contains a token shape: %s", i, + ShapeUtil::HumanString(param->shape()).c_str()); + } + } + if (ShapeContainsToken( + module.entry_computation()->root_instruction()->shape())) { + return InternalError( + "Entry root is or contains a token shape: %s", + ShapeUtil::HumanString( + module.entry_computation()->root_instruction()->shape()) + .c_str()); + } + return Status::OK(); +} + +} // namespace + StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); @@ -850,6 +899,8 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); } + TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 1392a78097aa026b2f7cffa2b0135402d3ca7ae5..7283b3e7dcdbed5be18a1da1571287cf0c089288 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -81,6 +81,7 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; + Status HandleGenerateToken(HloInstruction* token) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index dc3bfce0c495bc40a2df7b985cab67e02a3e15ce..d7458c338e9f1df9fac90270845aae0b8f779ee2 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -169,6 +169,23 @@ string HumanReadableProfileBuilder::ToString() const { StrAppend(&s, table.MakeReport(CyclesToMicroseconds(total_cycles_))); } } + + if (total_bytes > 0) { + MetricTableReport table; + table.SetMetricName("MiB read+written"); + table.SetEntryName("ops"); + table.SetShowCategoryTable(); + for (const auto& op : op_infos_) { + MetricTableReport::Entry entry; + entry.text = op.name; + entry.short_text = op.short_name; + entry.category_text = op.category; + entry.metric = static_cast(op.bytes_accessed) / (1 << 20); + table.AddEntry(std::move(entry)); + } + StrAppend(&s, + table.MakeReport(static_cast(total_bytes) / (1 << 20))); + } return s; } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 15b2d8f4990735c56f105e7c1b9b7dc70609d898..8b3fa6c1572cf0ed91fc427722edcb23d8b8529d 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/indexed_array_analysis.h" #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -28,9 +29,11 @@ using Analysis = IndexedArrayAnalysis; using UnknownArray = Analysis::UnknownArray; using ConstantArray = Analysis::ConstantArray; using ScalarIndexedArray = Analysis::ScalarIndexedArray; +using tensorflow::gtl::ArraySlice; +using tensorflow::str_util::Join; } // namespace -string IndexedArrayAnalysis::ToString(Array* root) { +string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { switch (root->kind()) { case Array::kUnknown: { auto* unknown_tensor = root->as(); @@ -39,6 +42,12 @@ string IndexedArrayAnalysis::ToString(Array* root) { } case Array::kConstant: { + if (print_constants) { + string contents = root->as()->literal()->ToString(); + return tensorflow::strings::StrCat( + "(constant ", ShapeUtil::HumanString(root->shape()), " ", contents, + ")"); + } return tensorflow::strings::StrCat( "(constant ", ShapeUtil::HumanString(root->shape()), ")"); } @@ -50,26 +59,26 @@ string IndexedArrayAnalysis::ToString(Array* root) { ? "scalar-indexed-const" : "scalar-indexed"; return tensorflow::strings::StrCat( - "(", name, " ", ToString(indexed_array->source()), " ", - ToString(indexed_array->indices()), " ", indexed_array->source_dim(), - "->[", tensorflow::str_util::Join(indexed_array->output_dims(), ","), - "])"); + "(", name, " ", ToString(indexed_array->source(), print_constants), + " ", ToString(indexed_array->indices(), print_constants), " ", + indexed_array->source_dim(), "->[", + Join(indexed_array->output_dims(), ","), "])"); } } } -Analysis::Array* IndexedArrayAnalysis::GetArrayFor( +StatusOr IndexedArrayAnalysis::GetArrayFor( const HloInstruction* instr) { auto it = cache_.find(instr); if (it != cache_.end()) { return it->second; } - TraverseAndPopulateCache(instr); + TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr)); return FindOrDie(cache_, instr); } -void IndexedArrayAnalysis::TraverseAndPopulateCache( +Status IndexedArrayAnalysis::TraverseAndPopulateCache( const HloInstruction* root) { // Depth first search over the DAG, invoking ComputeArrayFor in post order. // The HLO instructions already in the cache are considered leaves. @@ -105,28 +114,46 @@ void IndexedArrayAnalysis::TraverseAndPopulateCache( case kVisited: stack.pop_back(); - InsertOrDie(&cache_, instr, ComputeArrayFor(instr)); + TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr)); + InsertOrDie(&cache_, instr, array); break; } } while (!stack.empty()); + + return Status::OK(); } -Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor( +StatusOr IndexedArrayAnalysis::ComputeArrayFor( const HloInstruction* instr) { Array* computed_array; - switch (instr->opcode()) { - default: - computed_array = nullptr; - break; - case HloOpcode::kConstant: - computed_array = ComputeArrayForConstant(instr->literal()); - break; - case HloOpcode::kGather: - computed_array = ComputeArrayForGather( - instr->shape(), instr->gather_dimension_numbers(), - instr->gather_window_bounds(), FindOrDie(cache_, instr->operand(0)), - FindOrDie(cache_, instr->operand(1))); - break; + if (instr->IsElementwise() && instr->operand_count() == 1) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForElementwiseUnaryOp( + instr->opcode(), FindOrDie(cache_, instr->operand(0)))); + } else if (instr->IsElementwise() && instr->operand_count() == 2) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForElementwiseBinaryOp( + instr->opcode(), FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); + } else if (instr->opcode() == HloOpcode::kConstant) { + TF_ASSIGN_OR_RETURN(computed_array, + ComputeArrayForConstant(instr->literal())); + } else if (instr->opcode() == HloOpcode::kGather) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(), + instr->gather_window_bounds(), + FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); + } else if (instr->opcode() == HloOpcode::kReshape) { + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForReshape(instr->shape(), + FindOrDie(cache_, instr->operand(0)))); + } else { + computed_array = nullptr; } if (!computed_array) { @@ -136,12 +163,12 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor( return computed_array; } -Analysis::Array* IndexedArrayAnalysis::ComputeArrayForConstant( +StatusOr IndexedArrayAnalysis::ComputeArrayForConstant( const Literal& literal) { return Construct(&literal); } -ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather( +StatusOr IndexedArrayAnalysis::FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, tensorflow::gtl::ArraySlice output_dims, Shape shape) { // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). @@ -161,14 +188,14 @@ ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather( IndexComponent::Ungathered); // Simulate the first gather. - simulated_index.erase(simulated_index.begin() + source->source_dim()); + EraseAt(&simulated_index, source->source_dim()); for (int64 gather_dim : source->output_dims()) { simulated_index.insert(simulated_index.begin() + gather_dim, IndexComponent::GatheredFirst); } // Simulate the second gather. - simulated_index.erase(simulated_index.begin() + source_dim); + EraseAt(&simulated_index, source_dim); for (int64 output_dim : output_dims) { simulated_index.insert(simulated_index.begin() + output_dim, IndexComponent::GatheredSecond); @@ -207,7 +234,7 @@ ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather( std::move(shape)); } -Analysis::Array* IndexedArrayAnalysis::ComputeArrayForGather( +StatusOr IndexedArrayAnalysis::ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, tensorflow::gtl::ArraySlice window_bounds, Array* source, Array* indices) { @@ -244,6 +271,443 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForGather( shape); } +namespace { +// Returns an index into `values` such that the product of the range +// [values.begin()+index, values.end()) is equal to `product`. If there is no +// such index, return -1. All integers in `values` must be positive. +int64 FindSuffixWithProduct(ArraySlice values, int64 product) { + DCHECK(c_all_of(values, [](int64 value) { return value > 0; })); + + int64 current_product = 1; + int64 i; + for (i = values.size() - 1; i >= 0 && product > current_product; --i) { + current_product *= values[i]; + } + + if (product == current_product) { + return i + 1; + } + + return -1; +} + +struct ReshapePassthroughDimPair { + int64 result_dim; + int64 operand_dim; +}; + +// Returns a set of dimension pairs such for all (result_dim, operand_dim) in +// the set: +// +// output_index[result_dim] = SourceIndexOfReshape(output_index)[operand_dim] +// +// The returned vector of pairs is sorted in both the result_dim and the +// operand_dim components. +std::vector ComputeReshapePassthroughDimPairs( + ArraySlice operand_shape, ArraySlice result_shape) { + // A reshape can be seen as an index mapping from output index to input index: + // + // (i_0, ..., i_n) = f(o_0, ..., o_m) + // + // This function returns the pairs (j, k) for which the following invariant + // holds for all indices in the shape: + // + // o_j == i_k + // + // And this occurs when: + // + // O_{j+1} * ... * O_n == I_{k+1} * ... * I_m + // + // (where O_x are the sizes of the output shape and I_x are the sizes of the + // input shape) and the size of the dimension j of the result is the same as + // the size of dimension k in the operand. + // + // These conditions are sufficient because the Reshape HLO is spec'ed such + // that the rightmost dimensions are always minor in the flattening and refine + // operation. + + std::vector result; + int64 result_subarray_size = 1; + for (int64 result_dim = result_shape.size() - 1; result_dim >= 0; + --result_dim) { + int64 candidate_operand_dim = + FindSuffixWithProduct(operand_shape, result_subarray_size); + + // result_subarray_size does not include the elements in the current + // `result_dim` dimension (we multiply in result_shape[result_dim] at the + // end of loop body) so candidate_operand_dim can never be zero. + CHECK_NE(candidate_operand_dim, 0); + + if (candidate_operand_dim != -1 && + result_shape[result_dim] == operand_shape[candidate_operand_dim - 1]) { + result.push_back({/*result_dim=*/result_dim, + /*operand_dim=*/candidate_operand_dim - 1}); + } + result_subarray_size *= result_shape[result_dim]; + } + + c_reverse(result); + + if (VLOG_IS_ON(3)) { + std::vector result_strings; + c_transform(result, std::back_inserter(result_strings), + [](ReshapePassthroughDimPair value) { + return tensorflow::strings::StrCat(value.result_dim, "->", + value.operand_dim); + }); + VLOG(3) << "For a reshape from [" << Join(operand_shape, ",") << "] to [" + << Join(result_shape, ",") << "] passthrough indices are [" + << Join(result_strings, ",") << "]"; + } + + DCHECK(c_is_sorted( + result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { + return lhs.result_dim < rhs.result_dim; + })); + + DCHECK(c_is_sorted( + result, [](ReshapePassthroughDimPair lhs, ReshapePassthroughDimPair rhs) { + return lhs.operand_dim < rhs.operand_dim; + })); + + return result; +} + +// Return true if `dim` is stated as an passthrough operand dim in +// `passthrough_dims`. +bool IsReshapePassthroughOperandDim( + ArraySlice passthrough_dims, int64 dim) { + return c_any_of(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == dim; + }); +} + +// Maps `operand_dim` which must be an passthrough operand dimension to its +// corresponding passthrough result dimension based on `passthrough_dims`. +int64 MapPassthroughOperandDimToResultDim( + ArraySlice passthrough_dims, int64 operand_dim) { + auto it = c_find_if(passthrough_dims, + [&](ReshapePassthroughDimPair passthrough_dim_pair) { + return passthrough_dim_pair.operand_dim == operand_dim; + }); + CHECK(it != passthrough_dims.end()); + return it->result_dim; +} + +int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, + ArraySlice result_shape, + int64 source_passthrough_dim) { + int64 indexed_source_subarray_size = + std::accumulate(operand_shape.begin() + source_passthrough_dim + 1, + operand_shape.end(), 1, std::multiplies()); + + return FindSuffixWithProduct(result_shape, indexed_source_subarray_size); +} + +}; // namespace + +StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( + const Shape& shape, Array* operand) { + auto* scalar_indexed = dynamic_cast(operand); + if (!scalar_indexed) { + return nullptr; + } + + // Try to fold Reshape(ScalarIndexed(Const, Indices)) + // => ScalarIndexed(Const', Indices) + // + // We can view the reshape and the scalar-indexed operations as functions that + // map an output index (i.e. an index into the result) to an input index + // (i.e. an index into the operand). The key idea used here is that the + // output-to-input mapping for some reshape operations may "pass through" some + // output dimensions into the input space unchanged -- i.e. there may exist + // output dimension "O" and input dimension "I" such that OutputIndex[O] is + // always == InputIndexForReshape(OutputIndex)[I]. If these pass-through + // dimensions in the input space of the reshape happen to be include all the + // output dimensions for the scalar-indexed node then, roughly, the following + // holds: + // + // SourceIndexOfScalarIndexed(SourceIndexOfReshape(Idx)) + // == SourceIndexOfScalarIndexed(SourceIndexOfReshape(Ps ++ Qs)) + // + // Where Ps are the set of the pass-through components of Idx that are + // also the output dims of the scalar-indexed node, and Qs are the rest. + // For brevity, we're playing fast and loose with the notation here -- we + // don't literally require Idx to be a concatenation of Ps and Qs, as + // suggested by the "++". + // + // == SourceIndexOfScalarIndexed(Ps ++ SourceIndexOfReshape(Qs)) + // + // Again, we're playing fast and loose with the notation around "++". + // Generally this ++ will be a different function that the ++ in the + // previous step. + // + // If the scalar-indexed node has a constant as the source then the + // SourceIndexOfReshape function can be "folded into" the constant itself by + // reshaping it, leaving us with: + // + // == SourceIndexOfScalarIndexed(Ps ++ Qs) + // == SourceIndexOfScalarIndexed(Idx) + // + // which is just a scalar-indexed node (with parameters different from the + // scalar-indexed node we started with) with a reshaped constant as the + // source. + // + // We can't fold SourceIndexOfReshape into the constant without introducing + // another precondition: since the new scalar-indexed node will have a + // reshaped (constant) array as its source it will, in general, have a + // different source dimension than the original scalar-indexed node. This + // source dimension will have to be a passthrough dimension of the + // SourceIndexOfReshape indexing function that is folded into the source. And + // such a dimension need not exist so this is a non-trivial precondition. + + std::vector reshape_passthrough_dims = + ComputeReshapePassthroughDimPairs( + /*operand_shape=*/AsInt64Slice(operand->shape().dimensions()), + /*result_shape=*/AsInt64Slice(shape.dimensions())); + + auto is_reshape_passthrough_operand_dim = [&](int64 operand_dim) { + return IsReshapePassthroughOperandDim(reshape_passthrough_dims, + operand_dim); + }; + + if (!c_all_of(scalar_indexed->output_dims(), + is_reshape_passthrough_operand_dim)) { + return nullptr; + } + + // To compute the shape of the source for the new scalar-indexed node we're + // going to create, we first "undo" the scalar-indexed operation. + std::vector new_scalar_indexed_source_shape(shape.dimensions().begin(), + shape.dimensions().end()); + for (int64 i = scalar_indexed->output_dims().size() - 1; i >= 0; i--) { + int64 output_dim = scalar_indexed->output_dims()[i]; + int64 output_dim_after_reshape = MapPassthroughOperandDimToResultDim( + reshape_passthrough_dims, output_dim); + EraseAt(&new_scalar_indexed_source_shape, output_dim_after_reshape); + } + + // After this, we need to add in the dimension that will be the source + // dimension for the new scalar-indexed node. A scalar-indexed node "removes" + // the source dimensions and "adds" the output dimensions, so to get back to + // the shape for the *source* of the scalar-indexed node we need to remove the + // output dims (which we did above) and then add back the source dim (which we + // are about to do below): + + const Shape& scalar_indexed_source_shape = scalar_indexed->source()->shape(); + + int64 source_dim_for_new_scalar_indexed_node = + FindSourcePositionForPassthroughResultDim( + /*operand_shape=*/AsInt64Slice( + scalar_indexed_source_shape.dimensions()), + /*result_shape=*/new_scalar_indexed_source_shape, + scalar_indexed->source_dim()); + + // We may not be able to find a source dim for the new scalar-indexed node. + // For instance consider: + // + // operand = s32[3,5,2] constant({...}) + // indices = s32[7] parameter(0) + // gather = s32[3,2,7] gather(operand, indices), + // output_window_dims={0,1}, + // elided_window_dims={1}, + // gather_dims_to_operand_dims={1}, + // index_vector_dim=1, + // window_bounds={3,1,2} + // reshape = s32[6,7] reshape(gather) + // + // In this case the gather maps to: + // (scalar-indexed-const (constant s32[3,5,2]) %indices 1->[2]) + // + // and the reshape passes through dimension 2 from its input into dimension 1 + // in its output. However, we can't rewrite the reshape as a scalar-indexed + // node because then we'd have to reshape the [3,5,2] `operand` array to + // [6,5], but then dimension 1 of the reshaped [6,5] array indexes differently + // (a.k.a. isn't pass-through) than the [3,5,2] array. + + if (source_dim_for_new_scalar_indexed_node == -1) { + return nullptr; + } + + InsertAt( + &new_scalar_indexed_source_shape, source_dim_for_new_scalar_indexed_node, + scalar_indexed_source_shape.dimensions(scalar_indexed->source_dim())); + + CHECK(IsReshapePassthroughOperandDim( + ComputeReshapePassthroughDimPairs( + /*operand_shape=*/AsInt64Slice( + scalar_indexed_source_shape.dimensions()), + /*result_shape=*/new_scalar_indexed_source_shape), + scalar_indexed->source_dim())); + + auto map_passthrough_operand_dim_to_result_dim = [&](int64 result_dim) { + return MapPassthroughOperandDimToResultDim(reshape_passthrough_dims, + result_dim); + }; + + std::vector output_dims_for_new_scalar_indexed_node; + c_transform(scalar_indexed->output_dims(), + std::back_inserter(output_dims_for_new_scalar_indexed_node), + map_passthrough_operand_dim_to_result_dim); + + TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal, + TakeOwnership(scalar_indexed->literal().Reshape( + new_scalar_indexed_source_shape))); + TF_ASSIGN_OR_RETURN( + Array * new_scalar_indexed_source, + ComputeArrayForConstant(*new_scalar_indexed_source_literal)); + + return ConstructScalarIndexedArray( + new_scalar_indexed_source, scalar_indexed->indices(), + source_dim_for_new_scalar_indexed_node, + output_dims_for_new_scalar_indexed_node, shape); +} + +StatusOr +IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, + Array* lhs, + Array* rhs) { + // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices)) + // => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices) + // + // We can do this if every output dimension from the scalar-indexed node is a + // broadcasted dimension for the broadcast node. Informally, the precondition + // means Broadcast(Const0)[IDX] is solely a function of the components of IDX + // that are not output-dims for the scalar-indexed node. In other words, for + // every assignment to the non-output dims in IDX we have a "constant" LHS to + // the BinaryOp. This transform propagates this "constant" to the source for + // the scalar-indexed node. + + ScalarIndexedConstantArray* lhs_scalar_indexed_const = + dynamic_cast(lhs); + ScalarIndexedConstantArray* rhs_scalar_indexed_const = + dynamic_cast(rhs); + + bool lhs_is_indexed; + + // One of the operands must be scalar-indexed and the other must be a + // broadcast of a constant. + if (lhs_scalar_indexed_const && !rhs_scalar_indexed_const) { + lhs_is_indexed = true; + } else if (rhs_scalar_indexed_const && !lhs_scalar_indexed_const) { + lhs_is_indexed = false; + } else { + return nullptr; + } + + ScalarIndexedConstantArray* scalar_indexed_const = + lhs_is_indexed ? lhs_scalar_indexed_const : rhs_scalar_indexed_const; + UnknownArray* candidate_broadcast_array = + dynamic_cast(lhs_is_indexed ? rhs : lhs); + if (!candidate_broadcast_array || + candidate_broadcast_array->instruction().opcode() != + HloOpcode::kBroadcast) { + return nullptr; + } + + const HloInstruction* broadcast_instr = + &candidate_broadcast_array->instruction(); + const HloInstruction* broadcast_const_operand = broadcast_instr->operand(0); + if (broadcast_const_operand->opcode() != HloOpcode::kConstant) { + return nullptr; + } + + ArraySlice broadcast_dims = broadcast_instr->dimensions(); + auto is_broadcasted_dim = [&](int64 output_dim) { + return c_find(broadcast_dims, output_dim) == broadcast_dims.end(); + }; + + // All of the output dims must be "broadcasted" dims for the other operand. + if (!c_all_of(scalar_indexed_const->output_dims(), is_broadcasted_dim)) { + return nullptr; + } + + // To figure out the broadcast dimensions for the (constant) source for the + // scalar-indexed node, we "simulate" the index transformation done by the + // existing broadcsat: + enum class IndexComponent { Broadcasted, NotBroadcasted }; + std::vector simulated_index( + broadcast_instr->shape().dimensions_size(), IndexComponent::Broadcasted); + for (int64 broadcast_dim : broadcast_dims) { + simulated_index[broadcast_dim] = IndexComponent::NotBroadcasted; + } + + // The scalar-indexed node "removes" the source dim and "inserts" the output + // dims. We do the opposite here to undo the scalar-indexed operation. + ArraySlice output_dims = scalar_indexed_const->output_dims(); + for (int64 i = output_dims.size() - 1; i >= 0; --i) { + CHECK(simulated_index[output_dims[i]] == IndexComponent::Broadcasted); + EraseAt(&simulated_index, output_dims[i]); + } + + InsertAt(&simulated_index, scalar_indexed_const->source_dim(), + IndexComponent::Broadcasted); + + // new_inner_broadcast_dims holds the broadcast dimensions for the inner + // BinaryOp(Broadcast'(Const0), Const1). We now translate simulated_index to + // new_inner_broadcast_dims. + std::vector new_inner_broadcast_dims; + for (int64 i = 0; i < simulated_index.size(); i++) { + if (simulated_index[i] == IndexComponent::NotBroadcasted) { + new_inner_broadcast_dims.push_back(i); + } + } + + // inner_broadcast_result is the Broadcast'(Const0) bit in + // BinaryOp(Broadcast'(Const0), Const1) + TF_ASSIGN_OR_RETURN( + std::unique_ptr inner_broadcast_result, + broadcast_const_operand->literal().Broadcast( + scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); + + // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1) + const Literal* literal_for_new_source; + if (lhs_is_indexed) { + TF_ASSIGN_OR_RETURN( + literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( + opcode, scalar_indexed_const->literal(), *inner_broadcast_result))); + } else { + TF_ASSIGN_OR_RETURN( + literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( + opcode, *inner_broadcast_result, scalar_indexed_const->literal()))); + } + + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, scalar_indexed_const->indices(), + scalar_indexed_const->source_dim(), + std::vector(scalar_indexed_const->output_dims().begin(), + scalar_indexed_const->output_dims().end()), + scalar_indexed_const->shape()); +} + +StatusOr +IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, + Array* operand) { + auto* scalar_indexed_const = + dynamic_cast(operand); + if (scalar_indexed_const == nullptr) { + return nullptr; + } + + // Fold UnaryOp(ScalarIndexed(Const, Indices)) + // => ScalarIndexed(UnaryOp(Const), Indices) + + TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp( + opcode, scalar_indexed_const->literal()))); + ConstantArray* new_source = Construct(literal_for_new_source); + return Construct( + new_source, scalar_indexed_const->indices(), + scalar_indexed_const->source_dim(), + std::vector(scalar_indexed_const->output_dims().begin(), + scalar_indexed_const->output_dims().end()), + scalar_indexed_const->shape()); +} + tensorflow::StringPiece IndexedArrayAnalysisPrinterPass::name() const { return "indexed-array-analysis-printer-pass"; } @@ -256,7 +720,7 @@ StatusOr IndexedArrayAnalysisPrinterPass::Run(HloModule* module) { IndexedArrayAnalysis analysis; for (auto* computation : module->MakeNonfusionComputations()) { for (auto* instr : computation->instructions()) { - auto* t = analysis.GetArrayFor(instr); + TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr)); if (!dynamic_cast(t) && !dynamic_cast(t)) { VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t); } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index b132a8f25153d2e86e8aa477fdb851f1c9c8e719..ce92fd2919c90fa8a2fb7b796ed6f0fdaf48fe62 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -143,8 +143,8 @@ class IndexedArrayAnalysis { // // For example, if source is of shape [11,13,17,19], indices is of shape // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of - // shape [23,11,29,19] and the output index [A,B,C,D,E] is mapped to the input - // index [B,D,indices[A,C],E]. + // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the + // input index [B,D,indices[A,C],E]. class ScalarIndexedArray : public Array { public: Kind kind() const override { return kScalarIndexed; } @@ -152,7 +152,15 @@ class IndexedArrayAnalysis { Array* source() const { return source_; } Array* indices() const { return indices_; } + + // `source_dim` is the dimension in the source array that is being indexed + // over using indices from the `indices` array. See the class documentation + // and the overview for more details. int64 source_dim() const { return source_dim_; } + + // `output_dims` are the dimensions in the output array that are being used + // to compute an index into the `indices` array. See the class + // documentation and the overview for more details. tensorflow::gtl::ArraySlice output_dims() const { return output_dims_; } @@ -212,26 +220,26 @@ class IndexedArrayAnalysis { // NB! By inspecting the implementation, you may be able to infer a stronger // caching guarantee than what is mentioned above. Nevertheless, what is // stated above is the contract. - Array* GetArrayFor(const HloInstruction* instr); + StatusOr GetArrayFor(const HloInstruction* instr); // Pretty-prints the expression rooted at `root`. - string ToString(Array* root); + string ToString(Array* root, bool print_constants = false); private: // Helper function that ensures that every HLO instruction that is // transitively used by `root` has an entry in `cache_`. - void TraverseAndPopulateCache(const HloInstruction* root); + Status TraverseAndPopulateCache(const HloInstruction* root); // Creates an Array instance for `instr` under the assumption that all // operations of `instr` are present in `cache_`. - Array* ComputeArrayFor(const HloInstruction* instr); + StatusOr ComputeArrayFor(const HloInstruction* instr); - Array* ComputeArrayForConstant(const Literal& literal); + StatusOr ComputeArrayForConstant(const Literal& literal); - Array* ComputeArrayForGather(const Shape& shape, - const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice window_bounds, - Array* source, Array* indices); + StatusOr ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, Array* source, + Array* indices); // This tries to fold a ScalarIndexedArray which has another // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a @@ -254,10 +262,17 @@ class IndexedArrayAnalysis { // // I2 = [I0[i] for i in I1] // G1 = [Arr[i] for i in I2] - ScalarIndexedArray* FoldGatherOfGather( + StatusOr FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, tensorflow::gtl::ArraySlice output_dims, Shape shape); + StatusOr ComputeArrayForReshape(const Shape& shape, Array* operand); + + StatusOr ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, + Array* lhs, Array* rhs); + StatusOr ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, + Array* operand); + template T* Construct(Args&&... args) { T* new_tensor = new T(std::forward(args)...); @@ -279,6 +294,19 @@ class IndexedArrayAnalysis { } } + Literal* TakeOwnership(std::unique_ptr literal) { + owned_literals_.push_back(std::move(literal)); + return owned_literals_.back().get(); + } + + StatusOr TakeOwnership( + StatusOr> literal_or_error) { + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + std::move(literal_or_error)); + owned_literals_.push_back(std::move(literal)); + return owned_literals_.back().get(); + } + std::vector> owned_tensors_; std::vector> owned_literals_; tensorflow::gtl::FlatMap cache_; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index b2731b7c51a45c4f9b713d99ef3e4623ad2c9c83..373556ebeba883f7dc2116bdf0ffc3274182f775 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -23,14 +23,31 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase { protected: void AssertArrayForRootExpressionIs(const string& hlo_text, const string& root_expression) { + AssertArrayForRootExpressionIsImpl(hlo_text, root_expression, + /*print_constants=*/false); + } + + void AssertArrayWithConstantsForRootExpressionIs( + const string& hlo_text, const string& root_expression) { + AssertArrayForRootExpressionIsImpl(hlo_text, root_expression, + /*print_constants=*/true); + } + + private: + void AssertArrayForRootExpressionIsImpl(const string& hlo_text, + const string& root_expression, + bool print_constants) { IndexedArrayAnalysis indexed_tensor_analysis; ParseAndVerifyModule(hlo_text); - string result = - indexed_tensor_analysis.ToString(indexed_tensor_analysis.GetArrayFor( + TF_ASSERT_OK_AND_ASSIGN( + IndexedArrayAnalysis::Array* const array_result, + indexed_tensor_analysis.GetArrayFor( module().entry_computation()->root_instruction())); - LOG(INFO) << result; - ASSERT_EQ(result, root_expression); + string string_result = + indexed_tensor_analysis.ToString(array_result, print_constants); + LOG(INFO) << string_result; + ASSERT_EQ(string_result, root_expression); } }; @@ -187,5 +204,301 @@ ENTRY main { "(scalar-indexed %operand (scalar-indexed %indices_a %indices_b " "1->[0,2]) 1->[0,1,3])"); } + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather0) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5] parameter(0) + gather = s32[5,4] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT reshape = s32[5,2,2] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather1) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5,7] parameter(0) + gather = s32[5,4,7] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4} + ROOT reshape = s32[5,2,2,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,2,2]) %indices 0->[0,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather2) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,2,6] constant(s32[3,2,6]{ + {{1,2,3,4,5,6},{1,2,3,4,5,6}}, + {{1,2,3,4,5,6},{1,2,3,4,5,6}}, + {{1,2,3,4,5,6},{1,2,3,4,5,6}}}) + indices = s32[5,7] parameter(0) + gather = s32[5,2,6,7] gather(operand, indices), + output_window_dims={1,2}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,2,6} + ROOT reshape = s32[5,3,4,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs( + hlo_text, + "(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative0) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,2,3,4},{1,2,3,4}}) + indices = s32[5,6] parameter(0) + gather = s32[5,4,6] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=2, + window_bounds={1,4} + ROOT reshape = s32[5,2,2,2,3] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%reshape"); +} + +TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative1) { + string hlo_text = R"( +HloModule ReshapeOfGather + +ENTRY main { + operand = s32[3,5,2] constant(s32[3,5,2]{ + {{1,2},{3,4},{5,6},{7,8},{9,10}}, + {{1,2},{3,4},{5,6},{7,8},{9,10}}, + {{1,2},{3,4},{5,6},{7,8},{9,10}}}) + indices = s32[7] parameter(0) + gather = s32[3,2,7] gather(operand, indices), + output_window_dims={0,1}, + elided_window_dims={1}, + gather_dims_to_operand_dims={1}, + index_vector_dim=1, + window_bounds={3,1,2} + ROOT reshape = s32[6,7] reshape(gather) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%reshape"); +} + +TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) { + string hlo_text = R"( +HloModule UnaryOpOfGather + +ENTRY main { + operand = f32[3,4] constant(f32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + indices = s32[5] parameter(0) + gather = f32[5,4] gather(operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT tanh = f32[5,4] tanh(gather) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant f32[3,4] f32[3,4] { + { 0.761594176, 0.964027584, 0.995054781, 0.999329329 }, + { 0.761594176, 0.995054781, 0.964027584, 0.999329329 }, + { 0.999329329, 0.995054781, 0.964027584, 0.761594176 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedScalarWithGather) { + string hlo_text = R"( +HloModule AddBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 6, 7, 8, 9 }, + { 6, 8, 7, 9 }, + { 9, 8, 7, 6 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, + SubtractBroadcastedScalarWithGather_GatherIsLhs) { + string hlo_text = R"( +HloModule SubtractBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT sub = s32[5,4] subtract(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { -4, -3, -2, -1 }, + { -4, -2, -3, -1 }, + { -1, -2, -3, -4 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, + SubtractBroadcastedScalarWithGather_GatherIsRhs) { + string hlo_text = R"( +HloModule SubtractBroadcastedScalarWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant = s32[] constant(5) + constant_broadcasted = s32[5,4] broadcast(constant), dimensions={} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT sub = s32[5,4] subtract(constant_broadcasted, gather) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 4, 3, 2, 1 }, + { 4, 2, 3, 1 }, + { 1, 2, 3, 4 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather) { + string hlo_text = R"( +HloModule AddBroadcastedVectorWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant_vect = s32[4] constant({10,11,12,13}) + constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={1} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayWithConstantsForRootExpressionIs(hlo_text, 1 + R"( +(scalar-indexed-const (constant s32[3,4] s32[3,4] { + { 11, 13, 15, 17 }, + { 11, 14, 14, 17 }, + { 14, 14, 14, 14 } +}) %indices 0->[0]))"); +} + +TEST_F(IndexedArrayAnalysisTest, AddBroadcastedVectorWithGather_Negative) { + string hlo_text = R"( +HloModule AddBroadcastedVectorWithGather + +ENTRY main { + gather_operand = s32[3,4] constant(s32[3,4]{{1,2,3,4},{1,3,2,4},{4,3,2,1}}) + constant_vect = s32[5] constant({10,11,12,13,14}) + constant_broadcasted = s32[5,4] broadcast(constant_vect), dimensions={0} + indices = s32[5] parameter(0) + gather = s32[5,4] gather(gather_operand, indices), + output_window_dims={1}, + elided_window_dims={0}, + gather_dims_to_operand_dims={0}, + index_vector_dim=1, + window_bounds={1,4} + ROOT add = s32[5,4] add(gather, constant_broadcasted) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%add"); +} + +TEST_F(IndexedArrayAnalysisTest, RegularUnaryOp) { + string hlo_text = R"( +HloModule RegularUnaryOp + +ENTRY main { + input = f32[100] parameter(0) + ROOT tanh = f32[100] tanh(input) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%tanh"); +} + +TEST_F(IndexedArrayAnalysisTest, RegularBinaryOp) { + string hlo_text = R"( +HloModule RegularUnaryOp + +ENTRY main { + input0 = f32[100] parameter(0) + input1 = f32[100] parameter(1) + ROOT add = f32[100] add(input0, input1) +} +)"; + + AssertArrayForRootExpressionIs(hlo_text, "%add"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index cb6c98c48171a06539499b723a8d8b7aa0ccc96a..abedb4063d3763516e66cff36633dbd90c8cafde 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -96,6 +96,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: + case HloOpcode::kGenerateToken: case HloOpcode::kTranspose: case HloOpcode::kTuple: return false; @@ -118,6 +119,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kDivide: + case HloOpcode::kDomain: case HloOpcode::kDot: case HloOpcode::kExp: case HloOpcode::kExpm1: @@ -178,8 +180,7 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) { bool InstructionFusion::CanFuseOnAllPaths( HloInstruction* producer, HloInstruction* consumer, - const HloReachabilityMap& reachability_map, - const DoNotFuseSet& do_not_fuse) { + const HloInstructionSet& do_not_duplicate) { if (consumer == producer) { return true; } @@ -190,10 +191,11 @@ bool InstructionFusion::CanFuseOnAllPaths( auto* consumer_operand = consumer->mutable_operand(i); // If the operand is not on a path to the producer, it doesn't matter // whether it's fusable. - if (!reachability_map.IsReachable(producer, consumer_operand)) { + if (!reachability_->IsReachable(producer, consumer_operand)) { continue; } - if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) { + if (do_not_duplicate.count(consumer_operand) > 0 || + !ShouldFuse(consumer, i)) { return false; } // The producer is reachable from consumer_operand which means we need @@ -201,18 +203,16 @@ bool InstructionFusion::CanFuseOnAllPaths( // producer to be fusable into consumer on all paths. // Perform the recursive step: make sure producer can be fused into // consumer_operand on all paths. - if (!CanFuseOnAllPaths(producer, consumer_operand, reachability_map, - do_not_fuse)) { + if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_duplicate)) { return false; } } return true; } -InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable( +InstructionFusion::HloInstructionSet +InstructionFusion::ComputeGloballyUnfusable( tensorflow::gtl::ArraySlice post_order) { - auto reachability = computation_->ComputeReachability(); - // Forbid fusion of producers that: // a) Need to be duplicated, unless they can be fused into all consumers // via all paths. @@ -222,10 +222,10 @@ InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable( // Note that if we allow fusion by these global rules, we may still forbid // fusing operations that require duplication later depending on // is_expensive_(). - DoNotFuseSet do_not_fuse; + HloInstructionSet do_not_duplicate; for (HloInstruction* consumer : post_order) { for (HloInstruction* producer : consumer->operands()) { - if (do_not_fuse.count(producer) > 0) { + if (do_not_duplicate.count(producer) > 0) { continue; } @@ -254,14 +254,14 @@ InstructionFusion::DoNotFuseSet InstructionFusion::ComputeGloballyUnfusable( // A will be not allowed to be fused into B, as it cannot be fused via // all paths. if (producer->IsFusable() && - CanFuseOnAllPaths(producer, consumer, *reachability, do_not_fuse)) { + CanFuseOnAllPaths(producer, consumer, do_not_duplicate)) { continue; } - do_not_fuse.insert(producer); + do_not_duplicate.insert(producer); } } - return do_not_fuse; + return do_not_duplicate; } StatusOr InstructionFusion::Run(HloModule* module) { @@ -273,6 +273,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { for (auto* computation : module->MakeNonfusionComputations()) { CHECK(!computation->IsFusionComputation()); computation_ = computation; + reachability_ = computation_->ComputeReachability(); // We want to be able to remove arbitrary instructions from the post order // and also compare positions of instructions in the post order. To make @@ -290,7 +291,7 @@ StatusOr InstructionFusion::Run(HloModule* module) { InsertOrDie(&post_order_index, post_order[i], i); } - DoNotFuseSet do_not_fuse = ComputeGloballyUnfusable(post_order); + HloInstructionSet do_not_duplicate = ComputeGloballyUnfusable(post_order); // Instruction fusion effectively fuses edges in the computation graph // (producer instruction -> consumer instruction) so we iterate over all @@ -358,9 +359,20 @@ StatusOr InstructionFusion::Run(HloModule* module) { // ensures that B will be considered before A. // // We store the original indices of the operands to pass to ShouldFuse. - std::vector sorted_operand_numbers(instruction->operands().size()); - std::iota(std::begin(sorted_operand_numbers), - std::end(sorted_operand_numbers), 0); + std::vector sorted_operand_numbers; + sorted_operand_numbers.reserve(instruction->operands().size()); + for (int i = 0; i < instruction->operands().size(); ++i) { + // This will happen if we have two possible instructions to fuse the + // same operand into; once the operand is fused into one instruction, + // the other instruction will get a new get-tuple-element as its + // operand, which is not in the post-order index. + // TODO(tjoerg): Look into fusing past these multi-output fuse points. + if (post_order_index.find(instruction->mutable_operand(i)) == + post_order_index.end()) { + continue; + } + sorted_operand_numbers.push_back(i); + } std::sort( sorted_operand_numbers.begin(), sorted_operand_numbers.end(), [&](int64 i, int64 j) { @@ -377,13 +389,20 @@ StatusOr InstructionFusion::Run(HloModule* module) { if (!operand->IsFusable()) { continue; } - if (!ShouldFuse(instruction, i)) { - continue; - } - if (do_not_fuse.count(operand) > 0) { + + HloInstruction* fusion_instruction; + // Try "regular" fusion if the operand may be duplicated. Otherwise, + // perform multi-output fusion, unless this creates a cycle. + // TODO(tjoerg): Consider making multi-output fusion the default. + if (ShouldFuse(instruction, i) && + do_not_duplicate.count(operand) == 0) { + fusion_instruction = Fuse(operand, instruction); + } else if (ShouldFuseIntoMultiOutput(instruction, i) && + !MultiOutputFusionCreatesCycle(operand, instruction)) { + fusion_instruction = FuseIntoMultiOutput(operand, instruction); + } else { continue; } - HloInstruction* fusion_instruction = Fuse(operand, instruction); // Fusing an instruction into a fusion instruction can change the // operand set of the fusion instruction. For simplicity just push the @@ -449,6 +468,19 @@ HloInstruction* InstructionFusion::FuseIntoMultiOutput( return fusion_instruction; } +bool InstructionFusion::MultiOutputFusionCreatesCycle( + HloInstruction* producer, HloInstruction* consumer) { + return c_any_of( + consumer->operands(), [&](const HloInstruction* consumer_operand) { + // The fusion algorithm traverses the HLO graph in reverse post order. + // Thus `cosumers` is visited before its operands (including + // `producer`). Therefore, consumer operands cannot have been fused yet. + // It is thus safe to use the pre-computed reachability map. + return consumer_operand != producer && + reachability_->IsReachable(producer, consumer_operand); + }); +} + bool InstructionFusion::ShouldFuse(HloInstruction* consumer, int64 operand_index) { HloInstruction* producer = consumer->mutable_operand(operand_index); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index c3c2ed0aaa81d6f346ec6e70d9c8b3b923e0a3d2..f73ca9adf768ed26f9ec9f162e01b7b160f50daf 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -61,6 +61,14 @@ class InstructionFusion : public HloPassInterface { // Subtypes can override this with target-specific heuristics. virtual bool ShouldFuse(HloInstruction* consumer, int64 operand_index); + // Returns whether multi-output fusion can be applied to fuse `producer` into + // `consumer`. In contrast to "regular" fusion, the `producer` is not + // duplicated by multi-output fusion. + virtual bool ShouldFuseIntoMultiOutput(HloInstruction* consumer, + int64 operand_index) { + return false; + } + // Chooses a fusion kind for `producer` and `consumer`. // Default method chooses `kLoop`. virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, @@ -97,10 +105,12 @@ class InstructionFusion : public HloPassInterface { // Current HloComputation instance the loop fuser is traversing. HloComputation* computation_; HloModule* module_; + // Reachability information for the current computation. + std::unique_ptr reachability_; private: // The set of producers whose consumers we cannot fuse into. - using DoNotFuseSet = std::unordered_set; + using HloInstructionSet = std::unordered_set; HloInstruction* AddFusionInstruction(HloInstruction* producer, HloInstruction* consumer); @@ -108,18 +118,21 @@ class InstructionFusion : public HloPassInterface { // Whether or not we can fuse producer into consumer on all paths // from the producer to the consumer where nodes are HLOs and edges are uses. bool CanFuseOnAllPaths(HloInstruction* producer, HloInstruction* consumer, - const HloReachabilityMap& reachability_map, - const DoNotFuseSet& do_not_fuse); + const HloInstructionSet& do_not_fuse); // Computes the set of nodes that we do not want to fuse into any of their // consumers based on a global analysis of the HLO graph. - DoNotFuseSet ComputeGloballyUnfusable( + HloInstructionSet ComputeGloballyUnfusable( tensorflow::gtl::ArraySlice post_order); // Used to determine if an HLO is expensive. Expensive operations will not be // duplicated. std::function is_expensive_; + // Whether multi-output fusion would introduce a cycle into the HLO graph. + bool MultiOutputFusionCreatesCycle(HloInstruction* producer, + HloInstruction* consumer); + // Returns whether we may duplicate an instruction if we want to fuse it. bool may_duplicate_; diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index df109df7877eefe4c337f93cc5a3a7a48e2e76c7..21db2338995960bde00ec9c4b325e5562fc3a592 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { @@ -47,7 +47,7 @@ class InstructionFusionForTesting : public InstructionFusion { }; TEST_F(InstructionFusionTest, FuseInstructions) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY entry_computation { p0 = f32[4,3]{1,0} parameter(0) @@ -67,7 +67,7 @@ TEST_F(InstructionFusionTest, FuseInstructions) { } TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module fused_computation { p1 = f32[4,3] parameter(0) @@ -90,7 +90,7 @@ TEST_F(InstructionFusionTest, FuseIntoFusionInstruction) { } TEST_F(InstructionFusionTest, FuseInstructionsIntoMultiOutput) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY entry_computation { p0 = f32[4,3]{1,0} parameter(0) @@ -195,7 +195,7 @@ static int Count(const HloModule& module, HloOpcode op) { } TEST_F(InstructionFusionTest, FuseCheapNonDuplicatableOps) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -220,7 +220,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // // p0 -> add -------------------------> sub // \-> abs1 -> rng -> abs2 -/ - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -251,7 +251,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // p0 -> add -------------------------> sub // \-> abs1 -> log -> abs2 -/ // \-> send - module = tools::Parse(R"( + module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -282,7 +282,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // \ \-> add2 -/ // \-> log -/ // \-> send - module = tools::Parse(R"( + module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -314,7 +314,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusableRecursively) { // \------> sub1 // log -/ // \-> send - module = tools::Parse(R"( + module = ParseHloString(R"( HloModule test_module ENTRY OutputFusion { p0 = f32[4,3]{1,0} parameter(0) @@ -390,7 +390,7 @@ TEST_F(InstructionFusionTest, AllowEffectiveUnaryDuplication) { TEST_F(InstructionFusionTest, WideningConvertsAreAlwaysDuplicableIntoConsumers) { - auto module = tools::Parse(R"( + auto module = ParseHloString(R"( HloModule test_module ENTRY Test { p0 = f16[100] parameter(0) diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 7508013199a82267efc0e1426cb5989d5fe844a0..bf0448a67674f24591d866b646b98aea09ebb12c 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -29,13 +29,13 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -651,7 +651,7 @@ TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { } )"; - auto module = tools::Parse(module_str).ValueOrDie(); + auto module = ParseHloString(module_str).ValueOrDie(); module = backend() @@ -691,7 +691,7 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { } )"; - auto module = tools::Parse(module_str).ValueOrDie(); + auto module = ParseHloString(module_str).ValueOrDie(); ComputationLayout computation_layout( module->entry_computation()->ComputeProgramShape()); Shape param_shape = ShapeUtil::MakeTupleShape( diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index f172b1d87c870270436f7301ed200b47d08431a7..d909845a3a21fc55e44b0037371fca30e577980f 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -80,8 +80,10 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, /*Name=*/""); + llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast( + global, llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo()); generators_[constant] = [=](const IrArray::Index& index) { - return IrArray(global, constant->shape()) + return IrArray(shape_constant, constant->shape()) .EmitReadArrayElement(index, ir_builder_); }; diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 23d2d4e87d26f4988ebddcf20f5a27af6a7fe0d6..1f6e3c829f890d68aa251b101f0402c120a19d61 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -15,53 +15,57 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { -void KernelSupportLibrary::For( +Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, - const std::function& for_body_generator) { - If(ir_builder_->CreateICmpSLT(start, end), [&]() { - for_body_generator(start, /*is_first_iteration=*/true); - For(name, ir_builder_->CreateAdd(start, step), end, step, - [&](llvm::Value* iv) { for_body_generator(iv, false); }); + const std::function& for_body_generator) { + return If(ir_builder_->CreateICmpSLT(start, end), [&]() -> Status { + TF_RETURN_IF_ERROR(for_body_generator(start, /*is_first_iteration=*/true)); + return For(name, ir_builder_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { return for_body_generator(iv, false); }); }); } -void KernelSupportLibrary::For( +Status KernelSupportLibrary::For( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, bool peel_first_iteration, - const std::function& for_body_generator) { + const std::function& + for_body_generator) { if (peel_first_iteration) { - For(name, start, end, step, true, - [&](llvm::Value* indvar, bool is_first_iteration) { - for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration)); - }); + return For(name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) -> Status { + return for_body_generator( + indvar, ir_builder_->getInt1(is_first_iteration)); + }); } else { std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( name, start, end, step, ir_builder_, - /*prevent_unrolling=*/prevent_unrolling_, + /*unroll_mode=*/unroll_mode_, /*prevent_vectorization=*/prevent_vectorization_); ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); - for_body_generator(loop->GetIndVarValue(), - /*is_first_iteration=*/ir_builder_->CreateICmpEQ( - loop->GetIndVarValue(), start)); + TF_RETURN_IF_ERROR( + for_body_generator(loop->GetIndVarValue(), + /*is_first_iteration=*/ir_builder_->CreateICmpEQ( + loop->GetIndVarValue(), start))); llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_); + return Status::OK(); } } -void KernelSupportLibrary::If( - llvm::Value* condition, const std::function& true_block_generator, - const std::function& false_block_generator) { +Status KernelSupportLibrary::If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator) { llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(condition, "", ir_builder_); ir_builder_->SetInsertPoint(&if_data.true_block->back()); - true_block_generator(); + TF_RETURN_IF_ERROR(true_block_generator()); ir_builder_->SetInsertPoint(&if_data.false_block->back()); - false_block_generator(); + TF_RETURN_IF_ERROR(false_block_generator()); llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); + return Status::OK(); } void KernelSupportLibrary::EmitAndCallOutlinedKernel( diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 64b935bbf1fb9033cd2e1259b4639cd3780be711..e17c649e5272a9e7c0d5126083ab76542abfdf48 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -30,13 +31,14 @@ namespace xla { class KernelSupportLibrary { public: // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR. - // If `prevent_unrolling` is true then unrolling is explicitly disabled on - // every loop generated by this instance of KernelSupportLibrary. - explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = true, - bool prevent_vectorization = true) + // `unroll_mode` specifies the desired LLVM unrolling behavior for every loop + // generated by this instance of KernelSupportLibrary. + explicit KernelSupportLibrary( + llvm::IRBuilder<>* ir_builder, + llvm_ir::UnrollMode unroll_mode = llvm_ir::UnrollMode::kNoUnroll, + bool prevent_vectorization = true) : ir_builder_(ir_builder), - prevent_unrolling_(prevent_unrolling), + unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} // Generates the following control flow structure: @@ -46,19 +48,41 @@ class KernelSupportLibrary { // for (i64 i = `start` + `step`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; // } - void For( + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator); + + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& - for_body_generator); + for_body_generator) { + CHECK_EQ(Status::OK(), + For(name, start, end, step, + [&](llvm::Value* ind_var, bool is_first_iteration) -> Status { + for_body_generator(ind_var, is_first_iteration); + return Status::OK(); + })); + } + + Status For(tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& + for_body_generator) { + return For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } - void For( + void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } // Generates the following control flow structure if `peel_first_iteration` is @@ -75,46 +99,101 @@ class KernelSupportLibrary { // for (i64 i = `start`; i s< `end`; i += `step`) // `for_body_generator(/*ind_var=*/,i, // /*is_first_iteration=*/,(i != `start`))`; - void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - llvm::Value* step, bool peel_first_iteration, - const std::function& + Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); + + void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + llvm::Value* end, llvm::Value* step, + bool peel_first_iteration, + const std::function& + for_body_generator) { + TF_CHECK_OK(For( + name, start, end, step, peel_first_iteration, + [&](llvm::Value* ind_var, llvm::Value* is_first_iteration) -> Status { + for_body_generator(ind_var, is_first_iteration); + return Status::OK(); + })); + } + + Status For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + return For(name, /*start=*/start, /*end=*/end, + /*step=*/ir_builder_->getInt64(step), peel_first_iteration, for_body_generator); + } + + void ForReturnVoid(tensorflow::StringPiece name, llvm::Value* start, + llvm::Value* end, int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + ForReturnVoid(name, /*start=*/start, /*end=*/end, + /*step=*/ir_builder_->getInt64(step), peel_first_iteration, + for_body_generator); + } - void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, - int64 step, bool peel_first_iteration, - const std::function& - for_body_generator) { - For(name, /*start=*/start, /*end=*/end, - /*step=*/ir_builder_->getInt64(step), peel_first_iteration, - for_body_generator); + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator) { + return For(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void For( + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, llvm::Value* step, const std::function& for_body_generator) { - For(name, start, end, step, - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + ForReturnVoid(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) { + return for_body_generator(indvar); + }); + } + + Status For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, + const std::function& for_body_generator) { + return For(name, start, end, ir_builder_->getInt64(step), + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) -> Status { + return for_body_generator(indvar); + }); } - void For( + void ForReturnVoid( tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, int64 step, const std::function& for_body_generator) { - For(name, start, end, ir_builder_->getInt64(step), - /*peel_first_iteration=*/false, - [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + ForReturnVoid(name, start, end, ir_builder_->getInt64(step), + for_body_generator); + } + + Status For( + tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + return For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } - void For( + void ForReturnVoid( tensorflow::StringPiece name, int64 start, int64 end, int64 step, const std::function& for_body_generator) { - For(name, /*start=*/ir_builder_->getInt64(start), - /*end=*/ir_builder_->getInt64(end), - /*step=*/ir_builder_->getInt64(step), for_body_generator); + ForReturnVoid(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); } // Generates the following control flow structure: @@ -123,9 +202,25 @@ class KernelSupportLibrary { // `true_block_generator()`; // else // `false_block_generator()`; - void If(llvm::Value* condition, - const std::function& true_block_generator, - const std::function& false_block_generator = []() {}); + Status If(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = + []() -> Status { return Status::OK(); }); + + void IfReturnVoid(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() { + }) { + TF_CHECK_OK(If(condition, + [&]() { + true_block_generator(); + return Status::OK(); + }, + [&]() { + false_block_generator(); + return Status::OK(); + })); + } using ArgumentVector = tensorflow::gtl::ArraySlice; @@ -183,7 +278,7 @@ class KernelSupportLibrary { private: llvm::IRBuilder<>* ir_builder_; - bool prevent_unrolling_; + llvm_ir::UnrollMode unroll_mode_; bool prevent_vectorization_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 497b48ff227d7d1f158080529372df44b6932b24..9f867014fb015845448c4fcf9c165750f8a61935 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -34,7 +34,7 @@ namespace llvm_ir { ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - llvm::Value* step, bool prevent_unrolling, + llvm::Value* step, UnrollMode unroll_mode, bool prevent_vectorization) : prefix_(std::string(prefix)), suffix_(std::string(suffix)), @@ -42,15 +42,15 @@ ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, end_index_(end_index), step_(step), insert_before_bb_(nullptr), - prevent_unrolling_(prevent_unrolling), + unroll_mode_(unroll_mode), prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr ForLoop::EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling, bool prevent_vectorization) { + UnrollMode unroll_mode, bool prevent_vectorization) { std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, - end_index, step, prevent_unrolling, + end_index, step, unroll_mode, prevent_vectorization)); loop->Emit(ir_builder); return loop; @@ -147,11 +147,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { std::vector ForLoop::GetLoopMetadata( llvm::IRBuilder<>* ir_builder) { const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; + const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full"; const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; llvm::LLVMContext* ctx = &start_index_->getContext(); std::vector result; - if (prevent_unrolling_) { + if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) { result.push_back(llvm::MDNode::get( *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)})); } @@ -162,6 +163,10 @@ std::vector ForLoop::GetLoopMetadata( llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); } + if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) { + result.push_back(llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)})); + } return result; } @@ -178,25 +183,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), - prevent_unrolling, prevent_vectorization); + unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } std::unique_ptr loop(new ForLoop( - /*prefix=*/name_, suffix, start_index, end_index, stride, - prevent_unrolling, prevent_vectorization)); + /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode, + prevent_vectorization)); loop->Emit(ir_builder_); if (outer_loop_preheader_bb_ == nullptr) { @@ -215,23 +220,23 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), prevent_unrolling, + ir_builder_->getInt64(end_index), unroll_mode, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling, + UnrollMode unroll_mode, bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), ir_builder_->getInt64(end_index), - ir_builder_->getInt64(stride), prevent_unrolling, + ir_builder_->getInt64(stride), unroll_mode, prevent_vectorization); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index d915f95db134918a173a9711936bb1e2f1ea0d95..4e403cd994874c27453574283c5c573c876628db 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -34,6 +34,12 @@ limitations under the License. namespace xla { namespace llvm_ir { +enum class UnrollMode { + kDefaultUnroll, + kFullyUnroll, + kNoUnroll, +}; + // A class for constructing a for-loop in LLVM IR. class ForLoop { public: @@ -69,12 +75,13 @@ class ForLoop { // LLVM IR. If non-empty, it is prepended to the name of the induction // variable value and each basic block created for the loop. // - // If `prevent_unrolling` is true then emit metadata that directs LLVM to not - // unroll the generated loop. + // `unroll_mode` specifies the desired LLVM unrolling behavior for generated + // loop. static std::unique_ptr EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = false, bool prevent_vectorization = false); + UnrollMode unroll_mode = llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // The names of the blocks follow LLVM's conventions. Control flow amongst the // blocks for the example C code looks like: @@ -128,7 +135,7 @@ class ForLoop { ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, - bool prevent_unrolling, bool prevent_vectorization); + UnrollMode unroll_mode, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* ir_builder); @@ -161,7 +168,7 @@ class ForLoop { llvm::BasicBlock* body_bb_; llvm::BasicBlock* exit_bb_; llvm::Value* indvar_; - bool prevent_unrolling_; + UnrollMode unroll_mode_; bool prevent_vectorization_; TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); @@ -182,34 +189,34 @@ class ForLoopNest { // Adds a loop to the nest. If no loop has been added yet then emit a loop at // the current insert point of the given builder. If one or more loops have - // been added then emit loop inside the body of the last added loop. If - // prevent_unrolling is true, then metadata is emitting directing LLVM to not - // unroll this loop. - std::unique_ptr AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + // been added then emit loop inside the body of the last added loop. + // unroll_mode is used to emit metadata that controls LLVM unrolling. + std::unique_ptr AddLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* stride, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. - std::unique_ptr AddLoop(tensorflow::StringPiece suffix, - llvm::Value* start_index, - llvm::Value* end_index, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + tensorflow::StringPiece suffix, llvm::Value* start_index, + llvm::Value* end_index, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. - std::unique_ptr AddLoop(int64 start_index, int64 end_index, - int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + int64 start_index, int64 end_index, int64 stride, + tensorflow::StringPiece suffix, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. - std::unique_ptr AddLoop(int64 start_index, int64 end_index, - tensorflow::StringPiece suffix, - bool prevent_unrolling = false, - bool prevent_vectorization = false); + std::unique_ptr AddLoop( + int64 start_index, int64 end_index, tensorflow::StringPiece suffix, + UnrollMode unroll_mode = xla::llvm_ir::UnrollMode::kDefaultUnroll, + bool prevent_vectorization = false); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ec04239b4f9112134ba876fdfbb3905a3baf1f72..ff64da87e9c9acf8a9d7ff87d3b1be7a9e9106bb 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -87,18 +87,10 @@ llvm::Value* EmitCallToIntrinsic( tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice overloaded_types, llvm::IRBuilder<>* ir_builder) { - std::vector types; - for (auto type : overloaded_types) { - types.push_back(type); - } llvm::Module* module = ModuleFromIRBuilder(ir_builder); - llvm::Function* intrinsic = - llvm::Intrinsic::getDeclaration(module, intrinsic_id, types); - std::vector operands_vec; - for (auto operand : operands) { - operands_vec.push_back(operand); - } - return ir_builder->CreateCall(intrinsic, operands_vec); + llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( + module, intrinsic_id, AsArrayRef(overloaded_types)); + return ir_builder->CreateCall(intrinsic, AsArrayRef(operands)); } llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, @@ -368,15 +360,52 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, return llvm::ConstantArray::get(aggregate_type, elements); } +template +llvm::Constant* GetConstantDataArray(const Literal& literal, + llvm::Module* module) { + const T* data = static_cast(literal.untyped_data()); + int64 num_elements = literal.size_bytes() / sizeof(T); + return llvm::ConstantDataArray::get(module->getContext(), + llvm::makeArrayRef(data, num_elements)); +} + } // namespace llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, llvm::Module* module) { - std::vector multi_index(ShapeUtil::Rank(literal.shape()), 0); - llvm::Constant* value = LiteralToConstant( - literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1, - &multi_index, module); - return value; + const Shape& shape = literal.shape(); + // TODO(b/29904935): We can get rid of this switch by exposing a + // ConstantDataArray factory method that takes a llvm::Type and a StringRef. + switch (shape.element_type()) { + case U64: + return GetConstantDataArray(literal, module); + case U32: + return GetConstantDataArray(literal, module); + case U8: + return GetConstantDataArray(literal, module); + case S64: + return GetConstantDataArray(literal, module); + case S32: + return GetConstantDataArray(literal, module); + case F64: + return GetConstantDataArray(literal, module); + case F32: + return GetConstantDataArray(literal, module); + case BF16: + case F16: + return GetConstantDataArray(literal, module); + case PRED: + return GetConstantDataArray(literal, module); + // TODO(b/29904935): Also use ConstantDataArray for complex numbers. + case C64: { + int64 dimensions = ShapeUtil::Rank(shape); + std::vector multi_index(dimensions, 0); + return LiteralToConstant(literal, /*dimension_index=*/dimensions - 1, + &multi_index, module); + } + default: + LOG(FATAL) << "unsupported type " << shape.element_type(); + } } llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type, diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 0728ccfff7b85e3751f33bc5272a5f22d4e5411a..dc2934a34c23f8229947210cacc9863d47c2ea55 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -83,7 +83,9 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, // Sanity check: In multi-output fusion, all shapes produced must have the // same dimensions. for (const IrArray& array : target_arrays) { - CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())); + CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())) + << ": '" << shape_.ShortDebugString() << "' does not match '" + << array.GetShape().ShortDebugString() << "'"; } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc index 3a21eda35757aa706565ee4a5286eee1acea117b..5fc08aab916e377b245b6221108956c06da70767 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -24,15 +24,14 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace llvm_ir { -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, - llvm::Module* module) { +void EmitTupleSelect(const IrArray& select, const IrArray& pred, + llvm::Value* on_true, llvm::Value* on_false, + llvm::IRBuilder<>* ir_builder, llvm::Module* module) { CHECK(ShapeUtil::IsScalar(pred.GetShape())); llvm::LoadInst* pred_value = @@ -47,30 +46,27 @@ void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { - std::vector element_index = {ir_builder->getInt64(0), - ir_builder->getInt64(i)}; + llvm::Value* const element_index[] = {ir_builder->getInt64(0), + ir_builder->getInt64(i)}; llvm::Value* on_true_element_address = ir_builder->CreateInBoundsGEP(on_true, element_index); llvm::Value* on_true_element = ir_builder->CreateLoad( - on_true_element_address, - tensorflow::strings::Printf("on_true_element_%d", i).c_str()); + on_true_element_address, "on_true_element_" + llvm::Twine(i)); llvm::Value* on_false_element_address = ir_builder->CreateInBoundsGEP(on_false, element_index); llvm::Value* on_false_element = ir_builder->CreateLoad( - on_false_element_address, - tensorflow::strings::Printf("on_false_element_%d", i).c_str()); + on_false_element_address, "on_false_element_" + llvm::Twine(i)); llvm::Value* output_element_address = ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); ir_builder->CreateStore( - ir_builder->CreateSelect( - pred_cond, on_true_element, on_false_element, - tensorflow::strings::Printf("select_output_element_%d", i).c_str()), + ir_builder->CreateSelect(pred_cond, on_true_element, on_false_element, + "select_output_element_" + llvm::Twine(i)), output_element_address); } } -void EmitTuple(IrArray tuple, +void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, llvm::IRBuilder<>* ir_builder, llvm::Module* module) { for (size_t i = 0; i < operands.size(); ++i) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h index dbf9a140068b60505f6798360438f709bfd3feba..352d34ebf839c6c2465abade7c3d3eb3b7a34506 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -59,13 +59,13 @@ namespace llvm_ir { // of the address from the corresponding element in either // tuple_on_true or tuple_on_false: // output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, - llvm::Module* module); +void EmitTupleSelect(const IrArray& select, const IrArray& pred, + llvm::Value* on_true, llvm::Value* on_false, + llvm::IRBuilder<>* ir_builder, llvm::Module* module); // A tuple is an array of pointers, one for each operand. Each pointer points to // the output buffer of its corresponding operand. -void EmitTuple(IrArray tuple, +void EmitTuple(const IrArray& tuple, tensorflow::gtl::ArraySlice operands, llvm::IRBuilder<>* ir_builder, llvm::Module* module); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 0fa4061738612df76c72a18a9353f16bf6a42677..296d04d4362b12fdc39798a016ca9e8795e02586 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -24,15 +24,12 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -110,6 +107,11 @@ ExecutionOptions CreateExecutionOptions( ->set_xla_dump_optimized_hlo_proto_to( build_options.dump_optimized_hlo_proto_to().value()); } + if (build_options.dump_unoptimized_hlo_proto_to().has_value()) { + execution_options.mutable_debug_options() + ->set_xla_dump_unoptimized_hlo_proto_to( + build_options.dump_unoptimized_hlo_proto_to().value()); + } if (build_options.dump_per_pass_hlo_proto_to().has_value()) { execution_options.mutable_debug_options() ->set_xla_dump_per_pass_hlo_proto_to( @@ -124,75 +126,17 @@ ExecutionOptions CreateExecutionOptions( LayoutUtil::SetToDefaultLayout( execution_options.mutable_shape_with_output_layout()); } - return execution_options; -} - -} // namespace - -StatusOr> LocalService::CompileExecutable( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& build_options) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(computation)); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Validate incoming layouts. - if (argument_layouts.size() != program_shape->parameters_size()) { - return InvalidArgument( - "Invalid number of arguments for computation: expected %d, got %zu.", - program_shape->parameters_size(), argument_layouts.size()); + for (const std::string& disabled_pass : build_options.disabled_hlo_passes()) { + execution_options.mutable_debug_options()->add_xla_disable_hlo_passes( + disabled_pass); } - for (int i = 0; i < argument_layouts.size(); ++i) { - const Shape& argument_shape = *argument_layouts[i]; - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); - if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) { - tensorflow::gtl::optional metadata = - user_computation->ParameterMetadata(i); - auto metadata_string = [&metadata]() -> string { - if (!metadata.has_value()) { - return ""; - } - CHECK(metadata.value() != nullptr); - const OpMetadata& m = *metadata.value(); - if (!m.source_file().empty()) { - return tensorflow::strings::Printf( - " (%s:%d)", m.source_file().c_str(), m.source_line()); - } - return ""; - }; - return InvalidArgument( - "Invalid argument shape for argument %d%s, expected %s, got %s.", i, - metadata_string().c_str(), - ShapeUtil::HumanString(program_shape->parameters(i)).c_str(), - ShapeUtil::HumanString(argument_shape).c_str()); - } - } - if (build_options.result_layout() != nullptr) { - TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( - *build_options.result_layout(), program_shape->result())); - } - - ExecutionOptions execution_options = - CreateExecutionOptions(build_options, program_shape.get()); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, - &execution_options, user_computation)); - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - execute_backend_->stream_executor(build_options.device_ordinal())); - - return BuildExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), executor, - build_options.device_allocator()); + return execution_options; } +} // namespace + StatusOr> LocalService::CompileExecutable( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -260,4 +204,15 @@ StatusOr LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { /*computation_count=*/1); } +StatusOr LocalService::GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number) { + TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data)); + if (replica_number >= buffers.size()) { + return InvalidArgument( + "replica_number %d out of range; must be less than num_replicas = %zu.", + replica_number, buffers.size()); + } + return buffers[replica_number]; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 06567cabd6eb28aae53881613cd6beb78e25e222..39d6734c3fc06df6832cf67edddbc7c14c815cd1 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -41,23 +41,11 @@ class LocalService : public Service { static StatusOr> NewService( const ServiceOptions& options); - // Builds an Executable with the given argument layouts and options. If - // result_layout is non-null, then the executable is compiled to produce a - // result of the given layout. If device_allocator is non-null, then the - // compiler may use it to allocate temp space on the device. The compiler is - // responsible for freeing any memory it allocates this way. - StatusOr> CompileExecutable( - const ComputationHandle& computation, - const tensorflow::gtl::ArraySlice argument_layouts, - const ExecutableBuildOptions& options); - // Builds an Executable with the given XlaComputation, argument layouts and // options. If result_layout is non-null, then the executable is compiled to // produce a result of the given layout. If device_allocator is non-null, // then the compiler may use it to allocate temp space on the device. The // compiler is responsible for freeing any memory it allocates this way. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> CompileExecutable( const XlaComputation& computation, const tensorflow::gtl::ArraySlice argument_layouts, @@ -70,6 +58,11 @@ class LocalService : public Service { // the "easy" case where a single replica is a single device. StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); + // Converts a GlobalDataHandle into a pointer to a ShapedBuffer that's valid + // as long as the handle is valid. + StatusOr GlobalDataToShapedBuffer( + const GlobalDataHandle& data, int replica_number); + private: explicit LocalService(const ServiceOptions& options, std::unique_ptr backend); diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index 6aca6ba38572c5311797fbb91acbbcd6610a3410..f410921b4b5337192bdeae5924631d9c06b7d5a5 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -125,6 +125,12 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) { return Status::OK(); } +Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) { + // A kDomain instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand. + return Status::OK(); +} + Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) { // RecvDone doesn't create a new buffer but rather aliases its input (Recv) // tuple element at {0} to its output. diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index f4c63dd86b4d8a6f598d46047012e4e5bc7b3d7e..b5ef3967875a58b35631d5f69c210f5cbcd91250 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -59,6 +59,7 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleDomain(HloInstruction* domain) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc new file mode 100644 index 0000000000000000000000000000000000000000..29f787b86b9cbb6f80d048b46b78bdad8074f488 --- /dev/null +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -0,0 +1,342 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/multi_output_fusion.h" + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +StatusOr MultiOutputFusion::Run(HloModule* module) { + bool changed = false; + + for (auto* computation : module->MakeNonfusionComputations()) { + computation_ = computation; + reachability_ = computation_->ComputeReachability(); + candidates_.clear(); + candidates_index_.clear(); + all_fusion_candidates_.clear(); + + int64 index = 0; + for (auto it : computation_->MakeInstructionPostOrder()) { + candidates_.emplace_back(it); + InsertOrDie(&candidates_index_, it, index++); + } + + // Create the initial candidate list for each Node. + for (auto& node : candidates_) { + HloInstruction* instruction = node.hlo; + int64 instruction_id = get_candidate_id(instruction); + FusionCandidate& instr_node = candidates_[instruction_id]; + if (!IsFusible(instruction)) { + continue; + } + all_fusion_candidates_.push_back(instruction); + + std::vector candidates; + tensorflow::gtl::FlatSet candidates_set; + VLOG(10) << "Looking at instruction: " << instruction->name(); + for (auto operand : instruction->operands()) { + // Filter out the non-interesting instructions -- they + // will not generate the savings. + if (!IsProfitableOperand(operand)) { + VLOG(10) << "Operand not profitable: " << operand->name(); + continue; + } + VLOG(10) << "Operand profitable: " << operand->name(); + for (auto user : operand->users()) { + VLOG(10) << "User: " << user->name(); + if (user == instruction || !IsFusible(user)) { + VLOG(10) << "User is not fusible, or is the instruction itself: " + << user->name(); + continue; + } + int64 user_id = get_candidate_id(user); + if (is_connected(instruction, user)) { + VLOG(10) << "User is connected: " << user->name(); + continue; + } + if (instruction_id < user_id && + user->opcode() == HloOpcode::kFusion) { + VLOG(10) << "User ID for user: " << user->name() << " is " + << user_id << " which is higher than " << instruction_id; + continue; + } + if (!LegalToFuse(instruction, user)) { + VLOG(10) << "User not legal to fuse: " << user->name(); + continue; + } + if (candidates_set.insert(user).second) { + VLOG(10) << "User added to candidate list: " << user->name(); + candidates.push_back(user); + } + } + } + + // Iterate over candidates rather than candidates_set to avoid + // nondeterminism. + for (auto candidate : candidates) { + int64 profit = GetProfit(instruction, candidate); + if (profit > 0) { + FusionCandidate& candidate_node = + candidates_[get_candidate_id(candidate)]; + instr_node.fusibles.emplace_back(candidate, profit); + candidate_node.fusibles.emplace_back(instruction, profit); + worklist_.emplace(instruction, candidate, profit); + } + } + } + if (Perform()) { + changed = true; + } + } + return changed; +} + +HloInstruction* MultiOutputFusion::Fuse(HloInstruction* instr1, + HloInstruction* instr2) { + HloInstruction* remaining = instr1; + HloInstruction* fused = instr2; + // Make sure that if only one of the instructions is a fusion, or if only one + // of the instructions is a multi-output fusion, it's what will be fused into. + // + // An invariant is that no bitcast nodes will show up in the middle of a + // fusion node. This invariant must hold in order for us to lower it. Given + // that, we require that during multi-output fusion, a fusion node ending with + // bitcast to preserve its structure as a nested fusion instead being + // merged and flattened. + if (fused->opcode() == HloOpcode::kFusion && + fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + std::swap(remaining, fused); + } + if (fused->IsMultiOutputFusion()) { + std::swap(remaining, fused); + } + + if (fused->opcode() == HloOpcode::kFusion && + fused->fused_expression_root()->opcode() != HloOpcode::kBitcast) { + remaining->MergeFusionInstructionIntoMultiOutput(fused); + } else { + if (remaining->opcode() == HloOpcode::kFusion && + remaining->fused_expression_root()->opcode() == HloOpcode::kBitcast) { + auto parent_computation = remaining->parent(); + // Create a nested fusion node. + auto remaining_nested_fused = + parent_computation->AddInstruction(HloInstruction::CreateFusion( + remaining->shape(), HloInstruction::FusionKind::kLoop, + remaining)); + TF_CHECK_OK(parent_computation->ReplaceInstruction( + remaining, remaining_nested_fused)); + remaining = remaining_nested_fused; + } + remaining->FuseInstructionIntoMultiOutput(fused); + } + + return remaining; +} + +void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) { + HloInstruction* fusion = instr1; + HloInstruction* fused = instr2; + if (is_fused(instr1)) { + fusion = instr2; + fused = instr1; + } + + // Insert the newly created instruction (if any), to candidates_. + for (auto use : fusion->users()) { + if (candidates_index_.find(use) == candidates_index_.end()) { + int64 index = candidates_.size(); + candidates_.emplace_back(use); + InsertOrDie(&candidates_index_, use, index++); + } + } + FusionCandidate& fusion_node = candidates_[get_candidate_id(fusion)]; + FusionCandidate& fused_node = candidates_[get_candidate_id(fused)]; + + // Update the reachability graph. + UpdateReachability(fusion, fused, all_fusion_candidates_, + [this](HloInstruction* instr) { return is_fused(instr); }); + + // Update the fusible list for fusion. Variable new_fusibles keeps + // track of the new or changed entries. + std::vector> new_fusibles; + tensorflow::gtl::FlatSet in_list; + auto it = fusion_node.fusibles.begin(); + while (it != fusion_node.fusibles.end()) { + HloInstruction* instr = it->first; + if (is_fused(instr) || is_connected(fusion, instr)) { + it = fusion_node.fusibles.erase(it); + continue; + } + in_list.insert(instr); + int64 profit = GetProfit(instr, fusion); + if (profit > it->second) { + it->second = profit; + new_fusibles.emplace_back(instr, profit); + } + ++it; + } + + // Fused_node has been fused into fusion_node. Take the fusion candidates + // (fusibles) from fused_nodes and add them to the fusion_node's. Filter + // out those fusibles that no longer valid (or already in the list). + for (const auto& it : fused_node.fusibles) { + HloInstruction* instr = it.first; + if (instr == fusion || is_fused(instr) || is_connected(fusion, instr)) { + continue; + } + if (in_list.count(instr) > 0) { + continue; + } + int64 profit = GetProfit(instr, fusion); + fusion_node.fusibles.emplace_back(instr, profit); + new_fusibles.emplace_back(instr, profit); + } + fused_node.fusibles.clear(); + + // Update the worklist_. + for (auto it : new_fusibles) { + worklist_.emplace(fusion, it.first, it.second); + } +} + +bool MultiOutputFusion::LegalToFuse(HloInstruction* instr1, + HloInstruction* instr2) { + if (instr1 == instr2) { + return false; + } + if (instr1->opcode() != HloOpcode::kFusion) { + return false; + } + + // Fusing nodes with 0 user makes no sense and the rest of the implementation + // doesn't support it either. + if (instr1->user_count() == 0 || instr2->user_count() == 0) { + return false; + } + + // Check if the users of multioutput fusion is not a get-tuple-element. + // If this is the case, we bail out because the transformation assumes + // the users are get-tuple-element. + auto multioutput_user_is_not_gte = [](HloInstruction* instr) { + if (!instr->IsMultiOutputFusion()) { + return false; + } + for (auto user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + return true; + } + } + return false; + }; + if (multioutput_user_is_not_gte(instr1) || + multioutput_user_is_not_gte(instr2)) { + return false; + } + + if (is_connected(instr1, instr2)) { + return false; + } + if (!ShapesCompatibleForFusion(instr1, instr2)) { + return false; + } + + return true; +} + +void MultiOutputFusion::UpdateReachability( + HloInstruction* instr1, HloInstruction* instr2, + tensorflow::gtl::ArraySlice instrs_to_update, + const std::function& skip) { + for (auto instr : instrs_to_update) { + if (skip != nullptr && skip(instr)) { + continue; + } + if (reachability_->IsReachable(instr2, instr) && + reachability_->IsReachable(instr1, instr)) { + // If a candidate was already reachable by both, no update needed. + continue; + } + if (reachability_->IsReachable(instr2, instr)) { + reachability_->FastSetReachabilityToUnion({instr, instr1}, instr); + } + if (reachability_->IsReachable(instr1, instr)) { + reachability_->FastSetReachabilityToUnion({instr, instr2}, instr); + } + } +} + +bool MultiOutputFusion::Perform() { + int changed = false; + // Pick the top candidate from queue and try to merge. + while (!worklist_.empty()) { + if (fuel_ <= 0) { + VLOG(2) << "No fusing: run out of fuel."; + break; + } + ToBeFused candidate = worklist_.top(); + worklist_.pop(); + + HloInstruction* instr1 = candidate.instr1; + HloInstruction* instr2 = candidate.instr2; + + if (is_fused(instr1) || is_fused(instr2)) { + continue; + } + + VLOG(1) << "Considering candidate profit_score=" << candidate.score + << "\n\t\tinstr1 = " << instr1->ToString() + << "\n\t\tinstr2 = " << instr2->ToString(); + + if (LegalToFuse(instr1, instr2)) { + VLOG(1) << "Fuse!"; + VLOG(2) << "Before multi_output_fusion:"; + VLOG(2) << "instr1: " << instr1->ToString(); + VLOG(2) << "\n" + << instr1->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + VLOG(2) << "instr2: " << instr2->ToString(); + if (instr2->opcode() == HloOpcode::kFusion) { + VLOG(2) << "\n" + << instr2->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + } + HloInstruction* ret = Fuse(instr1, instr2); + set_is_fused(ret == instr1 ? instr2 : instr1); + Update(instr1, instr2); + changed = true; + VLOG(2) << "After fusion, \t this: " << ret->name() << "\n" + << ret->fused_instructions_computation()->ToString( + HloPrintOptions().set_indent_amount(1)); + auto users = ret->users(); + --fuel_; + } + } + if (DoProducerConsumerMultiOutputFusion(computation_)) { + changed = true; + } + return changed; +} + +bool MultiOutputFusion::DoProducerConsumerMultiOutputFusion( + HloComputation* /*computation*/) { + return false; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h new file mode 100644 index 0000000000000000000000000000000000000000..cfdf83cfe856a7c3b05f51129446cd4e1055a8d6 --- /dev/null +++ b/tensorflow/compiler/xla/service/multi_output_fusion.h @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { + +// This class implements the fusing of sibling fusion instructions that sharing +// common operands. +// It constructs the following associated data structures. +// (1) candidates_: stores the instruction and the set of instructions it can +// fuse to. +// (2) candidates_index_: maps instruction to id. +// (3) reachability_: reachability map in this computation. +// (4) all_fusion_candidates_: the vector of candidate instructions. +// (5) worklist_: a priority queue that contains pairs of instructions to be +// fused and their fusion profit scores. +// +// Function Perform() applies the optimization. It picks up the most profitable +// pair in the worklist_, check if it's legal to fuse and fuse the pair. +// After fusion, it updates the associated structure such as reachability_, +// candidates_ and worklist_. +// Note that the reachability map is updated based on the original computation. +// This works because the reachability is monotonically increasing with +// instruction fusion. +class MultiOutputFusion : public HloPassInterface { + public: + MultiOutputFusion(int64 fuel) : fuel_(fuel) {} + + tensorflow::StringPiece name() const override { + return "multi_output_fusion"; + } + + // Run multi-output fusion on the given module. Returns whether the module + // was changed. + StatusOr Run(HloModule* module) override; + + protected: + // Main entry for the optimization. Returns true if the optimization happens. + bool Perform(); + + // Test if instr1 and instr2 have the compatible shapes that can be legally + // fused. + virtual bool ShapesCompatibleForFusion(HloInstruction* instr1, + HloInstruction* instr2) = 0; + + // Whether the instruction is a candidate for fusion. + virtual bool IsFusible(HloInstruction* instr) = 0; + + // This function estimates the savings by merging instr1 and instr2 into one + // multi-output fusion instruction. + virtual int64 GetProfit(HloInstruction* instr1, HloInstruction* instr2) = 0; + + // Whether fusing the instruction can reduce cost. + virtual bool IsProfitableOperand(HloInstruction* instr) = 0; + + // Test if it's legal to fuse instr1 and instr2 into one fusion instruction. + virtual bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2); + + // Update the reachability map after fusing instr1 and instr2. + void UpdateReachability( + HloInstruction* instr1, HloInstruction* instr2, + tensorflow::gtl::ArraySlice instrs_to_update, + const std::function& skip = nullptr); + + // Hook for multi-output fusion along producer-consumer edges. + // Returns whether any instructions were fused. + // + // TODO(b/80420762): Perform producer-consumer multi-output fusion in + // InstructionFusion instead. + virtual bool DoProducerConsumerMultiOutputFusion(HloComputation* computation); + + private: + // Fuse HloInstrctuion instr1 and instr2 and return the fused instruction. + // The other instruction is removed from its parent computation. + HloInstruction* Fuse(HloInstruction* instr1, HloInstruction* instr2); + + // Update the internal data structures after instr1 and instr2 are fused into + // one fusion instruction. + void Update(HloInstruction* instr1, HloInstruction* instr2); + + // Optimization fuel is a compiler debugging technique that makes an + // optimization pass stop what it is doing after having made N changes to the + // program, where N is the fuel. By varying N, this can be used to find the + // first single change that makes a test fail. + int64 fuel_; + + // Computation for the pass. + HloComputation* computation_; + + // An internal data structure for each instruction in current computation. + // When an instruction is removed, member 'hlo' is set to nullptr. + struct FusionCandidate { + HloInstruction* hlo; + std::list> fusibles; + explicit FusionCandidate(HloInstruction* hlo) : hlo(hlo) {} + }; + std::vector candidates_; + + // A map that maps an instruction to the index_. + tensorflow::gtl::FlatMap candidates_index_; + + // The reachability map of current computation. + std::unique_ptr reachability_; + + // This stores all the candidate instructions in current computation. + std::vector all_fusion_candidates_; + + // The pair of candidates to be fused and the profit score. + struct ToBeFused { + HloInstruction* instr1; + HloInstruction* instr2; + int64 score; + ToBeFused(HloInstruction* instr1, HloInstruction* instr2, int64 score) + : instr1(instr1), instr2(instr2), score(score) {} + bool operator<(const ToBeFused& rhs) const { return score < rhs.score; } + }; + std::priority_queue worklist_; + + int64 get_candidate_id(HloInstruction* instr) { + return FindOrDie(candidates_index_, instr); + } + + bool is_fused(HloInstruction* instr) { + return candidates_[get_candidate_id(instr)].hlo == nullptr; + } + + void set_is_fused(HloInstruction* instr) { + candidates_[get_candidate_id(instr)].hlo = nullptr; + } + + bool is_connected(HloInstruction* instr1, HloInstruction* instr2) { + return reachability_->IsConnected(instr1, instr2); + } +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MULTI_OUTPUT_FUSION_H_ diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index d3bc47e61e0e75fa2ef181988700f88cec9c1d76..2515222cf2db3d9699c85c13f4fe72b3488fa217 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -204,7 +204,7 @@ class LayoutPattern { // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. constexpr LayoutPattern> EqualTo( - const Layout* layout) const { + const ::xla::Layout* layout) const { return LayoutPattern>( LayoutPatternEqualImpl(impl_, layout), matched_layout_); } diff --git a/tensorflow/compiler/xla/service/pattern_matcher_test.cc b/tensorflow/compiler/xla/service/pattern_matcher_test.cc index 204e8c99209fa95adb868a676bb9e5144fed432c..fef3c132b0f3467a01b02f2be88b419459179277 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher_test.cc +++ b/tensorflow/compiler/xla/service/pattern_matcher_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/pattern_matcher.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/core/platform/test.h" namespace xla { @@ -29,7 +29,7 @@ TEST(PatternMatcherTest, AddOp) { ROOT %two_plus_two = f32[] add(f32[] %two, f32[] %two) } )"; - TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); const HloInstruction* matched_inst; HloInstruction* matched_operand; @@ -182,7 +182,7 @@ TEST(PatternMatcherTest, FusionKind) { p0 = f32[] parameter(0) ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=fused_computation })"; - TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, tools::Parse(kModuleStr)); + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloString(kModuleStr)); auto* root = hlo_module->entry_computation()->root_instruction(); EXPECT_TRUE(Match( diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 0f26a025bf125f70199637894741540f89eae7e5..49ec38eb62c7b51c7a2d301d882cef032b288036 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -155,20 +155,15 @@ HloInstruction* UpdateOperand(const HloInstruction* first_reshape_operand, case HloOpcode::kConstant: { if (first_reshape_operand->opcode() == HloOpcode::kReshape) { VLOG(5) << "Adding reshape to kConstant operand"; - HloInstruction* reshape = computation->AddInstruction( + return computation->AddInstruction( HloInstruction::CreateReshape(new_shape, operand)); - operand->SetupDerivedInstruction(reshape); - return reshape; } else { CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose); VLOG(5) << "Adding transpose to kConstant operand"; std::vector inverse_permutation = InversePermutation(first_reshape_operand->dimensions()); - HloInstruction* transpose = - computation->AddInstruction(HloInstruction::CreateTranspose( - new_shape, operand, inverse_permutation)); - operand->SetupDerivedInstruction(transpose); - return transpose; + return computation->AddInstruction(HloInstruction::CreateTranspose( + new_shape, operand, inverse_permutation)); } } case HloOpcode::kRng: { diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index cb0f76ebe4d445059fdf37ebf559bef851a57104..d01c35b99231310692f85d0f9fbf4f2c3709d44c 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" -#include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" @@ -62,33 +61,6 @@ namespace xla { namespace { -// Records the arguments used to invoke a computation in a SessionModule -// proto. -Status RecordArguments( - const tensorflow::gtl::ArraySlice arguments, - se::StreamExecutor* executor, TransferManager* transfer_manager, - SessionModule* module) { - module->clear_arguments(); - for (const ShapedBuffer* argument : arguments) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, *argument)); - *module->add_arguments() = literal->ToProto(); - } - return Status::OK(); -} - -// Records the result of a computation in a SessionModule proto. -Status RecordResult(const ShapedBuffer& result, se::StreamExecutor* executor, - TransferManager* transfer_manager, SessionModule* module) { - module->clear_result(); - TF_ASSIGN_OR_RETURN( - std::unique_ptr literal, - transfer_manager->TransferLiteralFromDevice(executor, result)); - *module->mutable_result() = literal->ToProto(); - return Status::OK(); -} - // Records the arguments used to invoke a computation in an HloSnapshot proto. Status RecordArguments( const tensorflow::gtl::ArraySlice arguments, @@ -195,20 +167,6 @@ Service::Service(const ServiceOptions& options, } } -Status Service::Computation(const ComputationRequest* arg, - ComputationResponse* result) { - if (arg->name().empty()) { - return InvalidArgument("computation request needs a name"); - } - - *result->mutable_computation() = - computation_tracker_.NewComputation(arg->name()); - VLOG(1) << Printf("Created new computation %s on service %p, name %s", - result->computation().ShortDebugString().c_str(), this, - arg->name().c_str()); - return Status::OK(); -} - Status Service::CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) { *result->mutable_channel() = channel_tracker_.NewChannel(); @@ -288,8 +246,7 @@ Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - const UserComputation* user_computation) { + const ExecutionOptions* execution_options) { auto config = MakeUnique(program_shape); ComputationLayout* host_computation_layout = config->mutable_host_entry_computation_layout(); @@ -305,17 +262,9 @@ StatusOr> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { - if (user_computation == nullptr) { - return InvalidArgument( - "Argument does not match shape of computation parameter %d: want " - "%s, got %s", - i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), - ShapeUtil::HumanString(*argument_shapes[i]).c_str()); - } - return InvalidParameterArgument( - *user_computation->ParameterMetadata(i).value(), - "Argument does not match shape of computation parameter %d: want %s, " - "got %s", + return InvalidArgument( + "Argument does not match shape of computation parameter %d: want " + "%s, got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } @@ -366,76 +315,12 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, - const UserComputation* user_computation) { + const ExecutionOptions& execution_options) { std::vector argument_shapes; for (const auto* arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); } - return CreateModuleConfig(program_shape, argument_shapes, &execution_options, - user_computation); -} - -StatusOr>> Service::BuildExecutables( - std::vector versioned_handles, - std::vector> module_configs, - Backend* backend, std::vector> executors, - DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p", this); - - // Dump computation proto state if flag is set. - std::vector> session_modules; - for (int64 i = 0; i < versioned_handles.size(); ++i) { - const string& directory_path = - module_configs[i]->debug_options().xla_dump_computations_to(); - const string& other_directory_path = - module_configs[i]->debug_options().xla_dump_executions_to(); - if (directory_path.empty() && other_directory_path.empty()) { - continue; - } - TF_ASSIGN_OR_RETURN( - std::unique_ptr session_module, - computation_tracker_.SnapshotComputation(versioned_handles[i].handle)); - if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handles[i].handle.handle(), - session_module->entry().name().c_str(), - versioned_handles[i].version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - session_modules.push_back(std::move(session_module)); - } - } - - VLOG(1) << "Computation handles:"; - for (const VersionedComputationHandle& versioned_handle : versioned_handles) { - VLOG(1) << versioned_handle; - } - - CHECK_EQ(versioned_handles.size(), module_configs.size()); - std::vector> modules; - for (int64 i = 0; i < versioned_handles.size(); ++i) { - const VersionedComputationHandle& versioned_handle = versioned_handles[i]; - const HloModuleConfig& config = *module_configs[i]; - TF_ASSIGN_OR_RETURN(auto module, - computation_tracker_.BuildHloModule( - versioned_handle, config, - /*include_unreachable_instructions=*/true)); - modules.push_back(std::move(module)); - } - - TF_ASSIGN_OR_RETURN( - std::vector> executables, - backend->compiler()->Compile(std::move(modules), std::move(executors), - device_allocator)); - - for (size_t i = 0; i < versioned_handles.size(); ++i) { - if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { - executables[i]->set_session_module(std::move(session_modules[i])); - } - } - - return std::move(executables); + return CreateModuleConfig(program_shape, argument_shapes, &execution_options); } StatusOr>> Service::BuildExecutables( @@ -512,98 +397,6 @@ Status Service::ValidateEntryComputationLayout(HloModule* module) { return Status::OK(); } -StatusOr> Service::BuildExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { - VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this, - versioned_handle.ToString().c_str()); - - // Dump computation proto state if flag is set. - std::unique_ptr session_module; - const string& directory_path = - module_config->debug_options().xla_dump_computations_to(); - const string& other_directory_path = - module_config->debug_options().xla_dump_executions_to(); - if (!directory_path.empty() || !other_directory_path.empty()) { - TF_ASSIGN_OR_RETURN( - session_module, - computation_tracker_.SnapshotComputation(versioned_handle.handle)); - if (!directory_path.empty()) { - string filename = Printf("computation_%lld__%s__version_%lld", - versioned_handle.handle.handle(), - session_module->entry().name().c_str(), - versioned_handle.version); - TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename, - *session_module)); - } - } - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - true)); - - TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); - - TF_ASSIGN_OR_RETURN( - module, backend->compiler()->RunHloPasses(std::move(module), executor, - device_allocator)); - // Check that on-host and on-device shapes are consistent. - TF_RETURN_IF_ERROR(ValidateEntryComputationLayout(module.get())); - - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - backend->compiler()->RunBackend( - std::move(module), executor, device_allocator)); - - if (!other_directory_path.empty()) { - executable->set_session_module(std::move(session_module)); - } - - return std::move(executable); -} - -StatusOr> Service::BuildAndCacheExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, ExecutionProfile* profile, - DeviceMemoryAllocator* device_allocator) { - std::shared_ptr executable = - compilation_cache_.LookUp(versioned_handle, *module_config); - - if (executable != nullptr) { - // Executable found in the computation cache. - if (profile != nullptr) { - profile->set_compilation_cache_hit(true); - } - return executable; - } - - uint64 start_micros = - // Avoid reading the clock if we don't want timing info - (profile != nullptr) ? tensorflow::Env::Default()->NowMicros() : 0; - - // Take a copy of the module config, as compilation introduces layouts where - // layouts were optional before. - HloModuleConfig original_module_config = *module_config; - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable_unique_ptr, - BuildExecutable(versioned_handle, std::move(module_config), backend, - executor, device_allocator)); - - if (profile != nullptr) { - uint64 end_micros = tensorflow::Env::Default()->NowMicros(); - uint64 milliseconds = (end_micros - start_micros) / 1000; - profile->set_compilation_cache_hit(false); - profile->set_compile_time_ms(milliseconds); - } - - // Insert executable into the cache. - return compilation_cache_.Insert(std::move(executable_unique_ptr), - original_module_config); -} - StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, @@ -624,9 +417,16 @@ Service::ExecuteParallelAndRegisterResult( // profiled. std::map index_to_profiled_streams; - TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, - backend->computation_placer()->AssignDevices( - options_.number_of_replicas(), executables.size())); + // Build DeviceAssignment for all cores based on the provided device handles. + DeviceAssignment device_assignment(options_.number_of_replicas(), + executables.size()); + for (int64 i = 0; i < executables.size(); i++) { + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + CHECK_EQ(replicas.size(), arguments[i].size()); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + device_assignment(replica, i) = replicas[replica]->device_ordinal(); + } + } for (int64 i = 0; i < executables.size(); i++) { // Stream executors for the replicas of the current computation. @@ -799,13 +599,6 @@ StatusOr Service::ExecuteAndRegisterResult( result_tag); } -Status Service::SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - return computation->SetReturnValue(arg->operand()); -} - StatusOr> Service::GetExecutors( const ExecutionOptions& execution_options, int64 requests_size, int64 request_index) const { @@ -847,117 +640,6 @@ StatusOr>> Service::GetArguments( return replicated_arguments; } -Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) { - VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); - - std::vector>> all_arguments; - std::vector> all_executors; - std::vector versioned_handles; - std::vector> module_configs; - std::vector computation_names; - std::vector device_handles; - - int num_requested_devices = - std::accumulate(arg->requests().begin(), arg->requests().end(), 0, - [](int a, const ExecuteRequest& r) -> int { - return a + r.execution_options().device_handles_size(); - }); - if (num_requested_devices * options_.number_of_replicas() > - execute_backend_->device_count()) { - return FailedPrecondition( - "there are not enough stream executors to execute %d computations", - num_requested_devices); - } - - for (int64 i = 0; i < arg->requests_size(); ++i) { - // Get the stream executor for the i'th computation. This stream executor - // is one of the executors to run the replicated computation. - const ExecutionOptions& execution_options = - arg->requests(i).execution_options(); - - // Get the executors. - TF_ASSIGN_OR_RETURN(auto executors, GetExecutors(execution_options, - arg->requests_size(), i)); - - // Resolve the UserComputation object associated with the requested - // computation and compute the program shape. - const ExecuteRequest& request = arg->requests(i); - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(request.computation())); - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - // Get the replicated arguments. - TF_ASSIGN_OR_RETURN(auto replicated_arguments, - GetArguments(execution_options, request.arguments())); - - // Create an HloModuleConfig object for the computation, given the shape of - // the program and the argument allocations. Here, we care only about the - // shapes of the arguments, so, it is sufficient to use the arguments of - // replica 0. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - request.execution_options(), user_computation)); - VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - // Adds to the vectors to build and execute the computations after the loop. - all_arguments.push_back(replicated_arguments); - all_arguments.insert(all_arguments.end(), executors.size() - 1, {{}}); - versioned_handles.push_back(versioned_handle); - module_configs.push_back(std::move(module_config)); - computation_names.insert(computation_names.end(), executors.size(), - user_computation->name()); - all_executors.push_back(executors); - device_handles.insert(device_handles.end(), - execution_options.device_handles().begin(), - execution_options.device_handles().end()); - } - - // Build the user computations into HloModules and compile to generate the - // executables. - // - // TODO(jlebar): There's currently no way to pass a device allocator to - // ExecuteParallel, so we have to pass a null device_allocator below. - TF_ASSIGN_OR_RETURN( - std::vector> executables, - BuildExecutables(versioned_handles, std::move(module_configs), - execute_backend_.get(), all_executors, - /*device_allocator=*/nullptr)); - std::vector executable_ptrs; - executable_ptrs.reserve(executables.size()); - for (const auto& executable : executables) { - executable_ptrs.push_back(executable.get()); - } - - // Execute the generated executables in parallel and return the device - // handles for each computation's output. - ExecutionProfile profile; - TF_ASSIGN_OR_RETURN( - std::vector outputs, - ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), device_handles, - computation_names, &profile)); - for (const GlobalDataHandle& output : outputs) { - ExecuteResponse response; - *response.mutable_output() = output; - *response.mutable_profile() = profile; - *result->add_responses() = response; - } - - VLOG(1) << "successfully completed 'execute-parallel' request"; - return Status::OK(); -} - Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-graph-parallel request"; @@ -1007,8 +689,7 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, std::unique_ptr module_config, CreateModuleConfig(request.computation().program_shape(), replicated_arguments.front(), - request.execution_options(), - /*user_computation=*/nullptr)); + request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " << module_config->host_entry_computation_layout().ToString(); @@ -1083,15 +764,6 @@ Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, return Status::OK(); } -Status Service::ExecuteOneToN(const ExecuteRequest* arg, - ExecuteResponse* result) { - ExecuteParallelRequest parallel_arg; - *parallel_arg.add_requests() = *arg; - ExecuteParallelResponse parallel_result; - TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); - return PickParallelResponse(parallel_result, result); -} - Status Service::ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result) { ExecuteGraphParallelRequest parallel_arg; @@ -1124,80 +796,6 @@ Status Service::PickParallelResponse( return Status::OK(); } -Status Service::Execute(const ExecuteRequest* arg, ExecuteResponse* result) { - VLOG(1) << "running execute request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - // If we received multiple device handles, we must partition the module. - if (arg->execution_options().device_handles_size() > 1) { - return ExecuteOneToN(arg, result); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - // Since we care only about the shapes of the arguments, it is sufficient to - // use the arguments of replica 0. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), user_computation)); - - VLOG(3) << "Execute created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable(versioned_handle, std::move(module_config), - execute_backend_.get(), - execute_backend_->default_stream_executor(), - result->mutable_profile())); - - if (executable->dumping()) { - executable->session_module()->set_execution_platform( - execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR(RecordArguments( - replicated_arguments.front(), - execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - } - - TF_ASSIGN_OR_RETURN( - *result->mutable_output(), - ExecuteAndRegisterResult( - executable.get(), replicated_arguments, execute_backend_.get(), - "result of " + user_computation->name(), result->mutable_profile())); - - if (executable->dumping()) { - TF_ASSIGN_OR_RETURN( - const ShapedBuffer* result_buffer, - allocation_tracker_.ResolveForReplica(result->output(), 0)); - TF_RETURN_IF_ERROR(RecordResult( - *result_buffer, execute_backend_->default_stream_executor(), - execute_backend_->transfer_manager(), executable->session_module())); - TF_RETURN_IF_ERROR(executable->DumpSessionModule()); - } - - VLOG(1) << "successfully completed 'execute' request"; - return Status::OK(); -} - StatusOr> Service::BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -1303,86 +901,6 @@ Status Service::ExecuteGraph(const ExecuteGraphRequest* arg, return Status::OK(); } -Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) { - VLOG(1) << "running execute-async request: " << arg->ShortDebugString(); - - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - std::shared_ptr program_shape, - user_computation->ComputeProgramShape(versioned_handle.version)); - - TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, - SingleComputationDeviceHandle())); - TF_RET_CHECK(!replicas.empty()); - TF_ASSIGN_OR_RETURN( - std::vector> replicated_arguments, - ResolveAndValidateArguments(arg->arguments(), replicas)); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr module_config, - CreateModuleConfig(*program_shape, replicated_arguments.front(), - arg->execution_options(), user_computation)); - - VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " - << module_config->host_entry_computation_layout().ToString(); - - ExecutionProfile profile; - - TF_ASSIGN_OR_RETURN( - std::shared_ptr executable, - BuildAndCacheExecutable( - versioned_handle, std::move(module_config), execute_backend_.get(), - execute_backend_->default_stream_executor(), &profile)); - - // Set up streams. - std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : replicas) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, - execute_backend_->BorrowStream(executor)); - streams.push_back(std::move(stream)); - } - - std::vector result_buffers; - for (size_t i = 0; i < streams.size(); ++i) { - const auto& stream = streams[i]; - ExecutableRunOptions options; - options.set_stream(stream.get()); - options.set_allocator(execute_backend_->memory_allocator()); - options.set_intra_op_thread_pool( - execute_backend_->eigen_intra_op_thread_pool_device()); - - ServiceExecutableRunOptions service_options( - options, execute_backend_->StreamBorrower()); - - TF_ASSIGN_OR_RETURN(ScopedShapedBuffer this_result_buffer, - executable->ExecuteAsyncOnStream( - &service_options, replicated_arguments[i])); - - result_buffers.emplace_back(std::move(this_result_buffer)); - } - - TF_ASSIGN_OR_RETURN( - GlobalDataHandle output, - allocation_tracker_.RegisterReplicatedBuffers( - std::move(result_buffers), "result of " + user_computation->name())); - - *result->mutable_execution() = execution_tracker_.Register( - execute_backend_.get(), std::move(streams), profile, output); - streams.clear(); - - VLOG(1) << "successfully completed 'execute-async' request"; - return Status::OK(); -} - Status Service::WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) { TF_ASSIGN_OR_RETURN(const auto execution, @@ -1549,117 +1067,6 @@ Status Service::ResetDevice(const ResetDeviceRequest* arg, return execute_backend_->ResetDevices(); } -Status Service::IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->num_parameters())); - - result->set_is_constant(is_constant); - return Status::OK(); -} - -Status Service::ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandleAtOperation(arg->operand()); - - if (user_computation->request_count(versioned_handle.version) == 0) { - return InvalidArgument("computations may not be empty"); - } - - TF_ASSIGN_OR_RETURN( - bool is_constant, - user_computation->IsConstant(arg->operand(), arg->parameters_size())); - if (!is_constant) { - StatusOr op_request_status = - user_computation->LookUpRequestForErrorReporting(arg->operand()); - string op_request_string = ""; - if (op_request_status.ok()) { - op_request_string = op_request_status.ValueOrDie()->ShortDebugString(); - } - return InvalidArgument( - "Operand to ComputeConstant depends on a parameter.\n\n" - " op requested for constant evaluation: %s\n\n" - "This is an internal error that typically happens when the XLA user " - "(e.g. TensorFlow) is attempting to determine a value that must be a " - "compile-time constant (e.g. an array dimension) but it is not capable " - "of being evaluated at XLA compile time.\n\n" - "Please file a usability bug with the framework being used (e.g. " - "TensorFlow).", - op_request_string.c_str()); - } - - // We can't use ComputeProgramShape because it checks that all parameter - // instructions are present and contiguous. Instead construct ProgramShape - // directly. - ProgramShape program_shape; - TF_ASSIGN_OR_RETURN(*program_shape.mutable_result(), - user_computation->GetShape(arg->operand())); - - TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); - - ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions(); - execution_options.mutable_debug_options()->set_xla_enable_fast_math(false); - execution_options.mutable_debug_options() - ->set_xla_eliminate_hlo_implicit_broadcast(true); - *execution_options.mutable_shape_with_output_layout() = - program_shape.result(); - - Shape shape_with_output_layout(program_shape.result()); - if (arg->has_output_layout()) { - TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( - arg->output_layout(), execution_options.shape_with_output_layout())); - *execution_options.mutable_shape_with_output_layout()->mutable_layout() = - arg->output_layout(); - } - - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options, - user_computation)); - - // Exclude dead parameter instructions for the purpose of computing constants. - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, *module_config, - /*include_unreachable_instructions=*/ - false)); - - std::vector> parameters(arg->parameters_size()); - for (int64 i = 0; i < arg->parameters_size(); ++i) { - TF_ASSIGN_OR_RETURN(parameters[i], - Literal::CreateFromProto(arg->parameters(i))); - } - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN( - auto result_literal, - evaluator.Evaluate>(*module, parameters)); - - // Since the shape_with_output_layout option in ExecutionOption is - // non-effective to the Evaluator results, explicit relayout here. - // - // TODO(b/77824332): Make HloEvaluator take care of the re-layout. - if (arg->has_output_layout()) { - result_literal = result_literal->Relayout(arg->output_layout()); - } - *result->mutable_literal() = result_literal->ToProto(); - - return Status::OK(); -} - Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) { if (!arg->has_computation()) { @@ -1709,60 +1116,6 @@ Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { return Status::OK(); } -Status Service::GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - computation->GetVersionedHandle(); - - TF_ASSIGN_OR_RETURN(auto program_shape, computation->ComputeProgramShape( - versioned_handle.version)); - *result->mutable_program_shape() = *program_shape; - return Status::OK(); -} - -Status Service::GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_shape(), - computation->GetShape(arg->operand())); - return Status::OK(); -} - -Status Service::GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * user_computation, - computation_tracker_.Resolve(arg->computation())); - - VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - - HloModuleConfig config; - config.set_debug_options(arg->debug_options()); - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.BuildHloModule(versioned_handle, config)); - - hlo_graph_dumper::MaybeDumpHloModule(*module, - "computation statistics subject"); - - // Run HLO analysis to get the computation statistics. - HloCostAnalysis analysis( - execute_backend_->compiler()->ShapeSizeBytesFunction()); - - TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis)); - - ComputationStats stats; - stats.set_flop_count(analysis.flop_count()); - stats.set_transcendental_count(analysis.transcendental_count()); - *result->mutable_stats() = stats; - return Status::OK(); -} - Status Service::GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) { if (!arg->has_computation()) { @@ -1793,262 +1146,6 @@ Status Service::GetComputationGraphStats( return Status::OK(); } -template -Status Service::AddInstruction( - const RequestT* arg, ResponseT* result, - const std::function(UserComputation*)>& - adder) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - - TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation)); - return Status::OK(); -} - -Status Service::Op(const OpRequest* arg, OpResponse* result) { - TF_ASSIGN_OR_RETURN(UserComputation * computation, - computation_tracker_.Resolve(arg->computation())); - StatusOr handle_status; - - switch (arg->op_case()) { - case OpRequest::kBatchNormTrainingRequest: - handle_status = computation->AddBatchNormTrainingInstruction( - arg->batch_norm_training_request()); - break; - case OpRequest::kBatchNormInferenceRequest: - handle_status = computation->AddBatchNormInferenceInstruction( - arg->batch_norm_inference_request()); - break; - case OpRequest::kBatchNormGradRequest: - handle_status = computation->AddBatchNormGradInstruction( - arg->batch_norm_grad_request()); - break; - case OpRequest::kBinaryOpRequest: - handle_status = - computation->AddBinaryInstruction(arg->binary_op_request()); - break; - case OpRequest::kBroadcastRequest: - handle_status = - computation->AddBroadcastInstruction(arg->broadcast_request()); - break; - case OpRequest::kCallRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->call_request().to_apply())); - handle_status = - computation->AddCallInstruction(arg->call_request(), *to_apply); - break; - } - case OpRequest::kConcatenateRequest: - handle_status = - computation->AddConcatenateInstruction(arg->concatenate_request()); - break; - case OpRequest::kConditionalRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * true_computation, - computation_tracker_.Resolve( - arg->conditional_request().true_computation())); - TF_ASSIGN_OR_RETURN(UserComputation * false_computation, - computation_tracker_.Resolve( - arg->conditional_request().false_computation())); - handle_status = computation->AddConditionalInstruction( - arg->conditional_request(), *true_computation, *false_computation); - break; - } - case OpRequest::kConstantRequest: - handle_status = - computation->AddConstantInstruction(arg->constant_request()); - break; - case OpRequest::kConvertRequest: - handle_status = - computation->AddConvertInstruction(arg->convert_request()); - break; - case OpRequest::kBitcastConvertRequest: - handle_status = computation->AddBitcastConvertInstruction( - arg->bitcast_convert_request()); - break; - case OpRequest::kConvolveRequest: - handle_status = - computation->AddConvolveInstruction(arg->convolve_request()); - break; - case OpRequest::kCrossReplicaSumRequest: - handle_status = computation->AddCrossReplicaSumInstruction( - arg->cross_replica_sum_request()); - break; - case OpRequest::kCustomCallRequest: - handle_status = - computation->AddCustomCallInstruction(arg->custom_call_request()); - break; - case OpRequest::kDotRequest: - handle_status = computation->AddDotInstruction(arg->dot_request()); - break; - case OpRequest::kDynamicSliceRequest: - handle_status = - computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); - break; - case OpRequest::kDynamicUpdateSliceRequest: - handle_status = computation->AddDynamicUpdateSliceInstruction( - arg->dynamic_update_slice_request()); - break; - case OpRequest::kFftRequest: - handle_status = computation->AddFftInstruction(arg->fft_request()); - break; - case OpRequest::kGatherRequest: - handle_status = computation->AddGatherInstruction(arg->gather_request()); - break; - case OpRequest::kGetTupleElementRequest: - handle_status = computation->AddGetTupleElementInstruction( - arg->get_tuple_element_request()); - break; - case OpRequest::kInfeedRequest: - handle_status = computation->AddInfeedInstruction(arg->infeed_request()); - break; - case OpRequest::kOutfeedRequest: - handle_status = - computation->AddOutfeedInstruction(arg->outfeed_request()); - break; - case OpRequest::kHostComputeRequest: - handle_status = - computation->AddHostComputeInstruction(arg->host_compute_request()); - break; - case OpRequest::kMapRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->map_request().to_apply())); - handle_status = - computation->AddMapInstruction(arg->map_request(), *to_apply); - break; - } - case OpRequest::kPadRequest: - handle_status = computation->AddPadInstruction(arg->pad_request()); - break; - case OpRequest::kParameterRequest: - handle_status = - computation->AddParameterInstruction(arg->parameter_request()); - break; - case OpRequest::kReduceRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * to_apply, - computation_tracker_.Resolve(arg->reduce_request().to_apply())); - handle_status = - computation->AddReduceInstruction(arg->reduce_request(), *to_apply); - break; - } - case OpRequest::kReducePrecisionRequest: { - handle_status = computation->AddReducePrecisionInstruction( - arg->reduce_precision_request()); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * to_apply, - computation_tracker_.Resolve( - arg->reduce_window_request().to_apply())); - handle_status = computation->AddReduceWindowInstruction( - arg->reduce_window_request(), *to_apply); - break; - } - case OpRequest::kReshapeRequest: - handle_status = - computation->AddReshapeInstruction(arg->reshape_request()); - break; - case OpRequest::kReverseRequest: - handle_status = - computation->AddReverseInstruction(arg->reverse_request()); - break; - case OpRequest::kRngRequest: - handle_status = computation->AddRngInstruction(arg->rng_request()); - break; - case OpRequest::kSelectAndScatterRequest: { - TF_ASSIGN_OR_RETURN(UserComputation * select, - computation_tracker_.Resolve( - arg->select_and_scatter_request().select())); - TF_ASSIGN_OR_RETURN(UserComputation * scatter, - computation_tracker_.Resolve( - arg->select_and_scatter_request().scatter())); - handle_status = computation->AddSelectAndScatterInstruction( - arg->select_and_scatter_request(), *select, *scatter); - break; - } - case OpRequest::kSliceRequest: - handle_status = computation->AddSliceInstruction(arg->slice_request()); - break; - case OpRequest::kTernaryOpRequest: - handle_status = - computation->AddTernaryInstruction(arg->ternary_op_request()); - break; - case OpRequest::kTraceRequest: - return computation->AddTraceInstruction(arg->trace_request()); - case OpRequest::kTransposeRequest: - handle_status = - computation->AddTransposeInstruction(arg->transpose_request()); - break; - case OpRequest::kUnaryOpRequest: - handle_status = computation->AddUnaryInstruction(arg->unary_op_request()); - break; - case OpRequest::kVariadicOpRequest: - handle_status = - computation->AddVariadicInstruction(arg->variadic_op_request()); - break; - case OpRequest::kWhileRequest: { - TF_ASSIGN_OR_RETURN( - UserComputation * condition, - computation_tracker_.Resolve(arg->while_request().condition())); - TF_ASSIGN_OR_RETURN( - UserComputation * body, - computation_tracker_.Resolve(arg->while_request().body())); - handle_status = computation->AddWhileInstruction(arg->while_request(), - *condition, *body); - break; - } - case OpRequest::kSendRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterSend(arg->send_request().channel_handle())); - // Send does not return a value, but we need a handle to be able to - // set OpMetadata and OpSharding (device assignment). - handle_status = computation->AddSendInstruction(arg->send_request()); - break; - } - case OpRequest::kRecvRequest: { - TF_RETURN_IF_ERROR( - channel_tracker_.RegisterRecv(arg->recv_request().channel_handle())); - handle_status = computation->AddRecvInstruction(arg->recv_request()); - break; - } - case OpRequest::OP_NOT_SET: - return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); - default: - return InvalidArgument("Unsupported operation in XLA service"); - } - TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); - - // We set the debug metadata here, because we slice off part of the OpRequest - // proto in the above switch statement. - TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status); - TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata())); - if (arg->has_sharding()) { - TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding())); - } - return Status::OK(); -} - -Status Service::SnapshotComputation(const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr module, - computation_tracker_.SnapshotComputation(arg->computation())); - - result->set_allocated_module(module.release()); - - return Status::OK(); -} - -Status Service::LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) { - TF_ASSIGN_OR_RETURN(*result->mutable_computation(), - computation_tracker_.LoadSessionModule(arg->module())); - return Status::OK(); -} - DeviceHandle Service::SingleComputationDeviceHandle() const { DeviceHandle device_handle; device_handle.set_handle(0); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 81fbd41957887aec763e1cfe165ad0d1d2ac2269..8748a4c1447eca691abc0f7ca48feda48ceb86e1 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -26,17 +26,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/allocation_tracker.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/channel_tracker.h" -#include "tensorflow/compiler/xla/service/compilation_cache.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/executable.h" #include "tensorflow/compiler/xla/service/execution_tracker.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/user_computation.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/service_interface.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -83,11 +78,6 @@ class Service : public ServiceInterface { static StatusOr> NewService( const ServiceOptions& options); - // Creates a new computation with the given name. - // A unique ComputationHandle is returned. - Status Computation(const ComputationRequest* arg, - ComputationResponse* result) override; - // Unregisters a previously-allocated global handle. // // If the handle given is not currently allocated, a NOT_FOUND status is @@ -100,35 +90,15 @@ class Service : public ServiceInterface { Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) override; - // Modifies the provided computation so that subsequent executions - // will compute the provided ComputationDataHandle, rather than the - // last expression enqueued on that Computation. - Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) override; - - // Executes a computation with the provided global data passed as - // immutable arguments. Returns global data output and execution timing. - Status Execute(const ExecuteRequest* arg, ExecuteResponse* result) override; - // Executes a computation with the provided global data passed as // immutable arguments. The request contains the whole computation graph. // Returns global data output and execution timing. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. Status ExecuteGraph(const ExecuteGraphRequest* arg, ExecuteResponse* result) override; // Executes one or more computations in parallel with the provided global data // passed as immutable arguments. Returns global data output for each // computation. - Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) override; - - // Executes one or more computations in parallel with the provided global data - // passed as immutable arguments. Returns global data output for each - // computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) override; @@ -143,16 +113,6 @@ class Service : public ServiceInterface { Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) override; - // Asynchronously executes a computation with provided arguments. Invokes - // the provided computation with the provided global data passed as - // immutable arguments. Returns a handle to the execution. - // - // (Note: The corresponding function in xla::Client was removed as part of - // b/64116060, in an attempt to simplify our API. We're keeping this around - // for now in case we want to expose this to clients in a different way.) - Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) override; - // Waits until the specified execution is complete and returns the result. // Calling this API multiple times with the same execution handle returns the // method with an error since the execution handle is destroyed after the @@ -190,13 +150,6 @@ class Service : public ServiceInterface { Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) override; - // Tests if an expression is a compile-time constant. - Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) override; - - // Computes the value of a constant expression. - Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) override; Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) override; @@ -205,54 +158,15 @@ class Service : public ServiceInterface { Status GetShape(const GetShapeRequest* arg, GetShapeResponse* result) override; - // Returns the program shape of the computation associated with the given - // handle. - Status GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) override; - - ///// - // Computation-oriented methods. - - // Enqueues an Op on the computation. - Status Op(const OpRequest* arg, OpResponse* result) override; - - // Retrieves the inferred shape for a value within a computation. - Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) override; - // Retrieves the statistics of a computation. - Status GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) override; - - // Retrieves the statistics of a computation. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. Status GetComputationGraphStats(const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) override; - // Snapshots the current state of a computation handle into a serializable - // protocol buffer form, so it can be loaded via - // LoadComputationSnapshot. - Status SnapshotComputation(const SnapshotComputationRequest* arg, - SnapshotComputationResponse* result) override; - - // Loads a computation from a serialized protocol buffer created via - // SnapshotComputation. - Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* arg, - LoadComputationSnapshotResponse* result) override; - // Creates a unique channel handle that can be used for Send/Recv // instructions. Status CreateChannelHandle(const CreateChannelHandleRequest* arg, CreateChannelHandleResponse* result) override; - // Returns the ComputationTracker of the current service instance. - // Only used in unit tests to access user computations from client. - const ComputationTracker& computation_tracker() { - return computation_tracker_; - } - // Returns the backend used to execute computations. const Backend& backend() const { return *execute_backend_; } Backend* mutable_backend() { return execute_backend_.get(); } @@ -263,8 +177,7 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options, - const UserComputation* user_computation = nullptr); + const ExecutionOptions& execution_options); // Picks a parallel response and fills the result. Status PickParallelResponse(const ExecuteParallelResponse& parallel_result, @@ -305,23 +218,13 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - const UserComputation* user_computation = nullptr); + const ExecutionOptions* execution_options); // Builds an Executable for the given parameters. // // If device_allocator is not null, the compiler may use it to allocate temp // buffers, which the compiler is responsible for freeing. The allocator // given here need not match the allocator used when running the executable. - StatusOr> BuildExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, - DeviceMemoryAllocator* device_allocator = nullptr); - - // Builds an Executable for the given HLO module proto. - // - // TODO(b/74197823): This is a part of a NOT YET ready refactor. StatusOr> BuildExecutable( const HloModuleProto& module_proto, std::unique_ptr module_config, Backend* backend, @@ -330,26 +233,12 @@ class Service : public ServiceInterface { // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. - StatusOr>> BuildExecutables( - std::vector versioned_handles, - std::vector> module_configs, - Backend* backend, std::vector> executors, - DeviceMemoryAllocator* device_allocator); StatusOr>> BuildExecutables( const std::vector& module_protos, std::vector> module_configs, Backend* backend, std::vector> executors, DeviceMemoryAllocator* device_allocator); - // Similar to BuildExecutable, but look in the compilation cache for the - // executable first. If the executable is not in the cache, it is built and - // inserted into the cache. - StatusOr> BuildAndCacheExecutable( - const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, Backend* backend, - se::StreamExecutor* executor, ExecutionProfile* profile, - DeviceMemoryAllocator* device_allocator = nullptr); - // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is // returned. If the parameter "profile" is not null, it points to an @@ -372,17 +261,9 @@ class Service : public ServiceInterface { tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile); - // Convenience function for adding a function to a user computation. - template - Status AddInstruction( - const RequestT* arg, ResponseT* result, - const std::function(UserComputation*)>& - adder); - // Executes a single computation which has more than one target device. // The N devices are expected to all return an empty tuple, but one, which // will be the result of this computation. - Status ExecuteOneToN(const ExecuteRequest* arg, ExecuteResponse* result); Status ExecuteOneToN(const ExecuteGraphRequest* arg, ExecuteResponse* result); // Convenience function which checks whether the given shape_with_layout @@ -405,9 +286,6 @@ class Service : public ServiceInterface { ServiceOptions options_; - // Tracks computations built via the API. - ComputationTracker computation_tracker_; - // Tracks channels created via the API. ChannelTracker channel_tracker_; @@ -417,9 +295,6 @@ class Service : public ServiceInterface { // Tracks asynchronously launched executions via the API. ExecutionTracker execution_tracker_; - // Cache containing previously built Executables. - CompilationCache compilation_cache_; - // Backend to compile and execute computations on. std::unique_ptr execute_backend_; diff --git a/tensorflow/compiler/xla/service/session.proto b/tensorflow/compiler/xla/service/session.proto deleted file mode 100644 index bb8d1cd2a106ea3e5bb61eee5052bd60c38cd0e2..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/session.proto +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This proto file defines messages which store the state of XLA -// computations within the XLA service. A computation is stored as a record -// of the operation requests used to build it. -syntax = "proto3"; - -import "tensorflow/compiler/xla/xla_data.proto"; - -package xla; - -// Describes a single operation request. -message OperationRequest { - ComputationDataHandle output_handle = 1; - Shape output_shape = 2; - - // For operations which call embedded computations such as "Map", these are - // the version(s) that the embedded computation should be called at. A version - // value of a computation is the ComputationDataHandle of the root of the - // computation at the point in time. - // - // "Call", "Map", "Reduce", and "ReduceWindow" operations take a single - // embedded computation so this field will have a single value for those - // operations. - // - // "While" operation takes two; index 0 is the "condition" version and index 1 - // is the "body" version. - repeated int64 embedded_computation_versions = 3; - - // The actual request, which in itself is a tagged union of all possible - // operation request types. - OpRequest request = 4; -} - -// Describes a sequence of operation requests which define an XLA -// computation. -message SessionComputation { - string name = 1; - - // The ComputationHandle used to refer to this computation in the XLA - // service. - ComputationHandle computation_handle = 2; - - // Map from ComputationDataHandle value to operation request. The highest - // ComputationDataHandle value corresponds to the root of the computation. - map requests = 3; -} - -// Describes a group of SessionComputations with an "entry point" computation -// that may refer to the other non-entry (AKA embedded) computations. -// -// This message is used to serialize a computation that has been built via the -// XLA service API, along with its dependencies, for purposes such as -// analysis/replay/file-storage. -message SessionModule { - // The entry computation, which was requested for serialization. This may have - // referred to embedded computations, which are reflected below. - SessionComputation entry = 1; - - // Embedded computations that are transitively referred to by the entry - // computation. - repeated SessionComputation embedded_computations = 2; - - // The arguments passed to the computation. - repeated LiteralProto arguments = 3; - - // The result of the computation. - LiteralProto result = 4; - - // The name of the platform used to run the computation. - string execution_platform = 5; -} diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 3500978bdd808f0c7684d14a05636d90105aa594..bd98e86b08b7507b4a7a0d1a7ebac4b654ff2171 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -44,129 +44,6 @@ namespace xla { namespace { -// Return the UnaryOperation proto enum value associated with the given HLO -// opcode. -UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAbs: - return UNOP_ABS; - case HloOpcode::kCeil: - return UNOP_CEIL; - case HloOpcode::kClz: - return UNOP_CLZ; - case HloOpcode::kCos: - return UNOP_COS; - case HloOpcode::kExp: - return UNOP_EXP; - case HloOpcode::kExpm1: - return UNOP_EXPM1; - case HloOpcode::kFloor: - return UNOP_FLOOR; - case HloOpcode::kImag: - return UNOP_IMAG; - case HloOpcode::kIsFinite: - return UNOP_IS_FINITE; - case HloOpcode::kLog: - return UNOP_LOG; - case HloOpcode::kLog1p: - return UNOP_LOG1P; - case HloOpcode::kNot: - return UNOP_NOT; - case HloOpcode::kNegate: - return UNOP_NEGATE; - case HloOpcode::kReal: - return UNOP_REAL; - case HloOpcode::kRoundNearestAfz: - return UNOP_ROUND_NEAREST_AFZ; - case HloOpcode::kSign: - return UNOP_SIGN; - case HloOpcode::kSin: - return UNOP_SIN; - case HloOpcode::kSort: - return UNOP_SORT; - case HloOpcode::kTanh: - return UNOP_TANH; - default: - LOG(FATAL) << "Unhandled opcode for conversion to unary operation: " - << opcode; - } -} - -// Return the BinaryOperation proto enum value associated with the given HLO -// opcode. -BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kAtan2: - return BINOP_ATAN2; - case HloOpcode::kComplex: - return BINOP_COMPLEX; - case HloOpcode::kMultiply: - return BINOP_MUL; - case HloOpcode::kAdd: - return BINOP_ADD; - case HloOpcode::kSubtract: - return BINOP_SUB; - case HloOpcode::kDivide: - return BINOP_DIV; - case HloOpcode::kEq: - return BINOP_EQ; - case HloOpcode::kGe: - return BINOP_GE; - case HloOpcode::kGt: - return BINOP_GT; - case HloOpcode::kLe: - return BINOP_LE; - case HloOpcode::kLt: - return BINOP_LT; - case HloOpcode::kNe: - return BINOP_NE; - case HloOpcode::kMaximum: - return BINOP_MAX; - case HloOpcode::kMinimum: - return BINOP_MIN; - case HloOpcode::kPower: - return BINOP_POW; - case HloOpcode::kRemainder: - return BINOP_REM; - case HloOpcode::kOr: - return BINOP_OR; - case HloOpcode::kAnd: - return BINOP_AND; - case HloOpcode::kShiftLeft: - return BINOP_SHIFT_LEFT; - case HloOpcode::kShiftRightArithmetic: - return BINOP_SHIFT_RIGHT_ARITHMETIC; - case HloOpcode::kShiftRightLogical: - return BINOP_SHIFT_RIGHT_LOGICAL; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - -// Return the TernaryOperation proto enum value associated with the given HLO -// opcode. -TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kClamp: - return TRIOP_CLAMP; - case HloOpcode::kSelect: - return TRIOP_SELECT; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - -// Return the VariadicOperation proto enum value associated with the given HLO -// opcode. -VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) { - switch (opcode) { - case HloOpcode::kTuple: - return VAROP_TUPLE; - default: - LOG(FATAL) << "unhandled opcode " << opcode; - } -} - // Returns true if no element is present in slice more than once. bool AllUnique(tensorflow::gtl::ArraySlice slice) { return std::set(slice.begin(), slice.end()).size() == slice.size(); @@ -316,88 +193,86 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferUnaryOpShape( HloOpcode opcode, const Shape& shape) { // There is no copy operation at the proto level, so handle copy explicitly. - if (opcode == HloOpcode::kCopy) { + // A domain shape is the same as the input one. + if (opcode == HloOpcode::kCopy || opcode == HloOpcode::kDomain) { return shape; } - return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), shape); -} + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(shape, "operand of unary operation")); -/* static */ StatusOr ShapeInference::InferUnaryOpShape( - UnaryOperation operation, const Shape& arg) { - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation")); - - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg)); - switch (operation) { - case UNOP_FLOOR: - case UNOP_CEIL: - if (!ShapeUtil::ElementIsFloating(arg)) { + TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); + switch (opcode) { + case HloOpcode::kFloor: + case HloOpcode::kCeil: + if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( "Expected element type in shape to be floating for floor/ceil " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; - case UNOP_COS: - case UNOP_SIN: - case UNOP_EXP: - case UNOP_EXPM1: - case UNOP_LOG: - case UNOP_LOG1P: - case UNOP_TANH: - if (!ShapeUtil::ElementIsFloating(arg) && - !ShapeUtil::ElementIsComplex(arg)) { + return shape; + case HloOpcode::kCos: + case HloOpcode::kSin: + case HloOpcode::kExp: + case HloOpcode::kExpm1: + case HloOpcode::kLog: + case HloOpcode::kLog1p: + case HloOpcode::kTanh: + if (!ShapeUtil::ElementIsFloating(shape) && + !ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( "Expected element type in shape to be floating or complex for " "sin/cos/exp/log/tanh operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; - case UNOP_REAL: - case UNOP_IMAG: - if (!ShapeUtil::ElementIsComplex(arg)) { + return shape; + case HloOpcode::kReal: + case HloOpcode::kImag: + if (!ShapeUtil::ElementIsComplex(shape)) { return InvalidArgument( "Expected element type in shape to be complex for real/imag " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(arg, F32); - case UNOP_ABS: - if (ShapeUtil::ElementIsComplex(arg)) { + return ShapeUtil::ChangeElementType(shape, F32); + case HloOpcode::kAbs: + if (ShapeUtil::ElementIsComplex(shape)) { return ShapeUtil::ChangeElementType( - arg, primitive_util::ComplexComponentType(arg.element_type())); + shape, primitive_util::ComplexComponentType(shape.element_type())); } - return arg; - case UNOP_CLZ: - case UNOP_NEGATE: - case UNOP_ROUND_NEAREST_AFZ: - case UNOP_SIGN: - case UNOP_SORT: - return arg; - - case UNOP_NOT: - if (arg.element_type() != PRED && - !primitive_util::IsIntegralType(arg.element_type())) { + return shape; + case HloOpcode::kClz: + case HloOpcode::kNegate: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kSign: + case HloOpcode::kSort: + return shape; + + case HloOpcode::kNot: + if (shape.element_type() != PRED && + !primitive_util::IsIntegralType(shape.element_type())) { return InvalidArgument( "Expected pred or an integral element type in argument to Not " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return arg; + return shape; - case UNOP_IS_FINITE: - if (!ShapeUtil::ElementIsFloating(arg)) { + case HloOpcode::kIsFinite: + if (!ShapeUtil::ElementIsFloating(shape)) { return InvalidArgument( - "Expected element type in shape to be floating point for IsFinite " + "Expected element type in shape to be floating " + "point for IsFinite " "operation; got %s.", - PrimitiveType_Name(arg.element_type()).c_str()); + PrimitiveType_Name(shape.element_type()).c_str()); } - return ShapeUtil::ChangeElementType(arg, PRED); + return ShapeUtil::ChangeElementType(shape, PRED); default: return InvalidArgument( "Unknown operation for unary shape inference: \"%s\".", - UnaryOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -462,6 +337,17 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } +/* static */ StatusOr ShapeInference::InferTokenShape( + tensorflow::gtl::ArraySlice arg_shapes) { + for (const Shape* arg_shape : arg_shapes) { + if (arg_shape->element_type() != TOKEN) { + return InvalidArgument( + "Operands of token instructions must be TOKEN types."); + } + } + return ShapeUtil::MakeTokenShape(); +} + /* static */ StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); @@ -767,8 +653,9 @@ Status ValidateDotDimensionNumbers( } /* static */ StatusOr -ShapeInference::InferDegenerateDimensionBroadcastShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs) { +ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, + const Shape& lhs, + const Shape& rhs) { TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)); // The shapes have to be compatible. That is, if some dimension d has a @@ -786,7 +673,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } else { return InvalidArgument( "Binary op %s with incompatible shapes: %s and %s.", - BinaryOperation_Name(operation).c_str(), + HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -796,8 +683,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( - BinaryOperation operation, const Shape& smaller_shape, - const Shape& larger_shape, + const Shape& smaller_shape, const Shape& larger_shape, tensorflow::gtl::ArraySlice broadcast_dimensions) { if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) { // Reject "magic" inference for binops on different shapes, requiring @@ -898,7 +784,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } /* static */ StatusOr ShapeInference::InferElementwiseBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation")); @@ -908,8 +794,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Binary op %s with different element types: %s and %s.", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), + HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } @@ -942,10 +827,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs; // After InDim broadcasting, perform degenerate dimensions broadcasting. - TF_ASSIGN_OR_RETURN( - Shape indim_broadcast_shape, - InferInDimBroadcastShape(operation, smaller_shape, larger_shape, - broadcast_dimensions)); + TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape, + InferInDimBroadcastShape(smaller_shape, larger_shape, + broadcast_dimensions)); return InferDegenerateDimensionBroadcastShape( operation, indim_broadcast_shape, larger_shape); @@ -954,51 +838,44 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) { - return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(), - rhs->shape(), /*broadcast_dimensions=*/{}); + return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(), + /*broadcast_dimensions=*/{}); } /* static */ StatusOr ShapeInference::InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs, - broadcast_dimensions); -} - -/* static */ StatusOr ShapeInference::InferBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions) { VLOG(2) << tensorflow::strings::Printf( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), + HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(), + ShapeUtil::HumanString(rhs).c_str(), Join(broadcast_dimensions, ", ").c_str()); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( lhs, tensorflow::strings::StrCat("lhs of binary operation ", - BinaryOperation_Name(operation)))); + HloOpcodeString(opcode)))); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( rhs, tensorflow::strings::StrCat("rhs of binary operation ", - BinaryOperation_Name(operation)))); - switch (operation) { - case BINOP_MAX: - case BINOP_MIN: - case BINOP_SUB: - case BINOP_ADD: - case BINOP_ATAN2: - case BINOP_POW: - case BINOP_DIV: - case BINOP_REM: - case BINOP_MUL: - case BINOP_SHIFT_LEFT: - case BINOP_SHIFT_RIGHT_ARITHMETIC: - case BINOP_SHIFT_RIGHT_LOGICAL: - return InferElementwiseBinaryOpShape(operation, lhs, rhs, + HloOpcodeString(opcode)))); + switch (opcode) { + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kSubtract: + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kPower: + case HloOpcode::kDivide: + case HloOpcode::kRemainder: + case HloOpcode::kMultiply: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case BINOP_COMPLEX: { + case HloOpcode::kComplex: { if (!ShapeUtil::ElementIsFloating(lhs)) { return InvalidArgument( "Expected element type in shape to be floating for complex compose " @@ -1006,7 +883,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(lhs.element_type()).c_str()); } TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(operation, lhs, rhs, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); @@ -1014,8 +891,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return Unimplemented("Complex component type is not implemented."); } } - case BINOP_AND: - case BINOP_OR: + case HloOpcode::kAnd: + case HloOpcode::kOr: if (lhs.element_type() != PRED && !primitive_util::IsIntegralType(lhs.element_type())) { return InvalidArgument( @@ -1023,24 +900,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "got %s.", PrimitiveType_Name(lhs.element_type()).c_str()); } - return InferElementwiseBinaryOpShape(operation, lhs, rhs, + return InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions); - case BINOP_EQ: - case BINOP_GE: - case BINOP_GT: - case BINOP_LE: - case BINOP_LT: - case BINOP_NE: { + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: { TF_ASSIGN_OR_RETURN(const Shape& shape, - InferElementwiseBinaryOpShape(operation, lhs, rhs, + InferElementwiseBinaryOpShape(opcode, lhs, rhs, broadcast_dimensions)); return ShapeUtil::ChangeElementType(shape, PRED); } default: return Unimplemented( "Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.", - BinaryOperation_Name(operation).c_str(), - lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str()); + HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(), + rhs.ShortDebugString().c_str()); } } @@ -1052,23 +929,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferTernaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) { - return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs, rhs, ehs); -} - -/* static */ StatusOr ShapeInference::InferTernaryOpShape( - TernaryOperation operation, const Shape& lhs, const Shape& rhs, - const Shape& ehs) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs)); - switch (operation) { - case TRIOP_CLAMP: + switch (opcode) { + case HloOpcode::kClamp: return InferClampShape(lhs, rhs, ehs); - case TRIOP_SELECT: + case HloOpcode::kSelect: return InferSelectShape(lhs, rhs, ehs); default: return InvalidArgument("Unknown operation %s.", - TernaryOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } @@ -1085,18 +956,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferVariadicOpShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operand_shapes) { - return InferVariadicOpShape(OpcodeToVariadicOperation(opcode), - operand_shapes); -} - -/* static */ StatusOr ShapeInference::InferVariadicOpShape( - VariadicOperation operation, - tensorflow::gtl::ArraySlice operand_shapes) { for (const Shape* shape : operand_shapes) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape)); } - switch (operation) { - case VAROP_TUPLE: { + switch (opcode) { + case HloOpcode::kTuple: { Shape result = ShapeUtil::MakeTupleShape({}); for (const Shape* shape : operand_shapes) { ShapeUtil::AppendShapeToTuple(*shape, &result); @@ -1105,7 +969,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } default: return InvalidArgument("Unknown operation %s.", - VariadicOperation_Name(operation).c_str()); + HloOpcodeString(opcode).c_str()); } } diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 9da2c99b4177f08ece8daabaf2922ddd7e947a1b..f1f7b50902d899c0c629c3098d80fc400fb1388d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -46,8 +46,6 @@ class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the // given input shape. - static StatusOr InferUnaryOpShape(UnaryOperation operation, - const Shape& arg); static StatusOr InferUnaryOpShape(HloOpcode opcode, const Shape& shape); static StatusOr InferUnaryOpShape(HloOpcode opcode, @@ -55,9 +53,6 @@ class ShapeInference { // Infers the shape produced by applying the given binary operation to the // given input shapes. - static StatusOr InferBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, - tensorflow::gtl::ArraySlice broadcast_dimensions); static StatusOr InferBinaryOpShape( HloOpcode opcode, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); @@ -67,9 +62,6 @@ class ShapeInference { // Infers the shape produced by applying the given ternary operation to the // given input shapes. - static StatusOr InferTernaryOpShape(TernaryOperation operation, - const Shape& lhs, const Shape& rhs, - const Shape& ehs); static StatusOr InferTernaryOpShape(HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs); @@ -80,9 +72,6 @@ class ShapeInference { // Infers the shape produced by applying the given variadic operation to the // given input operand shapes. - static StatusOr InferVariadicOpShape( - VariadicOperation operation, - tensorflow::gtl::ArraySlice operand_shapes); static StatusOr InferVariadicOpShape( HloOpcode opcode, tensorflow::gtl::ArraySlice operand_shapes); @@ -227,6 +216,13 @@ class ShapeInference { static StatusOr InferConcatOpShape( tensorflow::gtl::ArraySlice arg_shapes, int64 dimension); + // Infers the shape produced by a kGenerateToken operation. Trivially this + // shape is always a TOKEN shape. However, ShapeInference serves two purposes: + // inferring shapes and checking operand shapes. This method verifies that the + // operand shapes are all TOKENs. + static StatusOr InferTokenShape( + tensorflow::gtl::ArraySlice arg_shapes); + // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that // the shape is identical except for the element type. @@ -279,7 +275,7 @@ class ShapeInference { // the LHS and a single element in the RHS to produce a single output element, // even in the presence of broadcasting of one of the operands over the other. static StatusOr InferElementwiseBinaryOpShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs, + HloOpcode operation, const Shape& lhs, const Shape& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions); // Helper for inferring the shape of Clamp ops. @@ -295,7 +291,7 @@ class ShapeInference { // dimension broadcasting (a dimension of size 1 in one operand is broadcast // up to match the size of the dimension in the other operand). static StatusOr InferDegenerateDimensionBroadcastShape( - BinaryOperation operation, const Shape& lhs, const Shape& rhs); + HloOpcode operation, const Shape& lhs, const Shape& rhs); // Helper for inferring shapes of binary operations using "InDim" // broadcasting. This is the broadcasting used in the *InDim binary operations @@ -303,8 +299,7 @@ class ShapeInference { // lower-rank shape than larger_shape. Returns the shape that the // smaller_shape is broadcast to. static StatusOr InferInDimBroadcastShape( - BinaryOperation operation, const Shape& smaller_shape, - const Shape& larger_shape, + const Shape& smaller_shape, const Shape& larger_shape, tensorflow::gtl::ArraySlice broadcast_dimensions); TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference); diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 0e61994a786b53a295ef9c9c2287b28fbf754d9b..6d017dffe2d8f927abad4a62bff7fe41bc871975 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -101,8 +101,8 @@ class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest { TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}); - auto inferred_status = ShapeInference::InferUnaryOpShape( - UnaryOperation::UNOP_NEGATE, matrix_shape); + auto inferred_status = + ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie())); } @@ -110,14 +110,14 @@ TEST_F(ShapeInferenceTest, UnaryNegateMatrix) { TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) { Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_}); auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple); + HloOpcode::kSelect, pred_, tuple, tuple); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } @@ -125,34 +125,34 @@ TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) { TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) { auto predarray = ShapeUtil::MakeShape(PRED, {64, 48}); auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, SelectBadShapes) { auto inferred_status_error1 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_); + HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Operands to select must be the same shape")); auto inferred_status_error2 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("pred operand must have PRED")); auto inferred_status_error3 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}), - matrix_64_48_, matrix_64_48_); + HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_, + matrix_64_48_); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("with non-scalar predicate with dimensionality")); // Tuples have a TUPLE element type and cannot be the pred of a select. auto inferred_status_error4 = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}), + HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}), ShapeUtil::MakeTupleShape({f32_, f32_}), ShapeUtil::MakeTupleShape({f32_, f32_})); ASSERT_FALSE(inferred_status_error4.ok()); @@ -162,102 +162,98 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { TEST_F(ShapeInferenceTest, ClampAllMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, - matrix_64_48_); + HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampAllScalar) { - auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, f32_); + auto inferred_status = + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, matrix_64_48_); + HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMaxScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, matrix_64_48_, f32_); + HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampOperandScalar) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, matrix_64_48_); + HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMinMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, matrix_64_48_, f32_, f32_); + HloOpcode::kClamp, matrix_64_48_, f32_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampMaxMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, matrix_64_48_); + HloOpcode::kClamp, f32_, f32_, matrix_64_48_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampOperandMatrix) { auto inferred_status = ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, matrix_64_48_, f32_); + HloOpcode::kClamp, f32_, matrix_64_48_, f32_); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie())); } TEST_F(ShapeInferenceTest, ClampBadShapes) { // Type mismatch - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, s32_, f32_, f32_) - .ok()); - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, s32_, f32_) - .ok()); - ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, f32_, s32_) - .ok()); - // Dimension mismatch ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_64_, vector_32_, vector_32_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_) .ok()); ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_32_, vector_64_, vector_32_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_) .ok()); ASSERT_FALSE( - ShapeInference::InferTernaryOpShape(TernaryOperation::TRIOP_CLAMP, - vector_32_, vector_32_, vector_64_) + ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_) .ok()); - // Dimension mismatch, where one operand is a scalar + // Dimension mismatch ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, vector_64_, vector_32_, f32_) + HloOpcode::kClamp, vector_64_, vector_32_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, vector_64_, f32_, vector_32_) + HloOpcode::kClamp, vector_32_, vector_64_, vector_32_) .ok()); ASSERT_FALSE(ShapeInference::InferTernaryOpShape( - TernaryOperation::TRIOP_CLAMP, f32_, vector_64_, vector_32_) + HloOpcode::kClamp, vector_32_, vector_32_, vector_64_) + .ok()); + // Dimension mismatch, where one operand is a scalar + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, + vector_64_, vector_32_, f32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, + vector_64_, f32_, vector_32_) + .ok()); + ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, + vector_64_, vector_32_) .ok()); } TEST_F(ShapeInferenceTest, Complex) { auto complex_shape = [&](const Shape& lhs, const Shape& rhs, const tensorflow::gtl::ArraySlice& bcast) { - return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX, - lhs, rhs, bcast); + return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs, + bcast); }; // Inputs must be FP. ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok()); @@ -292,8 +288,8 @@ TEST_F(ShapeInferenceTest, Complex) { } TEST_F(ShapeInferenceTest, VariadicOpTuplify) { - StatusOr result = ShapeInference::InferVariadicOpShape( - VariadicOperation::VAROP_TUPLE, {&s32_, &f32_}); + StatusOr result = + ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_}); ASSERT_IS_OK(result.status()); ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(), ShapeUtil::MakeTupleShape({s32_, f32_}))); @@ -804,8 +800,8 @@ TEST_F(ShapeInferenceTest, InferConstIndexShape) { TEST_F(ShapeInferenceTest, InferPowShape) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); - auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {}); + auto inferred_status = ShapeInference::InferBinaryOpShape( + HloOpcode::kPower, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie())); } @@ -813,7 +809,7 @@ TEST_F(ShapeInferenceTest, InferPowShape) { TEST_F(ShapeInferenceTest, InferCompareShapeEq) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kEq, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -822,7 +818,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeEq) { TEST_F(ShapeInferenceTest, InferCompareShapeGe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kGe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -831,7 +827,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGe) { TEST_F(ShapeInferenceTest, InferCompareShapeGt) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kGt, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -840,7 +836,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeGt) { TEST_F(ShapeInferenceTest, InferCompareShapeLe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kLe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -849,7 +845,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLe) { TEST_F(ShapeInferenceTest, InferCompareShapeLt) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kLt, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -858,7 +854,7 @@ TEST_F(ShapeInferenceTest, InferCompareShapeLt) { TEST_F(ShapeInferenceTest, InferCompareShapeNe) { auto ten_floats = ShapeUtil::MakeShape(F32, {10}); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {}); + ShapeInference::InferBinaryOpShape(HloOpcode::kNe, ten_floats, f32_, {}); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}), inferred_status.ValueOrDie())); @@ -1111,22 +1107,22 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { const Shape vec8 = ShapeUtil::MakeShape(F32, {8}); const Shape vec16 = ShapeUtil::MakeShape(F32, {16}); - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec8, {1}); + auto inferred_status_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec8, {0}); + auto inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0}); ASSERT_FALSE(inferred_status_mismatch.ok()); - inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec16, {0}); + inferred_status_match = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat)); - inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, mat, vec16, {1}); + inferred_status_mismatch = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1}); ASSERT_FALSE(inferred_status_mismatch.ok()); } @@ -1138,17 +1134,17 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) { const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8}); auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2}); + HloOpcode::kAdd, cube, matrix8_4, {1, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2}); + HloOpcode::kAdd, cube, matrix16_4, {0, 2}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1}); + HloOpcode::kAdd, cube, matrix16_8, {0, 1}); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube)); } @@ -1162,43 +1158,43 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8}); // "magical" broadcast rejected - auto inferred_status_error1 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {}); + auto inferred_status_error1 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("Automatic")); // broadcast_dimension out of bounds for tensor's rank - auto inferred_status_error2 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {3}); + auto inferred_status_error2 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), ContainsRegex("Broadcast dimension number .* too large")); // broadcast_dimension doesn't match corresponding dimension - auto inferred_status_error3 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, vec8, {0}); + auto inferred_status_error3 = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("Broadcast dimension 0 mismatch")); // broadcast_dimensions list too long auto inferred_status_error4 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2}); + HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2}); ASSERT_FALSE(inferred_status_error4.ok()); ASSERT_THAT(inferred_status_error4.status().error_message(), HasSubstr("broadcast_dimensions has to match")); // there's a dimension above the rank of the tensor auto inferred_status_error5 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0}); + HloOpcode::kAdd, tensor, matrix8_4, {3, 0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), ContainsRegex("dimension number .* too large")); // broadcasting dimensions don't match in this order auto inferred_status_error6 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1}); + HloOpcode::kAdd, tensor, matrix8_4, {2, 1}); ASSERT_FALSE(inferred_status_error6.ok()); ASSERT_THAT(inferred_status_error6.status().error_message(), HasSubstr("dimension 0 mismatch")); @@ -1207,13 +1203,13 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { // in a proper (strictly increasing) order, even if the lower-rank array // matches the higher-rank array in many different ways. auto inferred_status_error7 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0}); + HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0}); ASSERT_FALSE(inferred_status_error7.ok()); ASSERT_THAT(inferred_status_error7.status().error_message(), HasSubstr("dimensions order is wrong")); auto inferred_status_error8 = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0}); + HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0}); ASSERT_FALSE(inferred_status_error8.ok()); ASSERT_THAT(inferred_status_error8.status().error_message(), HasSubstr("dimensions order is wrong")); diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index ba16dc640e2d2974eab4fc8b134a6e33c03e3b85..49e1f873192f800056a2272f7d4f698898b0f8a1 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -178,7 +178,6 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto new_conv = HloInstruction::CreateConvolve( convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); - convolution.SetupDerivedInstruction(new_conv.get()); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index f73f1227aaf1630a9e7c43bb508732c5518ef929..cccb8f2fbb0266bbf1f40b09170938a1e5d3e78d 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -27,12 +27,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" @@ -69,7 +69,7 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); FoldTranspose(module.get()); @@ -91,7 +91,7 @@ ENTRY entry_computation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TransposeFolding transpose_folding( [](const HloInstruction& dot, @@ -119,7 +119,7 @@ ENTRY entry_computation { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TransposeFolding transpose_folding( [](const HloInstruction& dot, @@ -147,7 +147,7 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); FoldTranspose(module.get()); @@ -176,7 +176,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { HloComputation* entry_computation = module->AddEntryComputation(builder.Build(mul)); HloInstruction* call = module->OutlineExpressionFromComputation( - {add, sub, mul}, "", entry_computation); + {add, sub, mul}, "entry", entry_computation); EXPECT_EQ(call, entry_computation->root_instruction()); HloComputation* callee_computation = call->to_apply(); // The arguments to the call should be const1, const2, and const3. @@ -205,7 +205,7 @@ ENTRY entry_computation { } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); FoldTranspose(module.get()); const HloComputation* callee = module->GetComputationWithName("callee"); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 8cb654493ca82dc702b2c1e7a4284f4f31d1e5f9..eb6d1ada6b553f998fe06917dfdf0b5092cd79cd 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -273,6 +273,14 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleDomain(HloInstruction* domain) { + // A kDomain instruction aliases its operand. That is, the buffer of its + // result *is* the buffer of its operand, so just copy the operands points-to + // set. + CreateCopiedPointsToSet(domain, domain->operand(0)); + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { // A kSlice instruction aliases its operand if the backend lowers it to an // in-place implementation. @@ -715,15 +723,16 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( return false; } if (user->opcode() == HloOpcode::kFusion) { - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && - user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - // Loop fusion with kDynamicUpdateSlice fused root. - // - // Returns true iff there is exactly one use of 'operand' at shape index - // 'operand_index', and this singleton use is the fused root at operand - // index 0. - return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0); + } } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -781,8 +790,12 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser( return param_uses.size() == 1 && param_uses[0].first == callee_root && callee_root->IsElementwiseOnOperand(param_uses[0].second); } - // Check if 'user' is element-wise. - return user->IsElementwise(); + // Loop fusions that contain transposing copies won't reach here as they have + // different layouts, which fails the check in the beginning of this function. + // + // Multi-output fusion will fail the check here as tuples are not considered + // an elementwise operation. + return user->IsElementwiseOnOperand(user->operand_index(operand)); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 1ac713013650d807b15e33565e6d2dec406a5d13..c0d82414806d9a6ff57aec59d077f444137fec9a 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -248,6 +248,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleDomain(HloInstruction* domain) override; Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index f558316b05b168a6f100e8ef69adfd9dbc023102..5734f284071944bc22011405898cf86f33dc48d7 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1148,5 +1148,30 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { call, {})); } +TEST_F(CanShareOperandBufferWithUserTest, LoopFusionWithElementwiseOperand) { + Shape full_shape = ShapeUtil::MakeShape(F32, {16, 32}); + Shape broadcast_shape = ShapeUtil::MakeShape(F32, {16}); + + auto builder = HloComputation::Builder(TestName() + "_fusion"); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, full_shape, "full")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, broadcast_shape, "small")); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(full_shape, param1, {0})); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + full_shape, HloOpcode::kAdd, param0, broadcast)); + + BuildModule(builder.Build()); + auto fusion = computation_->CreateFusionInstruction( + {add, broadcast}, HloInstruction::FusionKind::kLoop); + RunAnalysis(); + + EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(param0, {}, + fusion, {})); + EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(param1, {}, + fusion, {})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc index d668855084a884518b338cdf396a9330b9f43a2b..77bdcc9de0d830991208a1db271d009bccaf550e 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc @@ -30,10 +30,17 @@ limitations under the License. namespace xla { +TupleSimplifier::TupleSimplifier(bool exclude_entry_computation) : + exclude_entry_computation_(exclude_entry_computation) {} + StatusOr TupleSimplifier::Run(HloModule* module) { // Initially add all GTE and Tuple instructions to the worklist. std::queue worklist; for (auto* computation : module->computations()) { + if (exclude_entry_computation_ && + computation == module->entry_computation()) { + continue; + } for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kTuple || instruction->opcode() == HloOpcode::kGetTupleElement) { @@ -69,7 +76,6 @@ StatusOr TupleSimplifier::Run(HloModule* module) { // Tuple // HloInstruction* top_tuple = nullptr; - HloInstruction* first_gte = nullptr; bool can_simplify = true; for (int64 operand_number = 0; operand_number < instruction->operand_count(); ++operand_number) { @@ -79,17 +85,10 @@ StatusOr TupleSimplifier::Run(HloModule* module) { can_simplify = false; break; } - if (first_gte == nullptr) { - first_gte = operand; - } else if (!first_gte->has_compatible_sharding(operand)) { - can_simplify = false; - break; - } if (top_tuple == nullptr) { top_tuple = operand->mutable_operand(0); if (!ShapeUtil::Compatible(top_tuple->shape(), - instruction->shape()) || - !instruction->has_compatible_sharding(top_tuple)) { + instruction->shape())) { can_simplify = false; break; } @@ -118,14 +117,12 @@ StatusOr TupleSimplifier::Run(HloModule* module) { HloInstruction* element_source = instruction->mutable_operand(0)->mutable_operand( instruction->tuple_index()); - if (instruction->has_compatible_sharding(element_source)) { - changed = true; - TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); - for (HloInstruction* user : element_source->users()) { - if (user->opcode() == HloOpcode::kTuple || - user->opcode() == HloOpcode::kGetTupleElement) { - worklist.push(user); - } + changed = true; + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); + for (HloInstruction* user : element_source->users()) { + if (user->opcode() == HloOpcode::kTuple || + user->opcode() == HloOpcode::kGetTupleElement) { + worklist.push(user); } } } diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h index e5e9b10b5bf3f452d1bfec476b8d5c7d74c4f4e8..750950188312c5077d487f2feef0606f07839432 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier.h +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -27,13 +27,20 @@ namespace xla { // the module. class TupleSimplifier : public HloPassInterface { public: - TupleSimplifier() {} + TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} + explicit TupleSimplifier(bool exclude_entry_computation); ~TupleSimplifier() override {} tensorflow::StringPiece name() const override { return "tuple-simplifier"; } // Run tuple simplification on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + private: + // When set, this pipeline stage will perform optimization of all computations + // apart from the module's entry computation. This is used by Graphcore's + // backend. + bool exclude_entry_computation_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc index ca9ae91281fce5ee061d066fc3e538dbbc09f6b3..d3635eae81ec7017f9bf6a69250d10716309c9ec 100644 --- a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -42,6 +42,12 @@ class TupleSimplifierTest : public HloTestBase { TF_ASSERT_OK(changed_status.status()); EXPECT_EQ(change_expected, changed_status.ValueOrDie()); } + void Run(HloModule* module, bool change_expected, bool exclude_entry) { + TupleSimplifier simplifier(exclude_entry); + auto changed_status = simplifier.Run(module); + TF_ASSERT_OK(changed_status.status()); + EXPECT_EQ(change_expected, changed_status.ValueOrDie()); + } const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( @@ -211,5 +217,76 @@ TEST_F(TupleSimplifierTest, IncompatibleTuples) { EXPECT_THAT(computation->root_instruction(), tuple); } +TEST_F(TupleSimplifierTest, CanExcludeEntryComputation) { + // Verify that the root computation can be excluded + auto module = CreateNewModule(); + + HloInstruction* p0; + HloInstruction* p1; + HloComputation* c0; + HloComputation* c1; + HloComputation* entry; + + { + HloComputation::Builder builder(TestName() + "_1"); + p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p0, 2)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + c0 = module->AddEmbeddedComputation(builder.Build()); + } + { + HloComputation::Builder builder(TestName() + "_2"); + p1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, p1, 2)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + c1 = module->AddEmbeddedComputation(builder.Build()); + } + { + HloComputation::Builder builder(TestName() + "_Entry"); + HloInstruction* tuple_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* call0 = builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c0)); + HloInstruction* call1 = builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape_, {tuple_param}, c1)); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, call0, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, call1, 1)); + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 0)); + HloInstruction* gte3 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple0, 1)); + + builder.AddInstruction(HloInstruction::CreateTuple({gte2, gte3})); + + entry = module->AddEntryComputation(builder.Build()); + } + + Run(module.get(), /*change_expected=*/true, /*exclude_entry=*/ true); + + EXPECT_THAT(c0->root_instruction(), p0); + EXPECT_THAT(c1->root_instruction(), p1); + EXPECT_THAT(entry->instruction_count(), 9); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc index 754fd8ef169231827eeb5bfd72aeb596644ca767..d33d5bb8f30c8504aa323d461e5f59709b48e1fc 100644 --- a/tensorflow/compiler/xla/service/tuple_util_test.cc +++ b/tensorflow/compiler/xla/service/tuple_util_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/tuple_util.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -37,7 +37,7 @@ ENTRY entry { )"; TF_ASSIGN_OR_RETURN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc deleted file mode 100644 index 9e62d0acfb98946f1e693fc0310098b4ec99750b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ /dev/null @@ -1,3557 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/user_computation.h" - -#include -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace xla { -namespace { - -HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { - switch (unop) { - case UNOP_ABS: - return HloOpcode::kAbs; - case UNOP_CEIL: - return HloOpcode::kCeil; - case UNOP_CLZ: - return HloOpcode::kClz; - case UNOP_COS: - return HloOpcode::kCos; - case UNOP_EXP: - return HloOpcode::kExp; - case UNOP_EXPM1: - return HloOpcode::kExpm1; - case UNOP_FLOOR: - return HloOpcode::kFloor; - case UNOP_IMAG: - return HloOpcode::kImag; - case UNOP_IS_FINITE: - return HloOpcode::kIsFinite; - case UNOP_LOG: - return HloOpcode::kLog; - case UNOP_LOG1P: - return HloOpcode::kLog1p; - case UNOP_NOT: - return HloOpcode::kNot; - case UNOP_NEGATE: - return HloOpcode::kNegate; - case UNOP_REAL: - return HloOpcode::kReal; - case UNOP_ROUND_NEAREST_AFZ: - return HloOpcode::kRoundNearestAfz; - case UNOP_SIGN: - return HloOpcode::kSign; - case UNOP_SIN: - return HloOpcode::kSin; - case UNOP_SORT: - return HloOpcode::kSort; - case UNOP_TANH: - return HloOpcode::kTanh; - default: - LOG(FATAL) << "unhandled operation " << unop; - } -} - -HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { - switch (binop) { - case BINOP_ATAN2: - return HloOpcode::kAtan2; - case BINOP_COMPLEX: - return HloOpcode::kComplex; - case BINOP_MUL: - return HloOpcode::kMultiply; - case BINOP_ADD: - return HloOpcode::kAdd; - case BINOP_SUB: - return HloOpcode::kSubtract; - case BINOP_DIV: - return HloOpcode::kDivide; - case BINOP_EQ: - return HloOpcode::kEq; - case BINOP_GE: - return HloOpcode::kGe; - case BINOP_GT: - return HloOpcode::kGt; - case BINOP_LE: - return HloOpcode::kLe; - case BINOP_LT: - return HloOpcode::kLt; - case BINOP_NE: - return HloOpcode::kNe; - case BINOP_MAX: - return HloOpcode::kMaximum; - case BINOP_MIN: - return HloOpcode::kMinimum; - case BINOP_POW: - return HloOpcode::kPower; - case BINOP_REM: - return HloOpcode::kRemainder; - case BINOP_OR: - return HloOpcode::kOr; - case BINOP_AND: - return HloOpcode::kAnd; - case BINOP_SHIFT_LEFT: - return HloOpcode::kShiftLeft; - case BINOP_SHIFT_RIGHT_ARITHMETIC: - return HloOpcode::kShiftRightArithmetic; - case BINOP_SHIFT_RIGHT_LOGICAL: - return HloOpcode::kShiftRightLogical; - default: - LOG(FATAL) << "unhandled operation " << binop; - } -} - -HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) { - switch (triop) { - case TRIOP_CLAMP: - return HloOpcode::kClamp; - case TRIOP_SELECT: - return HloOpcode::kSelect; - default: - LOG(FATAL) << "unhandled operation " << triop; - } -} - -HloOpcode VariadicOperationToHloOpcode(VariadicOperation varop) { - switch (varop) { - case VAROP_TUPLE: - return HloOpcode::kTuple; - default: - LOG(FATAL) << "unhandled operation " << varop; - } -} - -} // namespace - -/* static */ StatusOr> -UserComputation::MakeWithRemapping( - const SessionComputation& session_computation, - const ComputationHandle& handle, - const std::map& old_to_new) { - auto user_computation = - MakeUnique(session_computation.name(), handle); - { - tensorflow::mutex_lock lock(user_computation->mutex_); - user_computation->session_computation_ = session_computation; - user_computation->next_handle_value_ = - std::max_element(session_computation.requests().begin(), - session_computation.requests().end(), - [](const std::pair& lhs, - const std::pair& rhs) { - return lhs.first < rhs.first; - }) - ->first + - 1; - TF_RETURN_IF_ERROR(user_computation->RemapEmbeddedComputations(old_to_new)); - } - - return std::move(user_computation); -} - -UserComputation::UserComputation(const string& name, - const ComputationHandle& handle) - : name_(name), next_handle_value_(1) { - *session_computation_.mutable_computation_handle() = handle; - session_computation_.set_name(name); - - VLOG(1) << "New UserComputation \"" << name - << "\", handle: " << handle.handle(); -} - -ComputationDataHandle UserComputation::CreateComputationDataHandle() { - ComputationDataHandle handle; - handle.set_handle(next_handle_value_); - // Handles are used as Version values and *must* be assigned consecutively for - // computation versioning to work. - next_handle_value_++; - return handle; -} - -StatusOr UserComputation::AddParameterInstruction( - const ParameterRequest& parameter_request) { - tensorflow::mutex_lock lock(mutex_); - - int64 parameter_number = parameter_request.parameter(); - if (parameters_.count(parameter_number) != 0) { - return InvalidArgument("parameter %lld already registered", - parameter_number); - } - ComputationDataHandle handle = CreateComputationDataHandle(); - - const Shape& validated_shape = parameter_request.shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_parameter_request() = parameter_request; - - parameters_[parameter_number] = &request; - - VLOG(1) << "AddParameterInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << parameter_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSendInstruction( - const SendRequest& send_request) { - tensorflow::mutex_lock lock(mutex_); - - // Check if the operand of the instruction is valid. - TF_RETURN_IF_ERROR(LookUpRequest(send_request.operand()).status()); - - // No handle is returned, but a handle must be assigned to this instruction - // for computation versioning. - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); - *request.mutable_request()->mutable_send_request() = send_request; - - VLOG(1) << "AddSendInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << send_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddRecvInstruction( - const RecvRequest& recv_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = recv_request.shape(); - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_recv_request() = recv_request; - - VLOG(1) << "AddRecvInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << recv_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddPadInstruction( - const PadRequest& pad_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(pad_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* padding_value, - LookUpRequest(pad_request.padding_value())); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferPadShape( - operand->output_shape(), - padding_value->output_shape(), - pad_request.padding_config())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_pad_request() = pad_request; - - VLOG(1) << "AddPadInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << pad_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConstantInstruction( - const ConstantRequest& constant_request) { - const Shape& validated_shape = constant_request.literal().shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - tensorflow::mutex_lock lock(mutex_); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_constant_request() = constant_request; - - VLOG(1) << "AddConstantInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle(); - return handle; -} - -StatusOr UserComputation::AddGatherInstruction( - const GatherRequest& gather_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* input_request, - LookUpRequest(gather_request.input())); - TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request, - LookUpRequest(gather_request.gather_indices())); - - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferGatherShape( - input_request->output_shape(), gather_indices_request->output_shape(), - gather_request.dimension_numbers(), - AsInt64Slice(gather_request.window_bounds()))); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_gather_request() = gather_request; - - VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << gather_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddGetTupleElementInstruction( - const GetTupleElementRequest& get_tuple_element_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(get_tuple_element_request.operand())); - if (!ShapeUtil::IsTuple(operand->output_shape())) { - return InvalidArgument( - "Operand to GetTupleElement() is not a tuple; got %s", - ShapeUtil::HumanString(operand->output_shape()).c_str()); - } - Shape element_shape = ShapeUtil::GetTupleElementShape( - operand->output_shape(), get_tuple_element_request.index()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = element_shape; - *request.mutable_request()->mutable_get_tuple_element_request() = - get_tuple_element_request; - - VLOG(1) << "AddGetTupleElementInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << get_tuple_element_request.ShortDebugString(); - return handle; -} - -Status UserComputation::AddTraceInstruction(const TraceRequest& trace_request) { - tensorflow::mutex_lock lock(mutex_); - - // Verify that the operand index is valid. - TF_RETURN_IF_ERROR(LookUpRequest(trace_request.operand()).status()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = ShapeUtil::MakeNil(); - *request.mutable_request()->mutable_trace_request() = trace_request; - - VLOG(1) << "AddTraceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << trace_request.ShortDebugString(); - return Status::OK(); -} - -StatusOr UserComputation::AddRngInstruction( - const RngRequest& rng_request) { - tensorflow::mutex_lock lock(mutex_); - - // Check the number of parameters per RNG distribution. - switch (rng_request.distribution()) { - case RandomDistribution::RNG_NORMAL: - case RandomDistribution::RNG_UNIFORM: - if (rng_request.parameter_size() != 2) { - return InvalidArgument( - "RNG distribution (%s) expects 2 parameters, but got %d", - RandomDistribution_Name(rng_request.distribution()).c_str(), - rng_request.parameter_size()); - } - break; - default: - LOG(FATAL) << "unhandled distribution " << rng_request.distribution(); - } - - // Verify that the parameter indices are valid; - for (const ComputationDataHandle& param : rng_request.parameter()) { - TF_RETURN_IF_ERROR(LookUpRequest(param).status()); - } - const Shape& validated_shape = rng_request.shape(); - TF_RETURN_IF_ERROR( - ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = validated_shape; - *request.mutable_request()->mutable_rng_request() = rng_request; - - VLOG(1) << "AddRngInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << rng_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddMapInstruction( - const MapRequest& map_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : map_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape, - AsInt64Slice(map_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_map_request() = map_request; - - VLOG(1) << "AddMapInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << map_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReduceInstruction( - const ReduceRequest& reduce_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(reduce_request.init_value())); - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReduceShape( - operand->output_shape(), init_value->output_shape(), - AsInt64Slice(reduce_request.dimensions()), *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_reduce_request() = reduce_request; - - VLOG(1) << "AddReduceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_request.ShortDebugString(); - return handle; -} - -StatusOr -UserComputation::AddBatchNormTrainingInstruction( - const BatchNormTrainingRequest& batch_norm_training_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_training_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_training_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* offset, - LookUpRequest(batch_norm_training_request.offset())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferBatchNormTrainingShape( - operand->output_shape(), scale->output_shape(), - offset->output_shape(), batch_norm_training_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_training_request() = - batch_norm_training_request; - - VLOG(1) << "AddBatchNormTrainingInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << batch_norm_training_request.ShortDebugString(); - - return handle; -} - -StatusOr -UserComputation::AddBatchNormInferenceInstruction( - const BatchNormInferenceRequest& batch_norm_inference_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_inference_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_inference_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* offset, - LookUpRequest(batch_norm_inference_request.offset())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* mean, - LookUpRequest(batch_norm_inference_request.mean())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* variance, - LookUpRequest(batch_norm_inference_request.variance())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferBatchNormInferenceShape( - operand->output_shape(), scale->output_shape(), - offset->output_shape(), mean->output_shape(), - variance->output_shape(), - batch_norm_inference_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_inference_request() = - batch_norm_inference_request; - - VLOG(1) << "AddBatchNormInferenceInstruction (" - << GetVersionedHandleInternal() << "), data handle " - << handle.handle() << ": " - << batch_norm_inference_request.ShortDebugString(); - - return handle; -} - -StatusOr UserComputation::AddBatchNormGradInstruction( - const BatchNormGradRequest& batch_norm_grad_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(batch_norm_grad_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* scale, - LookUpRequest(batch_norm_grad_request.scale())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* mean, - LookUpRequest(batch_norm_grad_request.mean())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* variance, - LookUpRequest(batch_norm_grad_request.variance())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* grad_output, - LookUpRequest(batch_norm_grad_request.grad_output())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferBatchNormGradShape( - operand->output_shape(), scale->output_shape(), mean->output_shape(), - variance->output_shape(), grad_output->output_shape(), - batch_norm_grad_request.feature_index())); - - *request.mutable_output_shape() = inferred_shape; - - *request.mutable_output_handle() = handle; - - *request.mutable_request()->mutable_batch_norm_grad_request() = - batch_norm_grad_request; - - VLOG(1) << "AddBatchNormGradInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << batch_norm_grad_request.ShortDebugString(); - - return handle; -} - -StatusOr UserComputation::AddReduceWindowInstruction( - const ReduceWindowRequest& reduce_window_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_window_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(reduce_window_request.init_value())); - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReduceWindowShape( - operand->output_shape(), init_value->output_shape(), - reduce_window_request.window(), *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_reduce_window_request() = - reduce_window_request; - - VLOG(1) << "AddReduceWindowInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_window_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSelectAndScatterInstruction( - const SelectAndScatterRequest& select_and_scatter_request, - const UserComputation& select_computation, - const UserComputation& scatter_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(select_and_scatter_request.operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* source, - LookUpRequest(select_and_scatter_request.source())); - TF_ASSIGN_OR_RETURN(const OperationRequest* init_value, - LookUpRequest(select_and_scatter_request.init_value())); - - VersionedComputationHandle::Version select_version = - select_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr select_program_shape, - select_computation.ComputeProgramShape(select_version)); - VersionedComputationHandle::Version scatter_version = - scatter_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr scatter_program_shape, - scatter_computation.ComputeProgramShape(scatter_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferSelectAndScatterShape( - operand->output_shape(), *select_program_shape, - select_and_scatter_request.window(), source->output_shape(), - init_value->output_shape(), *scatter_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(select_version); - request.add_embedded_computation_versions(scatter_version); - *request.mutable_request()->mutable_select_and_scatter_request() = - select_and_scatter_request; - - VLOG(1) << "AddSelectAndScatterInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << select_and_scatter_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReverseInstruction( - const ReverseRequest& reverse_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reverse_request.operand())); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReverseShape( - operand->output_shape(), AsInt64Slice(reverse_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_reverse_request() = reverse_request; - VLOG(1) << "AddReverseInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reverse_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddWhileInstruction( - const WhileRequest& while_request, - const UserComputation& condition_computation, - const UserComputation& body_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* init, - LookUpRequest(while_request.init())); - - VersionedComputationHandle::Version condition_version = - condition_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr condition_program_shape, - condition_computation.ComputeProgramShape(condition_version)); - - VersionedComputationHandle::Version body_version = body_computation.version(); - TF_ASSIGN_OR_RETURN(std::shared_ptr body_program_shape, - body_computation.ComputeProgramShape(body_version)); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferWhileShape( - *condition_program_shape, *body_program_shape, init->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(condition_version); - request.add_embedded_computation_versions(body_version); - *request.mutable_request()->mutable_while_request() = while_request; - - VLOG(1) << "AddWhileInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << while_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConditionalInstruction( - const ConditionalRequest& conditional_request, - const UserComputation& true_computation, - const UserComputation& false_computation) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* pred, - LookUpRequest(conditional_request.predicate())); - TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand, - LookUpRequest(conditional_request.true_operand())); - TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand, - LookUpRequest(conditional_request.false_operand())); - - VersionedComputationHandle::Version true_computation_version = - true_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr true_computation_shape, - true_computation.ComputeProgramShape(true_computation_version)); - - VersionedComputationHandle::Version false_computation_version = - false_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr false_computation_shape, - false_computation.ComputeProgramShape(false_computation_version)); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferConditionalShape( - pred->output_shape(), true_operand->output_shape(), - false_operand->output_shape(), - *true_computation_shape, *false_computation_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(true_computation_version); - request.add_embedded_computation_versions(false_computation_version); - *request.mutable_request()->mutable_conditional_request() = - conditional_request; - - VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << conditional_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBroadcastInstruction( - const BroadcastRequest& broadcast_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(broadcast_request.operand())); - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferBroadcastShape( - operand->output_shape(), - AsInt64Slice(broadcast_request.broadcast_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_broadcast_request() = broadcast_request; - - VLOG(1) << "AddBroadcastInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << broadcast_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReshapeInstruction( - const ReshapeRequest& reshape_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reshape_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferReshapeShape( - operand->output_shape(), AsInt64Slice(reshape_request.dimensions()), - AsInt64Slice(reshape_request.new_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_reshape_request() = reshape_request; - - VLOG(1) << "AddReshapeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reshape_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddTransposeInstruction( - const TransposeRequest& transpose_request) { - tensorflow::mutex_lock lock(mutex_); - - // Fetches and validates the operand. - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(transpose_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape inferred_shape, - ShapeInference::InferTransposeShape( - operand->output_shape(), - AsInt64Slice(transpose_request.dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - *request.mutable_request()->mutable_transpose_request() = transpose_request; - - VLOG(1) << "AddTransposeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << transpose_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddSliceInstruction( - const SliceRequest& slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(slice_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferSliceShape( - operand->output_shape(), AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()), - AsInt64Slice(slice_request.strides()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_slice_request() = slice_request; - - VLOG(1) << "AddSliceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << slice_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddDynamicSliceInstruction( - const DynamicSliceRequest& dynamic_slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(dynamic_slice_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* start_indices, - LookUpRequest(dynamic_slice_request.start_indices())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferDynamicSliceShape( - operand->output_shape(), start_indices->output_shape(), - AsInt64Slice(dynamic_slice_request.slice_sizes()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_dynamic_slice_request() = - dynamic_slice_request; - - VLOG(1) << "AddDynamicSliceInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << dynamic_slice_request.ShortDebugString(); - return handle; -} - -StatusOr -UserComputation::AddDynamicUpdateSliceInstruction( - const DynamicUpdateSliceRequest& dynamic_update_slice_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(dynamic_update_slice_request.operand())); - - TF_ASSIGN_OR_RETURN(const OperationRequest* update, - LookUpRequest(dynamic_update_slice_request.update())); - - TF_ASSIGN_OR_RETURN( - const OperationRequest* start_indices, - LookUpRequest(dynamic_update_slice_request.start_indices())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, - ShapeInference::InferDynamicUpdateSliceShape( - operand->output_shape(), update->output_shape(), - start_indices->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_dynamic_update_slice_request() = - dynamic_update_slice_request; - - VLOG(1) << "AddDynamicUpdateSliceInstruction (" - << GetVersionedHandleInternal() << "), data handle " - << handle.handle() << ": " - << dynamic_update_slice_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConcatenateInstruction( - const ConcatenateRequest& concatenate_request) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : concatenate_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - TF_ASSIGN_OR_RETURN(Shape new_shape, - ShapeInference::InferConcatOpShape( - operand_shapes, concatenate_request.dimension())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_concatenate_request() = - concatenate_request; - - VLOG(1) << "AddConcatenateInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << concatenate_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConvertInstruction( - const ConvertRequest& convert_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(convert_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( - operand->output_shape(), - convert_request.new_element_type())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_convert_request() = convert_request; - - VLOG(1) << "AddConvertInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convert_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBitcastConvertInstruction( - const ConvertRequest& convert_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(convert_request.operand())); - - TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( - operand->output_shape(), - convert_request.new_element_type())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_bitcast_convert_request() = - convert_request; - - VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convert_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddReducePrecisionInstruction( - const ReducePrecisionRequest& reduce_precision_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(reduce_precision_request.operand())); - - TF_ASSIGN_OR_RETURN( - Shape new_shape, - ShapeInference::InferReducePrecisionShape( - operand->output_shape(), reduce_precision_request.exponent_bits(), - reduce_precision_request.mantissa_bits())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = new_shape; - *request.mutable_request()->mutable_reduce_precision_request() = - reduce_precision_request; - - VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << reduce_precision_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddConvolveInstruction( - const ConvolveRequest& convolve_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(convolve_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(convolve_request.rhs())); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape( - lhs->output_shape(), rhs->output_shape(), - convolve_request.window(), - convolve_request.dimension_numbers())); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_convolve_request() = convolve_request; - - VLOG(1) << "AddConvolveInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << convolve_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddFftInstruction( - const FftRequest& fft_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(fft_request.operand())); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferFftShape( - operand->output_shape(), fft_request.fft_type(), - AsInt64Slice(fft_request.fft_length()))); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_fft_request() = fft_request; - - VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << fft_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCrossReplicaSumInstruction( - const CrossReplicaSumRequest& cross_replica_sum_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(cross_replica_sum_request.operand())); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - {&operand->output_shape()})); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_cross_replica_sum_request() = - cross_replica_sum_request; - - VLOG(1) << "AddCrossreplicaSumInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << cross_replica_sum_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddInfeedInstruction( - const InfeedRequest& infeed_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = infeed_request.shape(); - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Given shape to Infeed must have a layout"); - } - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_infeed_request() = infeed_request; - - VLOG(1) << "AddInfeedInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << infeed_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddOutfeedInstruction( - const OutfeedRequest& outfeed_request) { - tensorflow::mutex_lock lock(mutex_); - - const Shape& shape = outfeed_request.shape(); - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Given shape to Outfeed must have a layout"); - } - - // Verify that operand is valid. - TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_outfeed_request() = outfeed_request; - - VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << outfeed_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCallInstruction( - const CallRequest& call_request, - const UserComputation& to_apply_computation) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : call_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - VersionedComputationHandle::Version to_apply_version = - to_apply_computation.version(); - TF_ASSIGN_OR_RETURN( - std::shared_ptr to_apply_program_shape, - to_apply_computation.ComputeProgramShape(to_apply_version)); - TF_ASSIGN_OR_RETURN( - Shape inferred_shape, - ShapeInference::InferCallShape(operand_shapes, *to_apply_program_shape)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = inferred_shape; - request.add_embedded_computation_versions(to_apply_version); - *request.mutable_request()->mutable_call_request() = call_request; - - VLOG(1) << "AddCallInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << call_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddCustomCallInstruction( - const CustomCallRequest& custom_call_request) { - tensorflow::mutex_lock lock(mutex_); - - for (const ComputationDataHandle& handle : custom_call_request.operands()) { - TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); - } - - if (tensorflow::str_util::StartsWith(custom_call_request.call_target_name(), - "$")) { - return InvalidArgument( - "Invalid custom_call_target \"%s\": Call targets that start with '$' " - "are reserved for internal use.", - custom_call_request.call_target_name().c_str()); - } - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = custom_call_request.shape(); - *request.mutable_request()->mutable_custom_call_request() = - custom_call_request; - - VLOG(1) << "AddCustomCallInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << custom_call_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddHostComputeInstruction( - const HostComputeRequest& host_compute_request) { - tensorflow::mutex_lock lock(mutex_); - - for (const ComputationDataHandle& handle : host_compute_request.operands()) { - TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); - } - - ComputationDataHandle handle = CreateComputationDataHandle(); - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = host_compute_request.shape(); - *request.mutable_request()->mutable_host_compute_request() = - host_compute_request; - - VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << host_compute_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddDotInstruction( - const DotRequest& dot_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(dot_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(dot_request.rhs())); - - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( - lhs->output_shape(), rhs->output_shape(), - dot_request.dimension_numbers())); - - const ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_dot_request() = dot_request; - - VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << dot_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddUnaryInstruction( - const UnaryOpRequest& unary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, - LookUpRequest(unary_request.operand())); - TF_ASSIGN_OR_RETURN( - Shape shape, ShapeInference::InferUnaryOpShape(unary_request.unop(), - operand->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_unary_op_request() = unary_request; - - VLOG(1) << "AddUnaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << unary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddBinaryInstruction( - const BinaryOpRequest& binary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(binary_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(binary_request.rhs())); - TF_ASSIGN_OR_RETURN( - Shape shape, - ShapeInference::InferBinaryOpShape( - binary_request.binop(), lhs->output_shape(), rhs->output_shape(), - AsInt64Slice(binary_request.broadcast_dimensions()))); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_binary_op_request() = binary_request; - - VLOG(1) << "AddBinaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << binary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddTernaryInstruction( - const TernaryOpRequest& ternary_request) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, - LookUpRequest(ternary_request.lhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, - LookUpRequest(ternary_request.rhs())); - TF_ASSIGN_OR_RETURN(const OperationRequest* ehs, - LookUpRequest(ternary_request.ehs())); - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferTernaryOpShape( - ternary_request.triop(), lhs->output_shape(), - rhs->output_shape(), ehs->output_shape())); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_ternary_op_request() = ternary_request; - - VLOG(1) << "AddTernaryInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << ternary_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::AddVariadicInstruction( - const VariadicOpRequest& variadic_request) { - tensorflow::mutex_lock lock(mutex_); - - std::vector operand_shapes; - for (const ComputationDataHandle& handle : variadic_request.operands()) { - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - operand_shapes.push_back(&operand->output_shape()); - } - - TF_ASSIGN_OR_RETURN(Shape shape, - ShapeInference::InferVariadicOpShape( - variadic_request.varop(), operand_shapes)); - - ComputationDataHandle handle = CreateComputationDataHandle(); - - OperationRequest& request = - (*session_computation_.mutable_requests())[handle.handle()]; - *request.mutable_output_handle() = handle; - *request.mutable_output_shape() = shape; - *request.mutable_request()->mutable_variadic_op_request() = variadic_request; - - VLOG(1) << "AddVariadicInstruction (" << GetVersionedHandleInternal() - << "), data handle " << handle.handle() << ": " - << variadic_request.ShortDebugString(); - return handle; -} - -StatusOr UserComputation::GetShape(const ComputationDataHandle& handle) { - tensorflow::mutex_lock lock(mutex_); - - TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(handle)); - return operand->output_shape(); -} - -Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle, - const OpMetadata& metadata) { - tensorflow::mutex_lock lock(mutex_); - - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpMetadata (%lld)", - handle_value); - } - *session_computation_.mutable_requests() - ->at(handle_value) - .mutable_request() - ->mutable_metadata() = metadata; - return Status::OK(); -} - -Status UserComputation::SetOpSharding(const ComputationDataHandle& handle, - const OpSharding& sharding) { - tensorflow::mutex_lock lock(mutex_); - - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpSharding (%lld)", - handle_value); - } - *session_computation_.mutable_requests() - ->at(handle_value) - .mutable_request() - ->mutable_sharding() = sharding; - return Status::OK(); -} - -Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) { - tensorflow::mutex_lock lock(mutex_); - - if (!(handle.handle() > 0 && handle.handle() < next_handle_value_)) { - return InvalidArgument("Invalid handle in SetReturnValue"); - } - - handle_to_return_ = handle; - - VLOG(1) << "SetReturnValue of computation \"" << name() << "\" fixed to " - << GetVersionedHandleInternal(); - - return Status::OK(); -} - -VersionedComputationHandle UserComputation::GetVersionedHandle() const { - tensorflow::mutex_lock lock(mutex_); - return GetVersionedHandleInternal(); -} - -VersionedComputationHandle UserComputation::GetVersionedHandleInternal() const { - VersionedComputationHandle versioned_handle; - versioned_handle.handle = session_computation_.computation_handle(); - - if (handle_to_return_.handle() > 0) { - // A specific handle has been requested for the result of the computation. - versioned_handle.version = handle_to_return_.handle(); - } else { - // A version value is simply the most recently assigned - // ComputationDataHandle value, ie the handle value of the root of the - // computation. - versioned_handle.version = next_handle_value_ - 1; - } - - return versioned_handle; -} - -VersionedComputationHandle UserComputation::GetVersionedHandleAtOperation( - const ComputationDataHandle& operation) const { - tensorflow::mutex_lock lock(mutex_); - - // The version at which an operation was added is simply the handle value of - // the ComputationDataHandle. - VersionedComputationHandle versioned_handle; - versioned_handle.handle = session_computation_.computation_handle(); - versioned_handle.version = operation.handle(); - return versioned_handle; -} - -VersionedComputationHandle::Version UserComputation::version() const { - return GetVersionedHandle().version; -} - -namespace { - -// Returns true if the operation type corresponding to the given opcase can be -// the root of the computation. -bool CanBeRoot(const OpRequest::OpCase& op_case) { - switch (op_case) { - case OpRequest::kTraceRequest: - case OpRequest::kSendRequest: - case OpRequest::kOutfeedRequest: - return false; - default: - return true; - } -} - -// Returns a pointer to the operation with the given data handle value in the -// given SessionComputation. -StatusOr LookUpRequest( - int64 handle_value, const SessionComputation& session_computation) { - if (session_computation.requests().count(handle_value) == 0) { - return InvalidArgument("no ComputationDataHandle value %lld", handle_value); - } - return &session_computation.requests().at(handle_value); -} - -// Returns the OperationRequest corresponding to the root (result) of the -// session computation. -StatusOr GetRoot( - VersionedComputationHandle::Version version, - const SessionComputation& session_computation) { - TF_RET_CHECK(version > 0); - // Not all instructions can be roots. Walk backwards from the operation - // indicated by this version until a valid root is found. - const OperationRequest* root_request = nullptr; - while (version > 0) { - TF_ASSIGN_OR_RETURN(root_request, - LookUpRequest(version, session_computation)); - if (CanBeRoot(root_request->request().op_case())) { - break; - } - version--; - } - if (version == 0) { - return InternalError("Computation contains no root operation"); - } - return root_request; -} - -} // namespace - -StatusOr> -UserComputation::ComputeProgramShape( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - - TF_RET_CHECK(version > 0 && version < next_handle_value_); - - if (program_shape_ == nullptr || program_shape_version_ != version) { - // ProgramShape has not been computed yet, or is for different - // version. Compute it now. - TF_RETURN_IF_ERROR(CheckParametersAreContiguous(version)); - - auto program_shape = MakeUnique(); - for (int64 request_num = 1; request_num <= version; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - if (request.request().op_case() == OpRequest::kParameterRequest) { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - int64 param_no = parameter_request.parameter(); - // Parameters may be out of order so expand ProgramShape parameters - // until it is at least large enough to hold the current parameter - // number. - while (program_shape->parameters_size() <= param_no) { - program_shape->add_parameters(); - program_shape->add_parameter_names(); - } - *program_shape->mutable_parameters(param_no) = request.output_shape(); - *program_shape->mutable_parameter_names(param_no) = - parameter_request.name(); - } - } - - // The root determines the output shape. - TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, - GetRoot(version, session_computation_)); - *program_shape->mutable_result() = root_request->output_shape(); - if (ShapeUtil::IsOpaque(program_shape->result())) { - return Unimplemented("Computation results cannot be opaque"); - } - - program_shape_ = std::move(program_shape); - program_shape_version_ = version; - } - - return program_shape_; -} - -namespace { - -// A visitor which checks whether an operation is pure functional meaning that -// it doesn't depend on any parameter with an index higher then num_parameters. -// The visitor walks the computation starting at a given operation and sets -// is_functional to false iff a parameter or RNG operation is encountered. -void PureFunctionalVisitor(const SessionComputation& session_computation, - const ComputationDataHandle& handle, - int64 num_parameters, std::set* visited, - bool* is_functional) { - if (visited->count(handle.handle()) != 0 || !*is_functional) { - return; - } - - const OperationRequest& request = - session_computation.requests().at(handle.handle()); - switch (request.request().op_case()) { - case OpRequest::kRngRequest: - *is_functional = false; - break; - - case OpRequest::kConstantRequest: - break; - - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - PureFunctionalVisitor(session_computation, - get_tuple_element_request.operand(), num_parameters, - visited, is_functional); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - PureFunctionalVisitor(session_computation, slice_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - PureFunctionalVisitor(session_computation, - dynamic_slice_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_slice_request.start_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.update(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - dynamic_update_slice_request.start_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - PureFunctionalVisitor(session_computation, convolve_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, convolve_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - PureFunctionalVisitor(session_computation, fft_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - // TODO(b/33009255): Implmement constant folding for cross replica sum. - *is_functional = false; - break; - } - - case OpRequest::kInfeedRequest: { - *is_functional = false; - break; - } - - case OpRequest::kOutfeedRequest: { - *is_functional = false; - break; - } - - case OpRequest::kHostComputeRequest: { - *is_functional = false; - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - for (const ComputationDataHandle& handle : call_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - // TODO(b/32495713): We aren't checking the to_apply computation itself, - // so we conservatively say that computations containing the Call op - // cannot be constant. We cannot set is_functional=false in other similar - // cases since we're already relying on IsConstant to return true. - *is_functional = false; - break; - } - - case OpRequest::kCustomCallRequest: { - *is_functional = false; - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - PureFunctionalVisitor(session_computation, dot_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, dot_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kSendRequest: { - *is_functional = false; - break; - } - - case OpRequest::kRecvRequest: { - *is_functional = false; - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - for (const ComputationDataHandle& handle : map_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - PureFunctionalVisitor(session_computation, reduce_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, reduce_request.init_value(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - PureFunctionalVisitor(session_computation, - reduce_window_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - reduce_window_request.init_value(), num_parameters, - visited, is_functional); - // TODO(b/32495713): We aren't checking the to_apply computation itself. - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.source(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - select_and_scatter_request.init_value(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the select and scatter - // computations themselves. - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - PureFunctionalVisitor(session_computation, broadcast_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - PureFunctionalVisitor(session_computation, reshape_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - PureFunctionalVisitor(session_computation, reverse_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - PureFunctionalVisitor(session_computation, pad_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, pad_request.padding_value(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kParameterRequest: { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - if (parameter_request.parameter() >= num_parameters) { - *is_functional = false; - } - break; - } - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - PureFunctionalVisitor(session_computation, convert_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - PureFunctionalVisitor(session_computation, convert_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - PureFunctionalVisitor(session_computation, while_request.init(), - num_parameters, visited, is_functional); - // TODO(b/32495713): We aren't checking the condition and body - // computations themselves. - *is_functional = false; - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - PureFunctionalVisitor(session_computation, - conditional_request.predicate(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - conditional_request.true_operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - conditional_request.false_operand(), num_parameters, - visited, is_functional); - // TODO(b/32495713): We aren't checking the true and false computations - // themselves. - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - PureFunctionalVisitor(session_computation, ternary_op_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, ternary_op_request.rhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, ternary_op_request.ehs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - PureFunctionalVisitor(session_computation, transpose_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - PureFunctionalVisitor(session_computation, handle, num_parameters, - visited, is_functional); - } - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - PureFunctionalVisitor(session_computation, unary_op_request.operand(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.scale(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_training_request.offset(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.operand(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.scale(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.offset(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.mean(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_inference_request.variance(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.operand(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.scale(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.variance(), num_parameters, - visited, is_functional); - PureFunctionalVisitor(session_computation, - batch_norm_grad_request.grad_output(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - PureFunctionalVisitor(session_computation, binary_op_request.lhs(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, binary_op_request.rhs(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::kGatherRequest: { - PureFunctionalVisitor(session_computation, - request.request().gather_request().input(), - num_parameters, visited, is_functional); - PureFunctionalVisitor(session_computation, - request.request().gather_request().gather_indices(), - num_parameters, visited, is_functional); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } - if (!*is_functional) { - VLOG(1) << "Non-functional: " << request.request().DebugString(); - } - visited->insert(handle.handle()); -} - -} // namespace - -StatusOr UserComputation::IsConstant(const ComputationDataHandle& handle, - int64 num_parameters) { - tensorflow::mutex_lock lock(mutex_); - - // Verify that the handle is valid. - auto operation_status = LookUpRequest(handle); - if (!operation_status.ok()) { - return operation_status.status(); - } - - bool is_constant = true; - std::set visited; - PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited, - &is_constant); - - return is_constant; -} - -std::vector -UserComputation::GetEmbeddedComputations( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - - VLOG(1) - << "GetEmbeddedComputations(" << name() << " " - << VersionedComputationHandle{session_computation_.computation_handle(), - version} - << ")"; - XLA_VLOG_LINES(3, session_computation_.DebugString()); - - std::vector computations; - std::vector sorted_handles; - for (const auto& handle_request : session_computation_.requests()) { - sorted_handles.push_back(handle_request.first); - } - std::sort(sorted_handles.begin(), sorted_handles.end()); - for (int64 handle : sorted_handles) { - const auto& handle_request = session_computation_.requests().find(handle); - CHECK(handle_request != session_computation_.requests().end()); - int64 handle_value = handle_request->first; - if (handle_value <= version) { - const OperationRequest& request = handle_request->second; - switch (request.request().op_case()) { - case OpRequest::kCallRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const CallRequest& call_request = request.request().call_request(); - const VersionedComputationHandle versioned_handle = { - call_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kMapRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const MapRequest& map_request = request.request().map_request(); - const VersionedComputationHandle versioned_handle = { - map_request.to_apply(), request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kReduceRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const ReduceRequest& reduce_request = - request.request().reduce_request(); - const VersionedComputationHandle versioned_handle = { - reduce_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kReduceWindowRequest: { - CHECK_EQ(1, request.embedded_computation_versions_size()); - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - const VersionedComputationHandle versioned_handle = { - reduce_window_request.to_apply(), - request.embedded_computation_versions(0)}; - computations.push_back(versioned_handle); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - const VersionedComputationHandle select_versioned_handle = { - select_and_scatter_request.select(), - request.embedded_computation_versions(0)}; - computations.push_back(select_versioned_handle); - const VersionedComputationHandle scatter_versioned_handle = { - select_and_scatter_request.scatter(), - request.embedded_computation_versions(1)}; - computations.push_back(scatter_versioned_handle); - break; - } - - case OpRequest::kWhileRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const WhileRequest& while_request = request.request().while_request(); - const VersionedComputationHandle condition_versioned_handle = { - while_request.condition(), - request.embedded_computation_versions(0)}; - computations.push_back(condition_versioned_handle); - const VersionedComputationHandle body_versioned_handle = { - while_request.body(), request.embedded_computation_versions(1)}; - computations.push_back(body_versioned_handle); - break; - } - - case OpRequest::kConditionalRequest: { - CHECK_EQ(2, request.embedded_computation_versions_size()); - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - const VersionedComputationHandle true_computation_versioned_handle = { - conditional_request.true_computation(), - request.embedded_computation_versions(0)}; - computations.push_back(true_computation_versioned_handle); - const VersionedComputationHandle false_computation_versioned_handle = - {conditional_request.false_computation(), - request.embedded_computation_versions(1)}; - computations.push_back(false_computation_versioned_handle); - break; - } - - default: - // No embedded computation. - break; - } - } - } - VLOG(2) << "Embedded computations: " - << tensorflow::str_util::Join( - computations, ", ", - [](string* out, const VersionedComputationHandle& h) { - out->append(h.ToString()); - }); - return computations; -} - -StatusOr -UserComputation::LookUpRequestForErrorReporting( - const ComputationDataHandle& handle) const { - tensorflow::mutex_lock lock(mutex_); - return LookUpRequest(handle); -} - -tensorflow::gtl::optional UserComputation::ParameterMetadata( - int parameter_number) const { - tensorflow::mutex_lock lock(mutex_); - auto it = parameters_.find(parameter_number); - if (it == parameters_.end()) { - return tensorflow::gtl::nullopt; - } - OperationRequest* op = it->second; - return &op->request().metadata(); -} - -Status UserComputation::RemapEmbeddedComputations( - const std::map& old_to_new) { - auto update = [&old_to_new](ComputationHandle* to_update) -> Status { - int64 old = to_update->handle(); - auto it = old_to_new.find(old); - if (it == old_to_new.end()) { - string mapping = tensorflow::str_util::Join( - old_to_new, ", ", - [](string* out, std::pair element) { - tensorflow::strings::Appendf(out, "%lld:%lld", element.first, - element.second.handle()); - }); - return NotFound( - "could not find referenced (old) computation handle in mapping: " - "%lld; mapping: {%s}", - old, mapping.c_str()); - } - VLOG(2) << "remapping " << old << " to " << it->second.handle(); - *to_update = it->second; - return Status::OK(); - }; - TF_RETURN_IF_ERROR(update(session_computation_.mutable_computation_handle())); - for (auto& handle_request : *session_computation_.mutable_requests()) { - OperationRequest& request = handle_request.second; - switch (request.request().op_case()) { - case OpRequest::kCallRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - CallRequest* call_request = - request.mutable_request()->mutable_call_request(); - TF_RETURN_IF_ERROR(update(call_request->mutable_to_apply())); - break; - } - case OpRequest::kMapRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - MapRequest* map_request = - request.mutable_request()->mutable_map_request(); - TF_RETURN_IF_ERROR(update(map_request->mutable_to_apply())); - break; - } - case OpRequest::kReduceRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - ReduceRequest* reduce_request = - request.mutable_request()->mutable_reduce_request(); - TF_RETURN_IF_ERROR(update(reduce_request->mutable_to_apply())); - break; - } - case OpRequest::kReduceWindowRequest: { - TF_RET_CHECK(1 == request.embedded_computation_versions_size()); - ReduceWindowRequest* reduce_window_request = - request.mutable_request()->mutable_reduce_window_request(); - TF_RETURN_IF_ERROR(update(reduce_window_request->mutable_to_apply())); - break; - } - case OpRequest::kSelectAndScatterRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - SelectAndScatterRequest* select_and_scatter_request = - request.mutable_request()->mutable_select_and_scatter_request(); - TF_RETURN_IF_ERROR( - update(select_and_scatter_request->mutable_select())); - TF_RETURN_IF_ERROR( - update(select_and_scatter_request->mutable_scatter())); - break; - } - case OpRequest::kWhileRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - WhileRequest* while_request = - request.mutable_request()->mutable_while_request(); - TF_RETURN_IF_ERROR(update(while_request->mutable_condition())); - TF_RETURN_IF_ERROR(update(while_request->mutable_body())); - break; - } - case OpRequest::kConditionalRequest: { - TF_RET_CHECK(2 == request.embedded_computation_versions_size()); - ConditionalRequest* conditional_request = - request.mutable_request()->mutable_conditional_request(); - TF_RETURN_IF_ERROR( - update(conditional_request->mutable_true_computation())); - TF_RETURN_IF_ERROR( - update(conditional_request->mutable_false_computation())); - break; - } - default: - // No embedded computation. - TF_RET_CHECK(0 == request.embedded_computation_versions_size()); - break; - } - } - return Status::OK(); -} - -SessionComputation UserComputation::CloneSessionComputation( - VersionedComputationHandle::Version version) const { - tensorflow::mutex_lock lock(mutex_); - SessionComputation result = session_computation_; - // Erase all the requests that exceed the version specified. - // There's no lower_bound method on tensorflow::protobuf::Map so we iterate - // all the elements. - auto it = result.mutable_requests()->begin(); - while (it != result.mutable_requests()->end()) { - if (it->first > version) { - it = result.mutable_requests()->erase(it); - } else { - ++it; - } - } - return result; -} - -StatusOr UserComputation::LookUpRequest( - const ComputationDataHandle& handle) const { - int64 handle_value = handle.handle(); - if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("no ComputationDataHandle value %lld", handle_value); - } - return &session_computation_.requests().at(handle_value); -} - -Status UserComputation::CheckParametersAreContiguous( - VersionedComputationHandle::Version version) const { - TF_RET_CHECK(version > 0 && version < next_handle_value_); - - // Determine number of parameter inputs at the given version. - std::map parameter_requests; - for (int64 request_num = 1; request_num <= version; ++request_num) { - const OperationRequest& request = - session_computation_.requests().at(request_num); - - if (request.request().op_case() == OpRequest::kParameterRequest) { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - // Duplicate parameters should be checked when parameter requests are - // added. - TF_RET_CHECK(0 == - parameter_requests.count(parameter_request.parameter())); - parameter_requests[parameter_request.parameter()] = ¶meter_request; - } - } - - for (int64 i = 0; i < parameter_requests.size(); ++i) { - auto it = parameter_requests.find(i); - if (it == parameter_requests.end()) { - return FailedPrecondition( - "computation %s does not have all its parameters populated " - "sequentially, missing parameter %lld", - name_.c_str(), i); - } - } - - return Status::OK(); -} - -namespace { - -// Helper class which builds an HLO computation from a SessionComputation. To -// construct the HLO computation, the SessionComputation graph is walked in -// DFS order lowering each OperationRequest to an HLO instruction. -class ComputationLowerer { - public: - static StatusOr> Lower( - const string& computation_name, - const SessionComputation& session_computation, - VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver, - const DebugOptions& debug_options, - bool include_unreachable_instructions) { - ComputationLowerer lowerer(computation_name, session_computation, version, - std::move(hlo_resolver), debug_options, - include_unreachable_instructions); - return lowerer.Lower(); - } - - private: - ComputationLowerer(const string& computation_name, - const SessionComputation& session_computation, - VersionedComputationHandle::Version version, - UserComputation::HloComputationResolver hlo_resolver, - const DebugOptions& debug_options, - bool include_unreachable_instructions) - : hlo_builder_(computation_name), - session_computation_(session_computation), - version_(version), - hlo_resolver_(std::move(hlo_resolver)), - debug_options_(debug_options), - include_unreachable_instructions_(include_unreachable_instructions) {} - - // Build an HLO computation from the SessionComputation at the given - // version. - StatusOr> Lower(); - - private: - // Traverses the computation 'root' using a DFS, calling 'visit' in postorder. - void TraversePostorder( - const ComputationDataHandle& root, - std::unordered_map* visited, - const std::function& visit); - - // DFS visitor of the UserComputation operations which lowers the operations - // to HLO instructions. - void Visit(const ComputationDataHandle& handle, - std::unordered_map* instructions); - - // Resolves a ComputationHandle and Version to a previously lowered - // HloComputation using the hlo_resolver_ function. - HloComputation* ResolveComputation( - const ComputationHandle& handle, - VersionedComputationHandle::Version version); - - // This function takes an input value which is being implicitly broadcast into - // an output shape and figures out the right kBroadcast instruction(s) - // necessary to replicate the implicit broadcast semantics explicitly. - HloInstruction* ImplicitBroadcastToExplicitBroadcast( - HloInstruction* operand, const Shape& output_shape); - - HloComputation::Builder hlo_builder_; - const SessionComputation& session_computation_; - const VersionedComputationHandle::Version version_; - const UserComputation::HloComputationResolver hlo_resolver_; - const DebugOptions& debug_options_; - const bool include_unreachable_instructions_; -}; - -// Calls 'apply' on each operand of 'request'. -static void ForEachOperand( - const OperationRequest& request, - const std::function& apply) { - switch (request.request().op_case()) { - case OpRequest::kRngRequest: { - const RngRequest& rng_request = request.request().rng_request(); - for (const ComputationDataHandle& param : rng_request.parameter()) { - apply(param); - } - break; - } - - case OpRequest::kConstantRequest: - break; - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - apply(get_tuple_element_request.operand()); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - apply(slice_request.operand()); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - apply(dynamic_slice_request.operand()); - apply(dynamic_slice_request.start_indices()); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - apply(dynamic_update_slice_request.operand()); - apply(dynamic_update_slice_request.update()); - apply(dynamic_update_slice_request.start_indices()); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - apply(convolve_request.lhs()); - apply(convolve_request.rhs()); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - apply(fft_request.operand()); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - - apply(batch_norm_training_request.operand()); - apply(batch_norm_training_request.scale()); - apply(batch_norm_training_request.offset()); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - - apply(batch_norm_inference_request.operand()); - apply(batch_norm_inference_request.scale()); - apply(batch_norm_inference_request.offset()); - apply(batch_norm_inference_request.mean()); - apply(batch_norm_inference_request.variance()); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - - apply(batch_norm_grad_request.operand()); - apply(batch_norm_grad_request.scale()); - apply(batch_norm_grad_request.mean()); - apply(batch_norm_grad_request.variance()); - apply(batch_norm_grad_request.grad_output()); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - const CrossReplicaSumRequest& cross_replica_sum_request = - request.request().cross_replica_sum_request(); - apply(cross_replica_sum_request.operand()); - break; - } - - case OpRequest::kInfeedRequest: - break; - - case OpRequest::kOutfeedRequest: { - const OutfeedRequest& outfeed_request = - request.request().outfeed_request(); - apply(outfeed_request.operand()); - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - for (const ComputationDataHandle& handle : map_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - apply(reduce_request.operand()); - apply(reduce_request.init_value()); - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - apply(reduce_window_request.operand()); - apply(reduce_window_request.init_value()); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - apply(select_and_scatter_request.operand()); - apply(select_and_scatter_request.source()); - apply(select_and_scatter_request.init_value()); - - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - apply(broadcast_request.operand()); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - apply(reshape_request.operand()); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - apply(transpose_request.operand()); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - apply(reverse_request.operand()); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - apply(pad_request.operand()); - apply(pad_request.padding_value()); - break; - } - - case OpRequest::kRecvRequest: - case OpRequest::kParameterRequest: - break; - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - apply(convert_request.operand()); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - apply(convert_request.operand()); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - apply(while_request.init()); - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - apply(conditional_request.predicate()); - apply(conditional_request.true_operand()); - apply(conditional_request.false_operand()); - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - apply(ternary_op_request.lhs()); - apply(ternary_op_request.rhs()); - apply(ternary_op_request.ehs()); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - for (const ComputationDataHandle& handle : call_request.operands()) { - apply(handle); - } - break; - } - - case OpRequest::kCustomCallRequest: { - const CustomCallRequest& cc_request = - request.request().custom_call_request(); - for (const ComputationDataHandle& operand : cc_request.operands()) { - apply(operand); - } - break; - } - - case OpRequest::kHostComputeRequest: { - const HostComputeRequest& hc_request = - request.request().host_compute_request(); - for (const ComputationDataHandle& operand : hc_request.operands()) { - apply(operand); - } - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - apply(dot_request.rhs()); - apply(dot_request.lhs()); - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - apply(unary_op_request.operand()); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - apply(binary_op_request.rhs()); - apply(binary_op_request.lhs()); - break; - } - - case OpRequest::kReducePrecisionRequest: { - const ReducePrecisionRequest& reduce_precision_request = - request.request().reduce_precision_request(); - apply(reduce_precision_request.operand()); - break; - } - - case OpRequest::kTraceRequest: { - const TraceRequest& trace_request = request.request().trace_request(); - apply(trace_request.operand()); - break; - } - - case OpRequest::kSendRequest: { - const SendRequest& send_request = request.request().send_request(); - apply(send_request.operand()); - break; - } - - case OpRequest::kGatherRequest: { - const GatherRequest& gather_request = request.request().gather_request(); - apply(gather_request.input()); - apply(gather_request.gather_indices()); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } -} - -void ComputationLowerer::TraversePostorder( - const ComputationDataHandle& root, - std::unordered_map* visited, - const std::function& visit) { - // Stack containing {handle, enter} pairs. The 'enter' value describes whether - // we are entering or leaving 'handle'. - std::stack> work; - work.push({root, true}); - while (!work.empty()) { - ComputationDataHandle handle; - bool enter; - std::tie(handle, enter) = work.top(); - work.pop(); - - if (enter) { - // We are entering 'handle'. The first time we enter 'handle', we add it - // to 'visited' with a nullptr value. If 'handle' is already in 'visited', - // we do not visit it again. This algorithm only uses the presence of - // a handle in 'visited', but we use a map so we can use the same data - // structure to store the HloInstruction outputs. - if (visited->emplace(handle.handle(), nullptr).second) { - const OperationRequest& request = - session_computation_.requests().at(handle.handle()); - // Push the corresponding 'leave' action onto the stack, followed by - // the operands. - work.push({handle, false}); - ForEachOperand(request, [&work](const ComputationDataHandle& child) { - work.push({child, true}); - }); - } - } else { - // We are leaving 'handle'. We have visited the operands of 'handle', and - // now can visit the 'handle' itself. - visit(handle); - } - } -} - -StatusOr> ComputationLowerer::Lower() { - // Map from ComputationDataHandle to HLO instruction. Serves as a record of - // which operations have been visited as well as a cache for looking up - // ComputationDataHandles as HloInstructions. - std::unordered_map instructions; - - TF_ASSIGN_OR_RETURN(const OperationRequest* root_request, - GetRoot(version_, session_computation_)); - - auto visit = [&](const ComputationDataHandle& handle) { - Visit(handle, &instructions); - }; - TraversePostorder(root_request->output_handle(), &instructions, visit); - HloInstruction* hlo_root = - instructions.at(root_request->output_handle().handle()); - - if (include_unreachable_instructions_) { - // Iterate through all computation data handles, and visit any unvisited - // operations. - for (int64 request_num = 1; request_num <= version_; ++request_num) { - TF_ASSIGN_OR_RETURN(const OperationRequest* request, - LookUpRequest(request_num, session_computation_)); - TraversePostorder(request->output_handle(), &instructions, visit); - } - } - - return hlo_builder_.Build(hlo_root); -} - -HloComputation* ComputationLowerer::ResolveComputation( - const ComputationHandle& handle, - VersionedComputationHandle::Version version) { - const VersionedComputationHandle checked_handle = {handle, version}; - return hlo_resolver_(checked_handle); -} - -HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( - HloInstruction* operand, const Shape& output_shape) { - auto fadd = [this](std::unique_ptr x) { - return hlo_builder_.AddInstruction(std::move(x)); - }; - return fadd( - HloInstruction::CreateBroadcastSequence(output_shape, operand, fadd)); -} - -void ComputationLowerer::Visit( - const ComputationDataHandle& handle, - std::unordered_map* instructions) { - CHECK_LE(handle.handle(), version_); - CHECK(instructions->at(handle.handle()) == nullptr); - const OperationRequest& request = - session_computation_.requests().at(handle.handle()); - auto add_instruction = [&](std::unique_ptr instruction) { - HloInstruction* hlo_instruction = - hlo_builder_.AddInstruction(std::move(instruction)); - hlo_instruction->set_metadata(request.request().metadata()); - if (request.request().has_sharding()) { - OpSharding op_sharding = request.request().sharding(); - hlo_instruction->set_sharding( - HloSharding::FromProto(op_sharding).ValueOrDie()); - } - return hlo_instruction; - }; - auto lookup_instruction = [&](const ComputationDataHandle& handle) { - return instructions->at(handle.handle()); - }; - HloInstruction* hlo_instruction; - switch (request.request().op_case()) { - case OpRequest::kRngRequest: { - const RngRequest& rng_request = request.request().rng_request(); - std::vector parameters; - for (const ComputationDataHandle& param : rng_request.parameter()) { - parameters.push_back(lookup_instruction(param)); - } - hlo_instruction = add_instruction(HloInstruction::CreateRng( - request.output_shape(), rng_request.distribution(), parameters)); - break; - } - - case OpRequest::kConstantRequest: { - const ConstantRequest& constant_request = - request.request().constant_request(); - hlo_instruction = add_instruction(HloInstruction::CreateConstant( - Literal::CreateFromProto(constant_request.literal()) - .ConsumeValueOrDie())); - break; - } - - case OpRequest::kGetTupleElementRequest: { - const GetTupleElementRequest& get_tuple_element_request = - request.request().get_tuple_element_request(); - HloInstruction* operand = - lookup_instruction(get_tuple_element_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateGetTupleElement( - request.output_shape(), operand, get_tuple_element_request.index())); - break; - } - - case OpRequest::kSliceRequest: { - const SliceRequest& slice_request = request.request().slice_request(); - HloInstruction* operand = lookup_instruction(slice_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateSlice( - request.output_shape(), operand, - AsInt64Slice(slice_request.start_indices()), - AsInt64Slice(slice_request.limit_indices()), - AsInt64Slice(slice_request.strides()))); - break; - } - - case OpRequest::kDynamicSliceRequest: { - const DynamicSliceRequest& dynamic_slice_request = - request.request().dynamic_slice_request(); - HloInstruction* operand = - lookup_instruction(dynamic_slice_request.operand()); - HloInstruction* start_indices = - lookup_instruction(dynamic_slice_request.start_indices()); - - hlo_instruction = add_instruction(HloInstruction::CreateDynamicSlice( - request.output_shape(), operand, start_indices, - AsInt64Slice(dynamic_slice_request.slice_sizes()))); - break; - } - - case OpRequest::kDynamicUpdateSliceRequest: { - const DynamicUpdateSliceRequest& dynamic_update_slice_request = - request.request().dynamic_update_slice_request(); - HloInstruction* operand = - lookup_instruction(dynamic_update_slice_request.operand()); - HloInstruction* update = - lookup_instruction(dynamic_update_slice_request.update()); - HloInstruction* start_indices = - lookup_instruction(dynamic_update_slice_request.start_indices()); - hlo_instruction = - add_instruction(HloInstruction::CreateDynamicUpdateSlice( - request.output_shape(), operand, update, start_indices)); - break; - } - - case OpRequest::kConcatenateRequest: { - const ConcatenateRequest& concatenate_request = - request.request().concatenate_request(); - std::vector operands; - for (const ComputationDataHandle& handle : - concatenate_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - hlo_instruction = add_instruction(HloInstruction::CreateConcatenate( - request.output_shape(), operands, concatenate_request.dimension())); - break; - } - - case OpRequest::kConvolveRequest: { - const ConvolveRequest& convolve_request = - request.request().convolve_request(); - HloInstruction* lhs = lookup_instruction(convolve_request.lhs()); - HloInstruction* rhs = lookup_instruction(convolve_request.rhs()); - hlo_instruction = add_instruction(HloInstruction::CreateConvolve( - request.output_shape(), lhs, rhs, convolve_request.window(), - convolve_request.dimension_numbers())); - break; - } - - case OpRequest::kFftRequest: { - const FftRequest& fft_request = request.request().fft_request(); - HloInstruction* operand = lookup_instruction(fft_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateFft( - request.output_shape(), operand, fft_request.fft_type(), - AsInt64Slice(fft_request.fft_length()))); - break; - } - - case OpRequest::kDotRequest: { - const DotRequest& dot_request = request.request().dot_request(); - HloInstruction* lhs = lookup_instruction(dot_request.lhs()); - HloInstruction* rhs = lookup_instruction(dot_request.rhs()); - hlo_instruction = add_instruction(HloInstruction::CreateDot( - request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); - break; - } - - case OpRequest::kCrossReplicaSumRequest: { - const CrossReplicaSumRequest& cross_replica_sum_request = - request.request().cross_replica_sum_request(); - HloInstruction* operand = - lookup_instruction(cross_replica_sum_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), {operand})); - break; - } - - case OpRequest::kInfeedRequest: { - const InfeedRequest& infeed_request = request.request().infeed_request(); - hlo_instruction = add_instruction(HloInstruction::CreateInfeed( - request.output_shape(), infeed_request.config())); - break; - } - - case OpRequest::kOutfeedRequest: { - const OutfeedRequest& outfeed_request = - request.request().outfeed_request(); - HloInstruction* operand = lookup_instruction(outfeed_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateOutfeed( - outfeed_request.shape(), operand, outfeed_request.outfeed_config())); - break; - } - - case OpRequest::kMapRequest: { - const MapRequest& map_request = request.request().map_request(); - std::vector operands; - for (const ComputationDataHandle& handle : map_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version map_version = - request.embedded_computation_versions(0); - HloComputation* map_computation = - ResolveComputation(map_request.to_apply(), map_version); - hlo_instruction = add_instruction(HloInstruction::CreateMap( - request.output_shape(), operands, map_computation)); - break; - } - - case OpRequest::kReduceRequest: { - const ReduceRequest& reduce_request = request.request().reduce_request(); - HloInstruction* operand = lookup_instruction(reduce_request.operand()); - HloInstruction* init_value = - lookup_instruction(reduce_request.init_value()); - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version reduce_version = - request.embedded_computation_versions(0); - HloComputation* reduce_computation = - ResolveComputation(reduce_request.to_apply(), reduce_version); - hlo_instruction = add_instruction(HloInstruction::CreateReduce( - request.output_shape(), operand, init_value, - AsInt64Slice(reduce_request.dimensions()), reduce_computation)); - break; - } - - case OpRequest::kReduceWindowRequest: { - const ReduceWindowRequest& reduce_window_request = - request.request().reduce_window_request(); - HloInstruction* operand = - lookup_instruction(reduce_window_request.operand()); - HloInstruction* init_value = - lookup_instruction(reduce_window_request.init_value()); - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version reduce_window_version = - request.embedded_computation_versions(0); - HloComputation* reduce_window_computation = ResolveComputation( - reduce_window_request.to_apply(), reduce_window_version); - hlo_instruction = add_instruction(HloInstruction::CreateReduceWindow( - request.output_shape(), operand, init_value, - reduce_window_request.window(), reduce_window_computation)); - break; - } - - case OpRequest::kSelectAndScatterRequest: { - const SelectAndScatterRequest& select_and_scatter_request = - request.request().select_and_scatter_request(); - HloInstruction* operand = - lookup_instruction(select_and_scatter_request.operand()); - HloInstruction* source = - lookup_instruction(select_and_scatter_request.source()); - HloInstruction* init_value = - lookup_instruction(select_and_scatter_request.init_value()); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version select_version = - request.embedded_computation_versions(0); - VersionedComputationHandle::Version scatter_version = - request.embedded_computation_versions(1); - HloComputation* select_computation = ResolveComputation( - select_and_scatter_request.select(), select_version); - HloComputation* scatter_computation = ResolveComputation( - select_and_scatter_request.scatter(), scatter_version); - hlo_instruction = add_instruction(HloInstruction::CreateSelectAndScatter( - request.output_shape(), operand, select_computation, - select_and_scatter_request.window(), source, init_value, - scatter_computation)); - break; - } - - case OpRequest::kBatchNormTrainingRequest: { - const BatchNormTrainingRequest& batch_norm_training_request = - request.request().batch_norm_training_request(); - HloInstruction* operand = - lookup_instruction(batch_norm_training_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_training_request.scale()); - HloInstruction* offset = - lookup_instruction(batch_norm_training_request.offset()); - - hlo_instruction = add_instruction(HloInstruction::CreateBatchNormTraining( - request.output_shape(), operand, scale, offset, - batch_norm_training_request.epsilon(), - batch_norm_training_request.feature_index())); - break; - } - - case OpRequest::kBatchNormInferenceRequest: { - const BatchNormInferenceRequest& batch_norm_inference_request = - request.request().batch_norm_inference_request(); - HloInstruction* operand = - lookup_instruction(batch_norm_inference_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_inference_request.scale()); - HloInstruction* offset = - lookup_instruction(batch_norm_inference_request.offset()); - HloInstruction* mean = - lookup_instruction(batch_norm_inference_request.mean()); - HloInstruction* variance = - lookup_instruction(batch_norm_inference_request.variance()); - - hlo_instruction = - add_instruction(HloInstruction::CreateBatchNormInference( - request.output_shape(), operand, scale, offset, mean, variance, - batch_norm_inference_request.epsilon(), - batch_norm_inference_request.feature_index())); - break; - } - - case OpRequest::kBatchNormGradRequest: { - const BatchNormGradRequest& batch_norm_grad_request = - request.request().batch_norm_grad_request(); - - HloInstruction* operand = - lookup_instruction(batch_norm_grad_request.operand()); - HloInstruction* scale = - lookup_instruction(batch_norm_grad_request.scale()); - HloInstruction* mean = lookup_instruction(batch_norm_grad_request.mean()); - HloInstruction* variance = - lookup_instruction(batch_norm_grad_request.variance()); - HloInstruction* grad_output = - lookup_instruction(batch_norm_grad_request.grad_output()); - - hlo_instruction = add_instruction(HloInstruction::CreateBatchNormGrad( - request.output_shape(), operand, scale, mean, variance, grad_output, - batch_norm_grad_request.epsilon(), - batch_norm_grad_request.feature_index())); - break; - } - - case OpRequest::kBroadcastRequest: { - const BroadcastRequest& broadcast_request = - request.request().broadcast_request(); - HloInstruction* operand = lookup_instruction(broadcast_request.operand()); - std::vector broadcast_dimensions; - // The client-level broadcast instruction just appends dimensions on the - // left (adds lowest numbered dimensions). The HLO broadcast op is more - // flexible and can add new dimensions anywhere. The broadcast_dimensions - // maps operand dimensions to dimensions in the broadcast output, so - // to append dimensions on the left the broadcast_dimensions should just - // be the n highest dimension numbers of the output shape where n is - // the number of input dimensions. - broadcast_dimensions.reserve(ShapeUtil::Rank(operand->shape())); - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) { - broadcast_dimensions.push_back(i + - ShapeUtil::Rank(request.output_shape()) - - ShapeUtil::Rank(operand->shape())); - } - hlo_instruction = add_instruction(HloInstruction::CreateBroadcast( - request.output_shape(), operand, broadcast_dimensions)); - break; - } - - case OpRequest::kReshapeRequest: { - const ReshapeRequest& reshape_request = - request.request().reshape_request(); - HloInstruction* operand = lookup_instruction(reshape_request.operand()); - HloInstruction* transposed; - if (IsIdentityPermutation(AsInt64Slice(reshape_request.dimensions()))) { - transposed = operand; - } else { - transposed = add_instruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation(AsInt64Slice(reshape_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(reshape_request.dimensions()))); - } - hlo_instruction = add_instruction( - HloInstruction::CreateReshape(request.output_shape(), transposed)); - break; - } - - case OpRequest::kTransposeRequest: { - const TransposeRequest& transpose_request = - request.request().transpose_request(); - HloInstruction* operand = lookup_instruction(transpose_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateTranspose( - ShapeUtil::PermuteDimensions( - InversePermutation(AsInt64Slice(transpose_request.dimensions())), - operand->shape()), - operand, AsInt64Slice(transpose_request.dimensions()))); - break; - } - - case OpRequest::kReverseRequest: { - const ReverseRequest& reverse_request = - request.request().reverse_request(); - HloInstruction* operand = lookup_instruction(reverse_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateReverse( - request.output_shape(), operand, - AsInt64Slice(reverse_request.dimensions()))); - break; - } - - case OpRequest::kPadRequest: { - const PadRequest& pad_request = request.request().pad_request(); - HloInstruction* operand = lookup_instruction(pad_request.operand()); - HloInstruction* padding_value = - lookup_instruction(pad_request.padding_value()); - hlo_instruction = add_instruction(HloInstruction::CreatePad( - request.output_shape(), operand, padding_value, - pad_request.padding_config())); - break; - } - - case OpRequest::kRecvRequest: { - const RecvRequest& recv_request = request.request().recv_request(); - HloInstruction* recv = add_instruction(HloInstruction::CreateRecv( - request.output_shape(), recv_request.channel_handle().handle())); - hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv)); - break; - } - - case OpRequest::kParameterRequest: { - const ParameterRequest& parameter_request = - request.request().parameter_request(); - hlo_instruction = add_instruction(HloInstruction::CreateParameter( - parameter_request.parameter(), request.output_shape(), - parameter_request.name())); - break; - } - - case OpRequest::kConvertRequest: { - const ConvertRequest& convert_request = - request.request().convert_request(); - HloInstruction* operand = lookup_instruction(convert_request.operand()); - hlo_instruction = add_instruction( - HloInstruction::CreateConvert(request.output_shape(), operand)); - break; - } - - case OpRequest::kBitcastConvertRequest: { - const ConvertRequest& convert_request = - request.request().bitcast_convert_request(); - HloInstruction* operand = lookup_instruction(convert_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert( - request.output_shape(), operand)); - break; - } - - case OpRequest::kWhileRequest: { - const WhileRequest& while_request = request.request().while_request(); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version condition_version = - request.embedded_computation_versions(0); - HloComputation* condition = - ResolveComputation(while_request.condition(), condition_version); - VersionedComputationHandle::Version body_version = - request.embedded_computation_versions(1); - HloComputation* body = - ResolveComputation(while_request.body(), body_version); - HloInstruction* init = lookup_instruction(while_request.init()); - hlo_instruction = add_instruction(HloInstruction::CreateWhile( - request.output_shape(), condition, body, init)); - break; - } - - case OpRequest::kConditionalRequest: { - const ConditionalRequest& conditional_request = - request.request().conditional_request(); - CHECK_EQ(2, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version true_computation_version = - request.embedded_computation_versions(0); - HloComputation* true_computation = ResolveComputation( - conditional_request.true_computation(), true_computation_version); - VersionedComputationHandle::Version false_computation_version = - request.embedded_computation_versions(1); - HloComputation* false_computation = ResolveComputation( - conditional_request.false_computation(), false_computation_version); - HloInstruction* predicate = - lookup_instruction(conditional_request.predicate()); - HloInstruction* true_operand = - lookup_instruction(conditional_request.true_operand()); - HloInstruction* false_operand = - lookup_instruction(conditional_request.false_operand()); - hlo_instruction = add_instruction(HloInstruction::CreateConditional( - request.output_shape(), predicate, true_operand, true_computation, - false_operand, false_computation)); - break; - } - - case OpRequest::kTernaryOpRequest: { - const TernaryOpRequest& ternary_op_request = - request.request().ternary_op_request(); - HloInstruction* lhs = lookup_instruction(ternary_op_request.lhs()); - HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); - HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); - auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); - if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && - !ShapeUtil::IsTuple(request.output_shape())) { - if (!ShapeUtil::IsTuple(lhs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { - // lhs side is being implicitly broadcast. Change to explicit. - lhs = - ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); - } - - if (!ShapeUtil::IsTuple(rhs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { - rhs = - ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); - } - - if (!ShapeUtil::IsTuple(ehs->shape()) && - !ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { - ehs = - ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape()); - } - } - - hlo_instruction = add_instruction(HloInstruction::CreateTernary( - request.output_shape(), hlo_opcode, lhs, rhs, ehs)); - break; - } - - case OpRequest::kVariadicOpRequest: { - const VariadicOpRequest& variadic_op_request = - request.request().variadic_op_request(); - std::vector operands; - for (const ComputationDataHandle& handle : - variadic_op_request.operands()) { - HloInstruction* operand = lookup_instruction(handle); - operands.push_back(operand); - } - auto hlo_opcode = - VariadicOperationToHloOpcode(variadic_op_request.varop()); - hlo_instruction = add_instruction(HloInstruction::CreateVariadic( - request.output_shape(), hlo_opcode, operands)); - break; - } - - case OpRequest::kCallRequest: { - const CallRequest& call_request = request.request().call_request(); - std::vector operands; - for (const ComputationDataHandle& handle : call_request.operands()) { - operands.push_back(lookup_instruction(handle)); - } - CHECK_EQ(1, request.embedded_computation_versions_size()); - VersionedComputationHandle::Version call_version = - request.embedded_computation_versions(0); - HloComputation* call_computation = - ResolveComputation(call_request.to_apply(), call_version); - hlo_instruction = add_instruction(HloInstruction::CreateCall( - request.output_shape(), operands, call_computation)); - break; - } - - case OpRequest::kCustomCallRequest: { - const CustomCallRequest& cc_request = - request.request().custom_call_request(); - std::vector operands; - for (const ComputationDataHandle& operand : cc_request.operands()) { - operands.push_back(lookup_instruction(operand)); - } - hlo_instruction = add_instruction(HloInstruction::CreateCustomCall( - cc_request.shape(), operands, cc_request.call_target_name())); - break; - } - - case OpRequest::kHostComputeRequest: { - const HostComputeRequest& host_compute_request = - request.request().host_compute_request(); - std::vector operands; - for (const ComputationDataHandle& operand : - host_compute_request.operands()) { - operands.push_back(lookup_instruction(operand)); - } - auto output_shape = host_compute_request.shape(); - auto channel_name = host_compute_request.channel_name(); - auto cost_estimate_ns = host_compute_request.cost_estimate_ns(); - hlo_instruction = add_instruction(HloInstruction::CreateHostCompute( - output_shape, operands, channel_name, cost_estimate_ns)); - break; - } - - case OpRequest::kUnaryOpRequest: { - const UnaryOpRequest& unary_op_request = - request.request().unary_op_request(); - HloInstruction* operand = lookup_instruction(unary_op_request.operand()); - auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop()); - hlo_instruction = add_instruction(HloInstruction::CreateUnary( - request.output_shape(), hlo_opcode, operand)); - break; - } - - case OpRequest::kBinaryOpRequest: { - const BinaryOpRequest& binary_op_request = - request.request().binary_op_request(); - HloInstruction* lhs = lookup_instruction(binary_op_request.lhs()); - HloInstruction* rhs = lookup_instruction(binary_op_request.rhs()); - auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop()); - if (binary_op_request.broadcast_dimensions_size() > 0 && - ShapeUtil::Rank(lhs->shape()) != ShapeUtil::Rank(rhs->shape())) { - // Emit a broadcast instruction to perform the "broadcast in dimension" - // operation. - HloInstruction* operand_to_broadcast = - ShapeUtil::Rank(lhs->shape()) < ShapeUtil::Rank(rhs->shape()) ? lhs - : rhs; - CHECK_EQ(ShapeUtil::Rank(operand_to_broadcast->shape()), - binary_op_request.broadcast_dimensions().size()); - - // Construct the bounds of the shape of the kBroadcast instruction - // responsible for the in-dimension broadcast. - std::vector output_dimensions; - for (int64 size : request.output_shape().dimensions()) { - output_dimensions.push_back(size); - } - for (int64 operand_dim = 0; - operand_dim < ShapeUtil::Rank(operand_to_broadcast->shape()); - ++operand_dim) { - int64 output_dim = - binary_op_request.broadcast_dimensions()[operand_dim]; - output_dimensions[output_dim] = - operand_to_broadcast->shape().dimensions(operand_dim); - } - - Shape broadcast_shape = ShapeUtil::MakeShape( - operand_to_broadcast->shape().element_type(), output_dimensions); - - // The broadcast semantics of a client-level binary op broadcast is - // identical to the HLO broadcast semantics so the broadcast_dimensions - // field can just be passed to the instruction builder. - HloInstruction* broadcasted_operand = - add_instruction(HloInstruction::CreateBroadcast( - broadcast_shape, operand_to_broadcast, - AsInt64Slice(binary_op_request.broadcast_dimensions()))); - - lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; - rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; - } - if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { - if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { - // lhs side is being implicitly broadcast. Change to explicit. - lhs = - ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); - } - - if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { - rhs = - ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); - } - } - hlo_instruction = add_instruction(HloInstruction::CreateBinary( - request.output_shape(), hlo_opcode, lhs, rhs)); - break; - } - - case OpRequest::kReducePrecisionRequest: { - const ReducePrecisionRequest& reduce_precision_request = - request.request().reduce_precision_request(); - HloInstruction* operand = - lookup_instruction(reduce_precision_request.operand()); - auto exponent_bits = reduce_precision_request.exponent_bits(); - auto mantissa_bits = reduce_precision_request.mantissa_bits(); - hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision( - request.output_shape(), operand, exponent_bits, mantissa_bits)); - break; - } - - case OpRequest::kTraceRequest: { - const TraceRequest& trace_request = request.request().trace_request(); - HloInstruction* operand = lookup_instruction(trace_request.operand()); - hlo_instruction = add_instruction( - HloInstruction::CreateTrace(trace_request.tag(), operand)); - break; - } - - case OpRequest::kSendRequest: { - const SendRequest& send_request = request.request().send_request(); - HloInstruction* operand = lookup_instruction(send_request.operand()); - HloInstruction* send = add_instruction(HloInstruction::CreateSend( - operand, send_request.channel_handle().handle())); - hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send)); - break; - } - - case OpRequest::kGatherRequest: { - const GatherRequest& gather_request = request.request().gather_request(); - HloInstruction* input_operand = - lookup_instruction(gather_request.input()); - HloInstruction* gather_indices_operand = - lookup_instruction(gather_request.gather_indices()); - std::vector window_bounds; - c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds)); - hlo_instruction = add_instruction(HloInstruction::CreateGather( - request.output_shape(), input_operand, gather_indices_operand, - gather_request.dimension_numbers(), window_bounds)); - break; - } - - case OpRequest::OP_NOT_SET: - LOG(FATAL) << "OperationRequest doesn't contain a request"; - - default: - LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); - } - (*instructions)[handle.handle()] = hlo_instruction; -} // NOLINT(readability/fn_size) - -} // namespace - -StatusOr> UserComputation::BuildHloComputation( - VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, const DebugOptions& debug_options, - bool include_unreachable_instructions) const { - tensorflow::mutex_lock lock(mutex_); - - VLOG(2) << "Building HloComputation from UserComputation " << name_ - << " at version " << version; - XLA_VLOG_LINES(3, session_computation_.DebugString()); - - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_computation, - ComputationLowerer::Lower( - tensorflow::strings::StrCat(name(), ".v", version), - session_computation_, version, std::move(hlo_resolver), debug_options, - include_unreachable_instructions)); - - return std::move(hlo_computation); -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h deleted file mode 100644 index 5544c868fe905c1ca7e6cab32738440add2e3b4f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/user_computation.h +++ /dev/null @@ -1,413 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/session.pb.h" -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { - -// A UserComputation is the built-up computation that users create via the -// XLA Service interface. -// -// The XLA service adds instructions to a user computation via this -// interface. The state of the computation is stored as a SessionComputation -// proto which holds a record of all operation-building requests received by the -// XLA service. -// -// UserComputations are lowered to HloComputations which are passed to the high -// level compiler interface. -class UserComputation { - public: - // Factory used when restoring a computation from serialized session - // computation (computation snapshot) data. Remaps any references to - // computation handle via the old_to_new mapping. - // - // An error will occur if the old_to_new mapping cannot resolve a reference to - // a computation that is present in session_computation. - static StatusOr> MakeWithRemapping( - const SessionComputation& session_computation, - const ComputationHandle& handle, - const std::map& old_to_new); - - // Creates an empty computation with the given name and computation handle. - explicit UserComputation(const string& name, const ComputationHandle& handle); - - // Enqueues a parameter-retrieving instruction onto this user computation. - // Returns an error status if the parameter number is already registered with - // different values. - StatusOr AddParameterInstruction( - const ParameterRequest& parameter_request); - - // Enqueues a pad instruction onto this user computation. - StatusOr AddPadInstruction( - const PadRequest& pad_request); - - // Enqueues a tracing instruction onto this user computation. - // Returns an error status if the operand cannot be resolved. - Status AddTraceInstruction(const TraceRequest& trace_request); - - // Enqueues a random number generation instruction onto this user computation. - StatusOr AddRngInstruction( - const RngRequest& rng_request); - - // Enqueues a unary instruction onto this user computation. - // Returns an error status if the operand index is out of bounds. - StatusOr AddUnaryInstruction( - const UnaryOpRequest& unary_request); - - // Enqueues a batch norm training instruction onto this user computation. - StatusOr AddBatchNormTrainingInstruction( - const BatchNormTrainingRequest& batch_norm_training_request); - - // Enqueues a batch norm inference instruction onto this user computation. - StatusOr AddBatchNormInferenceInstruction( - const BatchNormInferenceRequest& batch_norm_inference_request); - - // Enqueues a batch norm grad instruction onto this user computation. - StatusOr AddBatchNormGradInstruction( - const BatchNormGradRequest& batch_norm_grad_request); - - // Enqueues a binary instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddBinaryInstruction( - const BinaryOpRequest& binary_request); - - // Enqueues a ternary instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddTernaryInstruction( - const TernaryOpRequest& ternary_request); - - // Enqueues a variadic instruction onto this user computation. - // Returns an error status if the operand indices are out of bounds. - StatusOr AddVariadicInstruction( - const VariadicOpRequest& variadic_request); - - // Enqueues a constant instruction onto this user computation. - StatusOr AddConstantInstruction( - const ConstantRequest& constant_request); - - // Enqueues a get tuple element instruction onto this user computation. - StatusOr AddGetTupleElementInstruction( - const GetTupleElementRequest& get_tuple_element_request); - - // Enqueues a map instruction onto this user computation. - StatusOr AddMapInstruction( - const MapRequest& map_request, - const UserComputation& to_apply_computation); - - // Enqueues a reduce-precision instruction onto this user computation. - StatusOr AddReducePrecisionInstruction( - const ReducePrecisionRequest& reduce_precision_request); - - // Enqueues a convolution instruction onto this user computation. - StatusOr AddConvolveInstruction( - const ConvolveRequest& convolve_request); - - // Enqueues an FFT instruction onto this user computation. - StatusOr AddFftInstruction( - const FftRequest& fft_request); - - // Enqueues a cross replica sum instruction onto this user computation. - StatusOr AddCrossReplicaSumInstruction( - const CrossReplicaSumRequest& cross_replica_sum_request); - - // Enqueues an infeed instruction onto this user computation. - StatusOr AddInfeedInstruction( - const InfeedRequest& infeed_request); - - // Enqueues an outfeed instruction onto this user computation. - StatusOr AddOutfeedInstruction( - const OutfeedRequest& outfeed_request); - - // Enqueues a host compute instruction onto this user computation. - StatusOr AddHostComputeInstruction( - const HostComputeRequest& host_compute_request); - - // Enqueues a call instruction onto this user computation. - StatusOr AddCallInstruction( - const CallRequest& call_request, - const UserComputation& to_apply_computation); - - // Enqueues a custom call instruction onto this user computation. - StatusOr AddCustomCallInstruction( - const CustomCallRequest& custom_call_request); - - // Enqueues a dot instruction onto this user computation. - StatusOr AddDotInstruction( - const DotRequest& dot_request); - - // Enqueues a broadcast instruction onto this user computation. - StatusOr AddBroadcastInstruction( - const BroadcastRequest& broadcast_request); - - // Enqueues a reshape instruction onto this user computation. - StatusOr AddReshapeInstruction( - const ReshapeRequest& reshape_request); - - // Enqueues a transpose instruction onto this user computation. - StatusOr AddTransposeInstruction( - const TransposeRequest& transpose_request); - - // Enqueues a slice instruction onto this user computation. - StatusOr AddSliceInstruction( - const SliceRequest& slice_request); - - // Enqueues a dynamic slice instruction onto this user computation. - StatusOr AddDynamicSliceInstruction( - const DynamicSliceRequest& dynamic_slice_request); - - // Enqueues a dynamic update slice instruction onto this user computation. - StatusOr AddDynamicUpdateSliceInstruction( - const DynamicUpdateSliceRequest& dynamic_update_slice_request); - - // Enqueues a concatenate instruction onto this user computation. - StatusOr AddConcatenateInstruction( - const ConcatenateRequest& concatenate_request); - - // Enqueues a convert instruction onto this user computation. - StatusOr AddConvertInstruction( - const ConvertRequest& convert_request); - - // Enqueues a bitcast element instruction onto this user computation. - StatusOr AddBitcastConvertInstruction( - const ConvertRequest& convert_request); - - // Enqueues a reduce instruction onto this user computation. - StatusOr AddReduceInstruction( - const ReduceRequest& reduce_request, - const UserComputation& to_apply_computation); - - // Enqueues a windowed reduce instruction onto this user computation. - StatusOr AddReduceWindowInstruction( - const ReduceWindowRequest& reduce_window_request, - const UserComputation& to_apply_computation); - - // Enqueues a select-and-scatter instruction onto this user - // computation. - StatusOr AddSelectAndScatterInstruction( - const SelectAndScatterRequest& select_and_scatter_request, - const UserComputation& select_computation, - const UserComputation& scatter_computation); - - // Enqueues a reverse instruction onto this user computation. - StatusOr AddReverseInstruction( - const ReverseRequest& reverse_request); - - // Enqueues a while instruction onto this user computation. - StatusOr AddWhileInstruction( - const WhileRequest& while_request, - const UserComputation& condition_computation, - const UserComputation& body_computation); - - // Enqueues a conditional instruction on this user computation. - StatusOr AddConditionalInstruction( - const ConditionalRequest& conditional_request, - const UserComputation& true_computation, - const UserComputation& false_computation); - - // Enqueues a Send instruction onto this user computation. - StatusOr AddSendInstruction( - const SendRequest& send_request); - - // Enqueues a Recv instruction onto this user computation. - StatusOr AddRecvInstruction( - const RecvRequest& recv_request); - - // Enqueues a Gather instruction onto this user computation. - StatusOr AddGatherInstruction( - const GatherRequest& gather_request); - - // Returns the user-provided name of this user computation, which is provided - // via the XLA computation-building API. - const string& name() const { return name_; } - - // Subsequent executions of this computation will compute the value - // represented by handle, rather than the last expression enqueued - // on the computation. - Status SetReturnValue(const ComputationDataHandle& handle); - - // Return a versioned handle for this computation. - VersionedComputationHandle GetVersionedHandle() const; - - // Return a versioned handle for this computation with a version equal to the - // point at which given operation was added to the computation. - VersionedComputationHandle GetVersionedHandleAtOperation( - const ComputationDataHandle& operation) const; - - // Return a version value representing the current state of the - // computation. - VersionedComputationHandle::Version version() const; - - // Computes and returns the program shape for the user computation -- gathers - // parameters and result type into a single proto. A shared_ptr is used - // because the returned pointer refers to an internally cached value which may - // be discarded by the UserComputation object. This avoid unnecessary copies. - // - // If the parameter space is not dense (i.e. there are holes in the parameter - // numbers provided) then an error status is returned. - StatusOr> ComputeProgramShape( - VersionedComputationHandle::Version version) const; - - // Returns true if the given data handle does not depend on any parameter with - // index higher then num_parameters. That is, the value can be computed at - // compile time if we know the first num_parameters arguments. - StatusOr IsConstant(const ComputationDataHandle& handle, - int64 num_parameters); - - // Returns the output shape of the operation indicated by the given handle. - StatusOr GetShape(const ComputationDataHandle& handle); - - // Sets metadata on the Hlo instruction referenced by the given handle. - Status SetOpMetadata(const ComputationDataHandle& handle, - const OpMetadata& metadata); - - // Sets the device assignment on the Hlo instruction referenced by 'handle'. - Status SetOpSharding(const ComputationDataHandle& handle, - const OpSharding& sharding); - - // Builds a HLO computation from the UserComputation. The parameter "resolver" - // is a function which returns a pointer to the HloComputation corresponding - // to the given ComputationHandle at the given version. The resolver is used - // for operations, such as map, which call other computations and need a - // pointer to the called HloComputation to construct the respective HLO - // instructions. If include_unreachable_instructions is true, then - // instructions which are not reachable from the root are lowered into - // HloInstructions. - using HloComputationResolver = - std::function; - StatusOr> BuildHloComputation( - VersionedComputationHandle::Version version, - HloComputationResolver hlo_resolver, const DebugOptions& debug_options, - bool include_unreachable_instructions = true) const; - - // Return a vector containing the embedded computations used by this - // UserComputation. Only embedded computations which are called directly by - // this UserComputation are included. That is, the transitive closure of - // embedded computations is not included. - std::vector GetEmbeddedComputations( - VersionedComputationHandle::Version version) const; - - // Returns the number of OperationRequest objects in this UserComputation. - // The 'version' of a computation is identical to the number of - // OperationRequests in the UserComputation. - int64 request_count(VersionedComputationHandle::Version version) const { - return version; - } - - // Returns a copy of the internal session state for this computation -- this - // is useful for serializing the guts of a user computation, though references - // to other handles (e.g. referred-to computations) must be handled with care - // in the serialization / de-serialization process. - SessionComputation CloneSessionComputation( - VersionedComputationHandle::Version version) const; - - // Warning: typically we don't want to look up computation data handles until - // the computation is finished being built, for consistency purposes. We - // expose this routine for error reporting purposes so that we can provide - // more meaningful error messages from the XLA service layer. - // - // Returns the operation request that the handle comes from. - StatusOr LookUpRequestForErrorReporting( - const ComputationDataHandle& handle) const; - - // Retrieves the parameter metadata for the given parameter number. - // - // If the parameter number is invalid for this computation, nullopt is - // returned. When the return value has_value(), nullptr will never be - // the held value. - tensorflow::gtl::optional ParameterMetadata( - int parameter_number) const; - - private: - // Warning: dangerous mutating operation that doesn't respect versioning. - // This is only used at initialization time when constructing from a - // SessionComputation a la MakeWithRemapping. - // - // Remaps references to old computations (with handle values in the keys of - // old_to_new) to the computation handle given in the values. This is useful - // when loading computations from snapshots, to finish initialization, before - // the user computation is released into the wild. - Status RemapEmbeddedComputations( - const std::map& old_to_new) - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Returns the OperationRequest corresponding to the given handle. - StatusOr LookUpRequest( - const ComputationDataHandle& handle) const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Creates a new ComputationDataHandle with the next available handle value. - ComputationDataHandle CreateComputationDataHandle() - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Checks whether the parameter numbers of the parameter operations are - // contiguous starting from zero. Returns appropriate error status if not. - Status CheckParametersAreContiguous( - VersionedComputationHandle::Version version) const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - VersionedComputationHandle GetVersionedHandleInternal() const - EXCLUSIVE_LOCKS_REQUIRED(mutex_); - - // Name of the computation. - string name_; - - mutable tensorflow::mutex mutex_; - - // State of the computation as a record of all operation-building requests. - SessionComputation session_computation_ GUARDED_BY(mutex_); - - // Mapping from parameter number to operation request containing the - // respective ParameterRequest. - std::map parameters_ GUARDED_BY(mutex_); - - // The next ComputationDataHandle value to assign. Handle values are assigned - // sequentially. - int64 next_handle_value_ GUARDED_BY(mutex_); - - // If handle_to_return_.has_handle() then an Execution of this Computation - // will compute the value represented by handle_to_return_, otherwise it will - // compute the value of (next_handle_value_ - 1). - ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_); - - // Memoized ProgramShape and its version. A shared_ptr is used because - // references to this object are returned by ComputeProgramShape. - mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0; - mutable std::shared_ptr program_shape_ GUARDED_BY(mutex_); - - TF_DISALLOW_COPY_AND_ASSIGN(UserComputation); -}; - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc deleted file mode 100644 index 2fa163953f638c0038e9f6bb11ce2a3742e0558c..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ /dev/null @@ -1,340 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/user_computation.h" - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" - -namespace op = xla::testing::opcode_matchers; - -namespace xla { -namespace { - -using UserComputationTest = ::testing::Test; - -TEST_F(UserComputationTest, SimpleComputation) { - const Shape kScalarShape = ShapeUtil::MakeShape(F32, {}); - const Shape kVectorShape = ShapeUtil::MakeShape(F32, {2}); - - // Build a simple three operation computatation: - // - // %constant = Constant({123, 42}) - // %param = Param(0) - // %outfeed = Outfeed(%constant) - // - // Build the computation at two different versions and check invariants. - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ConstantRequest constant_request; - *constant_request.mutable_literal() = - Literal::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle constant_handle, - computation.AddConstantInstruction(constant_request)); - - ParameterRequest param_request; - *param_request.mutable_shape() = kScalarShape; - param_request.set_parameter(0); - param_request.set_name("param0"); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle param_handle, - computation.AddParameterInstruction(param_request)); - OpMetadata metadata; - metadata.set_op_name("meta"); - TF_ASSERT_OK(computation.SetOpMetadata(param_handle, metadata)); - - OutfeedRequest outfeed_request; - *outfeed_request.mutable_operand() = constant_handle; - *outfeed_request.mutable_shape() = kVectorShape; - outfeed_request.set_outfeed_config("abc"); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle, - computation.AddOutfeedInstruction(outfeed_request)); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - { - // Test the computation at the latest version. In this case, the most - // recently added operation is an outfeed. However, the outfeed is not the - // root because outfeeds cannot be the root of a computation. - VersionedComputationHandle latest_version = - computation.GetVersionedHandle(); - - // Program shape should have a single scalar parameter and scalar - // result. The outfeed instruction should not affect the program shape. - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr program_shape, - computation.ComputeProgramShape(latest_version.version)); - ASSERT_EQ(1, program_shape->parameters_size()); - EXPECT_TRUE( - ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); - EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - DebugOptions())); - // There should be one HloInstruction per UserComputation operation. - EXPECT_EQ(3, hlo_computation->instruction_count()); - // The root of the instruction should be the parameter instruction (not the - // outfeed). - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - } - - { - // Test the computation at the version right after the parameter instruction - // is added. - VersionedComputationHandle version_at_param = - computation.GetVersionedHandleAtOperation(param_handle); - - // Program shape should have a single scalar parameter, and scalar result. - TF_ASSERT_OK_AND_ASSIGN( - std::shared_ptr program_shape, - computation.ComputeProgramShape(version_at_param.version)); - ASSERT_EQ(1, program_shape->parameters_size()); - EXPECT_TRUE( - ShapeUtil::Compatible(kScalarShape, program_shape->parameters(0))); - EXPECT_TRUE(ShapeUtil::Compatible(kScalarShape, program_shape->result())); - - // There should be two instructions, one for the constant and one for the - // parameter. The outfeed instruction should not be included. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(version_at_param.version, hlo_resolver, - DebugOptions())); - EXPECT_EQ(2, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - } - { - // Test the computation at the latest version, but lowered with - // include_unreachable_instructions set to false. - VersionedComputationHandle latest_version = - computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation( - latest_version.version, hlo_resolver, DebugOptions(), - /*include_unreachable_instructions=*/false)); - // There is only one reachable instruction, the parameter. - EXPECT_EQ(1, hlo_computation->instruction_count()); - // The root of the instruction should be the parameter instruction (not the - // outfeed). - EXPECT_THAT(hlo_computation->root_instruction(), op::Parameter()); - EXPECT_EQ(hlo_computation->root_instruction()->metadata().op_name(), - "meta"); - } -} - -TEST_F(UserComputationTest, EliminateScalarBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with scalar broadcast. - // - // %a = Constant({123, 42}) - // %b = Constant(1) - // %add = Add(%a, %b) - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ConstantRequest a_request; - *a_request.mutable_literal() = - Literal::CreateR1({123.0f, 42.0f})->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddConstantInstruction(a_request)); - - ConstantRequest b_request; - *b_request.mutable_literal() = Literal::CreateR0(1.0f)->ToProto(); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddConstantInstruction(b_request)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - // The binary operation has implicit scalar broadcast, should be converted - // to an explicit broadcast intruction and a binary instruction. - EXPECT_EQ(4, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - LOG(INFO) << hlo_computation->root_instruction()->ToString(); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast || - operands[1]->opcode() == HloOpcode::kBroadcast); -} - -TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with degenerate broadcast. - // - // %a = Param({1, 2, 3}); - // %b = Param({1, 2, 1}); - // %add = Add(%a, %b, {}); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 2, 1}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - const int64 kDevice = 7; - OpSharding sharding; - sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - sharding.add_tile_assignment_dimensions(1); - sharding.add_tile_assignment_devices(kDevice); - - TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - // b a - // | | - // reshape | - // | | - // broadcast | - // \ / - // add - EXPECT_EQ(5, hlo_computation->instruction_count()); - ASSERT_THAT( - hlo_computation->root_instruction(), - op::Add(op::Parameter(), op::Broadcast(op::Reshape(op::Parameter())))); - - const HloInstruction* broadcast = - hlo_computation->root_instruction()->operand(1); - EXPECT_TRUE(broadcast->has_sharding()); - - const HloInstruction* reshape = broadcast->operand(0); - EXPECT_TRUE(reshape->has_sharding()); -} - -TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // Build a binary computation with in-dim broadcast and degenerate broadcast. - // - // %a = Param({2, 3}); - // %b = Param({2, 1, 4}); - // %add = Add(%a, %b, {0, 1}); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - BinaryOpRequest add; - add.set_binop(BINOP_ADD); - *add.mutable_lhs() = a_handle; - *add.mutable_rhs() = b_handle; - add.add_broadcast_dimensions(0); - add.add_broadcast_dimensions(1); - TF_ASSERT_OK(computation.AddBinaryInstruction(add).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - // The binary operation has in-dim broadcast and degenerate broadcast, should - // first do the in-dim broadcast then convert the degnerate broadcast into a - // reshape and a broadcast. - // - // b a - // | | - // broadcast reshape - // | | - // | broadcast - // \ / - // add - EXPECT_EQ(6, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast && - operands[1]->opcode() == HloOpcode::kBroadcast); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.h b/tensorflow/compiler/xla/service/versioned_computation_handle.h deleted file mode 100644 index 5732a56caffa31dde52dff5c2775f9fde0cacfbd..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/versioned_computation_handle.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// A data structure encapsulating a ComputationHandle and version value of that -// computation. This object is used to unambiguously refer to a particular -// computation in the service. -struct VersionedComputationHandle { - // A version value unambiguously specifying the state of the computation at a - // particular point in time as it is being built. This value is the - // ComputationDataHandle of the current root instruction. - using Version = int64; - - ComputationHandle handle; - Version version; - - string ToString() const; - bool operator==(const VersionedComputationHandle& other) const { - return (handle.handle() == other.handle.handle()) && - (version == other.version); - } - bool operator<(const VersionedComputationHandle& other) const { - return ((handle.handle() < other.handle.handle()) || - ((handle.handle() == other.handle.handle()) && - (version < other.version))); - } -}; - -std::ostream& operator<<(std::ostream& out, - const VersionedComputationHandle& versioned_handle); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc index 0d2288d8ea6ebb0ac4ac9468a211b161438fc5f1..393e75803888d8a642881c4d525b170d1e1180ba 100644 --- a/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace xla { @@ -55,7 +55,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -95,7 +95,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -136,7 +136,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); @@ -184,7 +184,7 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); TF_ASSERT_OK_AND_ASSIGN(bool changed, WhileLoopConstantSinking{}.Run(module.get())); diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index 321fdeb1ea313d2bc00b0210b422f36915f41453..09ddcffb22c2184262adf87d570870ec000c0e6f 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -98,14 +98,17 @@ static void CreateLoopInvariantCopy( // Returns true if `instruction` is worth hoisting only if it lets us hoist some // instruction using it. The rationale is that hoisting these instructions will // prevent simplification and fusion in the while body. -static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { +bool WhileLoopInvariantCodeMotion::NotWorthHoistingIndividually( + const HloInstruction& instruction) { switch (instruction.opcode()) { default: return false; + case HloOpcode::kConstant: + return !hoist_constants_; + case HloOpcode::kBitcast: case HloOpcode::kBroadcast: - case HloOpcode::kConstant: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: @@ -115,7 +118,8 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { } } -static StatusOr TryHoistingInvariantInstructionsFromWhileBody( +StatusOr +WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( HloInstruction* while_instr) { auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); @@ -161,12 +165,16 @@ static StatusOr TryHoistingInvariantInstructionsFromWhileBody( } } - if (unhoisted_invariant_instructions.empty()) { + if (unhoisted_invariant_instructions.empty() && !hoist_constants_) { // There are no obviously loop invariant elements in the state being // threaded through the while loop so give up. In theory this precondition // is too strong -- we could have code that e.g. permutes the elements in // the while state but uses a select to pick the same value on every // iteration. + // + // If we were asked to hoist constants, we need to scan the while body for + // constants even if we didn't find any loop invariant values in the while + // state tuple. return false; } @@ -243,6 +251,9 @@ static StatusOr TryHoistingInvariantInstructionsFromWhileBody( } StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { + VLOG(2) << "HLO module before WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + bool changed = false; std::vector while_instrs; for (auto* comp : module->computations()) { @@ -270,6 +281,14 @@ StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { TryHoistingInvariantInstructionsFromWhileBody(while_instr)); changed |= result; } + + if (changed) { + VLOG(2) << "HLO module after WhileLoopConstantSinking:"; + XLA_VLOG_LINES(2, module->ToString()); + } else { + VLOG(2) << "HLO module unchanged after WhileLoopConstantSinking"; + } + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h index 8c4b765b0003c48cfacb9d28e7c8259ac0927d66..8e6cc8787576e4f041229da5cf8dd2b09194eb2a 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -27,12 +27,28 @@ namespace xla { class WhileLoopInvariantCodeMotion : public HloPassInterface { public: + // If `hoist_constants` is true then constants are always hoisted out of while + // loop bodies. Otherwise they are only hoisted out if they enable other + // non-trivial computations to be hoisted out. + // + // Setting `hoist_constants` to false can be help if LICM is run in the mid + // level HLO pipeline because hoisting constants out of while loop bodies can + // break optimizations like constant folding. + explicit WhileLoopInvariantCodeMotion(bool hoist_constants = false) + : hoist_constants_(hoist_constants) {} ~WhileLoopInvariantCodeMotion() override = default; tensorflow::StringPiece name() const override { return "while-loop-invariant-code-motion"; } StatusOr Run(HloModule* module) override; + + private: + bool NotWorthHoistingIndividually(const HloInstruction& instruction); + StatusOr TryHoistingInvariantInstructionsFromWhileBody( + HloInstruction* while_instr); + + bool hoist_constants_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc index 799340fda905fb7d40b19b4cb79bb0fcb5629fd3..8831c513eee66e36163135b732f833d46cb7eb03 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -438,5 +439,77 @@ TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { EXPECT_FALSE(simplified_loop); } +const char* const kConstantHoistingTestCase = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2]{0}) parameter(0) + p_body.1 = f32[2]{0} get-tuple-element(p_body), index=0 + const = f32[2]{0} constant({3, 4}) + add.0 = f32[2]{0} add(p_body.1, const) + ROOT root = (f32[2]{0}) tuple(add.0) +} + +condition { + p_cond = (f32[2]{0}) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2]{0} constant({1, 2}) + while_init = (f32[2]{0}) tuple(const_0) + ROOT while = (f32[2]{0}) while(while_init), condition=condition, body=body +} +)"; + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistsConstantWhenAsked) { + ParseAndVerifyModule(kConstantHoistingTestCase); + + TF_ASSERT_OK_AND_ASSIGN( + bool simplified_loop, + WhileLoopInvariantCodeMotion{/*hoist_constants=*/true}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloComputation* while_body = module().GetComputationWithName("wide.body"); + ASSERT_NE(while_body, nullptr); + + // We expect the while body to be the equivalent of: + // + // wide.body { + // wide_param.1 = (f32[2]{0}, f32[2]{0}) parameter(0) + // get-tuple-element.1 = f32[2]{0} get-tuple-element(wide_param.1), index=0 + // tuple.1 = (f32[2]{0}) tuple(get-tuple-element.1) + // get-tuple-element.4 = f32[2]{0} get-tuple-element(tuple.1), index=0 + // get-tuple-element.7 = f32[2]{0} get-tuple-element(wide_param.1), index=1 + // add.1 = f32[2]{0} add(get-tuple-element.4, get-tuple-element.7) + // tuple.3 = (f32[2]{0}) tuple(add.1) + // get-tuple-element.8 = f32[2]{0} get-tuple-element(tuple.3), index=0 + // get-tuple-element.9 = f32[2]{0} get-tuple-element(wide_param.1), index=1 + // ROOT tuple.4 = (f32[2]{0}, f32[2]{0}) tuple(get-tuple-element.8, + // get-tuple-element.9) + // } + + auto wide_param_1 = op::Parameter(0); + auto get_tuple_element_1 = op::GetTupleElement(wide_param_1, 0); + auto tuple_1 = op::Tuple(get_tuple_element_1); + auto get_tuple_element_4 = op::GetTupleElement(tuple_1, 0); + auto get_tuple_element_7 = op::GetTupleElement(wide_param_1, 1); + auto add_1 = op::Add(get_tuple_element_4, get_tuple_element_7); + auto tuple_3 = op::Tuple(add_1); + auto get_tuple_element_8 = op::GetTupleElement(tuple_3, 0); + auto get_tuple_element_9 = op::GetTupleElement(wide_param_1, 1); + auto tuple_4 = op::Tuple(get_tuple_element_8, get_tuple_element_9); + + EXPECT_THAT(while_body->root_instruction(), tuple_4); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DoesNotHoistConstantByDefault) { + ParseAndVerifyModule(kConstantHoistingTestCase); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc index ed20b36292a7f24385603627d74fc72ba6b3b724..473eab2ea84eb8faf745cbe299bc80bcc1b62a35 100644 --- a/tensorflow/compiler/xla/service/while_util.cc +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -117,9 +117,13 @@ WhileUtil::MakeInstructionsLiveIn( HloInstruction* new_while = containing_computation->AddInstruction( HloInstruction::CreateWhile(new_while_shape, new_while_condition, new_while_body, new_while_init)); - TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction( - while_instr, TupleUtil::ExtractPrefix( - new_while, while_instr->shape().tuple_shapes_size()))); + + // We want to get rid of the old while instruction even if it has side + // effecting operations so we do a manual HloComputation::RemoveInstruction + // instead of relying on HloComputation::ReplaceInstruction. + TF_RETURN_IF_ERROR(while_instr->ReplaceAllUsesWith(TupleUtil::ExtractPrefix( + new_while, while_instr->shape().tuple_shapes_size()))); + TF_RETURN_IF_ERROR(containing_computation->RemoveInstruction(while_instr)); HloInstruction* while_body_param = new_while_body->parameter_instruction(0); std::vector live_in_instructions; diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h index 322d27b88cae60cb051f5fafdde70e2aafedbc1e..e67636d80f4b682fe1335eae535fb86105ac082b 100644 --- a/tensorflow/compiler/xla/service/while_util.h +++ b/tensorflow/compiler/xla/service/while_util.h @@ -38,17 +38,21 @@ class WhileUtil { }; // Replaces `while_instr` with a new while instruction that is equivalent to - // `while_instr`, except that it has all of the HLO instructions in + // `while_instr` except that it has all of the HLO instructions in // `instructions` as live-in, loop invariant values. These new live in values // are represented as new elements appended to the parameter of the while // loop, which must be of tuple shape. GetTupleElement instructions computing // each new live in value is returned in the `while_body_live_in_values` // vector. // - // Precondition: `while_instr` must have a tuple shaped state. + // Deletes `while_instr` after replacing it. // - // Every instruction in `instructions` must be contained in the computation - // that contains `while_instr`. + // Preconditions: + // + // `while_instr` must have a tuple shaped state. + // + // Every instruction in `instructions` must be contained in the computation + // that contains `while_instr`. static StatusOr MakeInstructionsLiveIn( HloInstruction* while_instr, tensorflow::gtl::ArraySlice instructions); diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc index 974bc542a34d0af6d41ed29f36df87f4c164a360..d79d3297213e832306ea4726483b0f215df0f5d3 100644 --- a/tensorflow/compiler/xla/service/while_util_test.cc +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -16,8 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/while_util.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/util.h" namespace xla { namespace { @@ -49,7 +50,7 @@ ENTRY entry { )"; TF_ASSIGN_OR_RETURN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); *entry_computation = module->entry_computation(); *param0 = (*entry_computation)->parameter_instruction(0); @@ -150,7 +151,7 @@ ENTRY main { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_string)); + ParseHloString(hlo_string)); HloComputation* while_body = module->GetComputationWithName("body"); @@ -163,5 +164,47 @@ ENTRY main { ASSERT_EQ(gte_list.size(), 1); EXPECT_EQ((*gte_list.begin())->name(), "gte.0"); } + +TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) { + const char* const hlo_string = R"( +HloModule WhileWithSideEffects + +body { + param.b = (s32[], s32[]) parameter(0) + gte.0 = s32[] get-tuple-element(param.b), index=0 + gte.1 = s32[] get-tuple-element(param.b), index=1 + add = s32[] add(gte.0, gte.1) + ROOT tuple = (s32[], s32[]) tuple(gte.0, add) +} + +cond { + param.c = (s32[], s32[]) parameter(0) + ROOT condition = pred[] infeed() +} + +ENTRY main { + init = (s32[], s32[]) parameter(0) + to_make_live_in = f32[100] parameter(1) + ROOT while = (s32[], s32[]) while(init), condition=cond, body=body +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseHloString(hlo_string)); + + HloComputation* main = module->GetComputationWithName("main"); + HloInstruction* while_instr = main->root_instruction(); + HloInstruction* to_make_live_in = main->parameter_instruction(1); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, + /*instructions=*/{to_make_live_in})); + + auto is_while = [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }; + EXPECT_EQ(c_count_if(main->instructions(), is_while), 1); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h index 141347a792c23a2c542d7b564ab76c118409865d..14c35e7b84f07bebac33a9753ac26a8ee1418f1e 100644 --- a/tensorflow/compiler/xla/service_interface.h +++ b/tensorflow/compiler/xla/service_interface.h @@ -47,41 +47,22 @@ class ServiceInterface { virtual Status ResetDevice(const ResetDeviceRequest* arg, ResetDeviceResponse* result) = 0; - virtual Status LoadComputationSnapshot( - const LoadComputationSnapshotRequest* request, - LoadComputationSnapshotResponse* result) = 0; - - virtual Status Execute(const ExecuteRequest* arg, - ExecuteResponse* result) = 0; - virtual Status ExecuteGraph(const ExecuteGraphRequest* arg, ExecuteResponse* result) = 0; - virtual Status ExecuteParallel(const ExecuteParallelRequest* arg, - ExecuteParallelResponse* result) = 0; - virtual Status ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, ExecuteParallelResponse* result) = 0; - virtual Status ExecuteAsync(const ExecuteAsyncRequest* arg, - ExecuteAsyncResponse* result) = 0; - virtual Status WaitForExecution(const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0; virtual Status DeconstructTuple(const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0; - virtual Status GetComputationStats(const ComputationStatsRequest* arg, - ComputationStatsResponse* result) = 0; - virtual Status GetComputationGraphStats( const ComputationGraphStatsRequest* arg, ComputationStatsResponse* result) = 0; - virtual Status GetComputationShape(const GetComputationShapeRequest* arg, - GetComputationShapeResponse* result) = 0; - virtual Status GetShape(const GetShapeRequest* arg, GetShapeResponse* result) = 0; @@ -91,31 +72,9 @@ class ServiceInterface { virtual Status GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0; - // Methods used by ComputationBuilder. - virtual Status Computation(const ComputationRequest* arg, - ComputationResponse* result) = 0; - - virtual Status Op(const OpRequest* arg, OpResponse* result) = 0; - - virtual Status GetLocalShape(const GetLocalShapeRequest* arg, - GetLocalShapeResponse* result) = 0; - - virtual Status SetReturnValue(const SetReturnValueRequest* arg, - SetReturnValueResponse* results) = 0; - - virtual Status IsConstant(const IsConstantRequest* arg, - IsConstantResponse* result) = 0; - - virtual Status ComputeConstant(const ComputeConstantRequest* arg, - ComputeConstantResponse* result) = 0; - virtual Status ComputeConstantGraph(const ComputeConstantGraphRequest* arg, ComputeConstantResponse* result) = 0; - // Methods used by Computation. - virtual Status SnapshotComputation(const SnapshotComputationRequest* ag, - SnapshotComputationResponse* result) = 0; - // Methods used by GlobalData. virtual Status Unregister(const UnregisterRequest* arg, UnregisterResponse* result) = 0; diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 37c94ac543b166c14affd8165d244440ae6b67d6..18e54d23c241ae0d4c61d8be79ff021dfb02a3e6 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -47,6 +47,9 @@ struct ShapeTreeNode { // Children of this node, as indices into the container's nodes_ array. std::vector children; + // Tells whether this is a leaf node. + bool is_leaf = true; + explicit ShapeTreeNode(ShapeIndex index) : ShapeTreeNode(std::move(index), T()) {} ShapeTreeNode(ShapeIndex index, T data) @@ -122,9 +125,7 @@ class ShapeTree { // Returns true if the node at the given index is a leaf node (an array // shape). - bool IsLeaf(const ShapeIndex& index) const { - return Lookup(index)->children.empty(); - } + bool IsLeaf(const ShapeIndex& index) const { return Lookup(index)->is_leaf; } ShapeTree(const ShapeTree&) = default; ShapeTree& operator=(const ShapeTree&) = default; @@ -222,6 +223,9 @@ class ShapeTree { /*iterate_leaves_only=*/false); } + // Returns the number of leaf nodes in the tree. + int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); } + // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // @@ -308,16 +312,14 @@ class ShapeTreeIterator : nodes_(nodes), node_(std::move(node)), iterate_leaves_only_(iterate_leaves_only) { - while (iterate_leaves_only && node_ != nodes_->end() && - !node_->children.empty()) { + while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) { ++node_; } } ShapeTreeIterator& operator++() { ++node_; - while (iterate_leaves_only_ && node_ != nodes_->end() && - !node_->children.empty()) { + while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) { ++node_; } return *this; @@ -330,8 +332,7 @@ class ShapeTreeIterator ShapeTreeIterator& operator--() { --node_; - while (iterate_leaves_only_ && node_ > nodes_->begin() && - !node_->children.empty()) { + while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) { --node_; } return *this; @@ -355,7 +356,7 @@ class ShapeTreeIterator ContainerType* nodes_; IteratorType node_; // True if we should not include interior nodes in our walk. - bool iterate_leaves_only_; + const bool iterate_leaves_only_; }; template @@ -376,6 +377,7 @@ void ShapeTree::InitChildren(const Shape& shape, const T& init_value, if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); node->children.reserve(size); + node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); for (int i = 0; i < size; ++i) { @@ -392,6 +394,7 @@ void ShapeTree::InitChildren(const Shape& shape, Node* node) { if (ShapeUtil::IsTuple(shape)) { const int64 size = ShapeUtil::TupleElementCount(shape); node->children.reserve(size); + node->is_leaf = false; ShapeIndex shape_index = node->data.first; shape_index.push_back(0); for (int i = 0; i < size; ++i) { diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index dc5facf1581c07fbb74dfcee95025692938632bd..51de82e95746281ed6e587b545dc933b48ce1ad4 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -116,6 +116,11 @@ TEST_F(ShapeTreeTest, InitValueConstructor) { TestInitValueConstructor(nested_tuple_shape_, 10); } +TEST_F(ShapeTreeTest, EmptyTupleMustHaveNoLeaves) { + ShapeTree shape_tree{ShapeUtil::MakeTupleShape({})}; + EXPECT_EQ(0, shape_tree.leaf_count()); +} + TEST_F(ShapeTreeTest, ArrayShape) { ShapeTree shape_tree{array_shape_}; *shape_tree.mutable_element({}) = 42; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 7a897f6f8f99e65285e1be0757a55f703fc81c72..5db66599324913b9214d7623597060950246fb03 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/iterator_range.h" @@ -42,17 +41,35 @@ limitations under the License. namespace xla { +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + string ShapeIndex::ToString() const { - return tensorflow::strings::StrCat( - "{", tensorflow::str_util::Join(indices_, ","), "}"); + return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}"); } string ShapeIndexView::ToString() const { - return tensorflow::strings::StrCat( - "{", - tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_), - ","), - "}"); + return StrCat("{", + tensorflow::str_util::Join( + tensorflow::gtl::make_range(begin_, end_), ","), + "}"); +} + +bool ShapeIndexView::operator==(const ShapeIndexView& other) const { + if (size() != other.size()) { + return false; + } + for (auto it = begin(), other_it = other.begin(); it != end(); + ++it, ++other_it) { + if (*it != *other_it) { + return false; + } + } + return true; +} + +bool ShapeIndexView::operator!=(const ShapeIndexView& other) const { + return !(*this == other); } std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { @@ -67,18 +84,30 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { namespace { +// Returns whether the given primitive type corresponds to an array shape. +bool IsArrayPrimitiveType(PrimitiveType primitive_type) { + return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE && + primitive_type != OPAQUE && primitive_type != TOKEN; +} + // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) { - return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) && - ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + if (!ShapeUtil::SameElementType(lhs, rhs)) { + VLOG(3) << "CompareShapes: lhs element type != rhs element type"; + return false; + } + + if (ShapeUtil::IsTuple(lhs)) { + return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { return CompareShapes(l, r, compare_layouts); }); - } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) { - return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs); + } else if (!ShapeUtil::IsArray(lhs)) { + // Non-tuple, non-array tupes such as opaque and token types are trivially + // the same. + return true; } if (compare_layouts) { @@ -108,10 +137,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions"; return false; } - if (!ShapeUtil::SameElementType(lhs, rhs)) { - VLOG(3) << "CompareShapes: lhs element type != rhs element type"; - return false; - } return true; } @@ -154,8 +179,8 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ int64 ShapeUtil::Rank(const Shape& shape) { - CHECK(!ShapeUtil::IsTuple(shape)) - << "Tuples do not have a rank, shape: " << shape; + CHECK(ShapeUtil::IsArray(shape)) + << "Non-arrays do not have a rank, shape: " << shape; return shape.dimensions_size(); } @@ -182,8 +207,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShape( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { - DCHECK_NE(TUPLE, element_type); - DCHECK_NE(OPAQUE, element_type); + CHECK(IsArrayPrimitiveType(element_type)); Shape result; PopulateShape(element_type, dimensions, &result); return result; @@ -206,8 +230,7 @@ StatusOr MakeShapeWithLayoutInternal( /* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, int64 max_sparse_elements) { - DCHECK_NE(TUPLE, element_type); - DCHECK_NE(OPAQUE, element_type); + CHECK(IsArrayPrimitiveType(element_type)); Shape shape = ShapeUtil::MakeShape(element_type, dimensions); *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); @@ -254,6 +277,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return result; } +/* static */ Shape ShapeUtil::MakeTokenShape() { + Shape result; + result.set_element_type(TOKEN); + TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result)); + return result; +} + /* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape, Shape* tuple_shape) { TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape)); @@ -277,7 +307,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) { - if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) { + if (!IsArray(shape)) { return false; } return primitive_util::BitWidth(shape.element_type()) == bits; @@ -303,6 +333,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( case C64: case TUPLE: case OPAQUE: + case TOKEN: return false; default: @@ -318,6 +349,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return primitive_util::IsFloatingPointType(shape.element_type()); } +/* static */ bool ShapeUtil::IsArray(const Shape& shape) { + return IsArrayPrimitiveType(shape.element_type()); +} + /* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) { return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), IsTuple); @@ -371,7 +406,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape); + CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape); CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, @@ -386,23 +421,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return shape.element_type() == F32 && Rank(shape) == 0; } -/* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (IsTuple(shape)) { - string text = "("; - const char* prefix = ""; - for (const Shape& elem_shape : shape.tuple_shapes()) { - tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape)); - prefix = ", "; - } - text += ")"; - return text; - } else { - return tensorflow::strings::StrCat( - tensorflow::str_util::Lowercase( - PrimitiveType_Name(shape.element_type())), - "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]"); - } -} namespace { @@ -453,48 +471,56 @@ StatusOr StringToPrimitiveType(const string& name) { } // namespace -/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { +/* static */ string ShapeUtil::HumanString(const Shape& shape) { if (IsTuple(shape)) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { - tensorflow::strings::StrAppend(&text, prefix, - HumanStringWithLayout(elem_shape)); + StrAppend(&text, prefix, HumanString(elem_shape)); prefix = ", "; } text += ")"; return text; - } else { - string result = tensorflow::strings::StrCat( - LowercasePrimitiveTypeName(shape.element_type()), "["); - for (int i = 0; i < shape.dimensions().size(); i++) { - tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "", - shape.dimensions(i)); + } + return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[", + tensorflow::str_util::Join(shape.dimensions(), ","), "]"); +} + +/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { + if (IsTuple(shape)) { + string text = "("; + const char* prefix = ""; + for (const Shape& elem_shape : shape.tuple_shapes()) { + StrAppend(&text, prefix, HumanStringWithLayout(elem_shape)); + prefix = ", "; } - result += "]"; - if (!IsScalar(shape) && !IsOpaque(shape)) { - if (LayoutUtil::HasLayout(shape)) { - tensorflow::strings::StrAppend(&result, - LayoutUtil::HumanString(shape.layout())); - } + text += ")"; + return text; + } + string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "["); + for (int i = 0; i < shape.dimensions().size(); i++) { + StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i)); + } + result += "]"; + if (!IsScalar(shape) && IsArray(shape)) { + if (LayoutUtil::HasLayout(shape)) { + StrAppend(&result, LayoutUtil::HumanString(shape.layout())); } - return result; } + return result; } /* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) { std::vector parameters; for (auto& shape : program_shape.parameters()) { const int i = parameters.size(); - parameters.push_back( - tensorflow::strings::StrCat(i < program_shape.parameter_names_size() - ? program_shape.parameter_names(i) - : "(unknown)", - ": ", HumanString(shape))); + parameters.push_back(StrCat(i < program_shape.parameter_names_size() + ? program_shape.parameter_names(i) + : "(unknown)", + ": ", HumanString(shape))); } - return tensorflow::strings::StrCat( - "(", tensorflow::str_util::Join(parameters, ", "), ") -> ", - HumanString(program_shape.result())); + return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ", + HumanString(program_shape.result())); } namespace { @@ -564,14 +590,17 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // Extract the primitive element type. TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, StringToPrimitiveType(element_type_string)); - if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || - primitive_type == OPAQUE) { + if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) { return InvalidArgument("Invalid element type string: \"%s\".", element_type_string.c_str()); } Shape result; - if (format_string.empty() && layout_string.empty()) { + if (primitive_type == OPAQUE) { + result = ShapeUtil::MakeOpaqueShape(); + } else if (primitive_type == TOKEN) { + result = ShapeUtil::MakeTokenShape(); + } else if (format_string.empty() && layout_string.empty()) { // Create a shape without a layout set. result = ShapeUtil::MakeShape(primitive_type, dimensions); } else if (format_string == "sparse") { @@ -616,43 +645,44 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { } /* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return IsArray(rhs) && SameDimensions(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringElementType); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs); } /* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs) { - if (lhs.element_type() == TUPLE) { + if (IsArray(lhs)) { + return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) && + CompatibleIgnoringElementType(lhs, rhs); + } else if (lhs.element_type() == TUPLE) { return rhs.element_type() == TUPLE && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), CompatibleIgnoringFpPrecision); + } else { + // Opaque, token, etc types are vacuously compatible. + return true; } - if (lhs.element_type() == OPAQUE) { - return rhs.element_type() == OPAQUE; - } - if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { - return CompatibleIgnoringElementType(lhs, rhs); - } - return false; } /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, @@ -674,10 +704,6 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { switch (primitive_type) { case PRED: return sizeof(int8); - case TUPLE: - LOG(FATAL) << "tuples have no definitive size"; - case OPAQUE: - LOG(FATAL) << "opaque have no definitive size"; case S8: return sizeof(int8); case S16: @@ -704,6 +730,13 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(double); case C64: return sizeof(complex64); + case TOKEN: + // Tokens require no space. + return 0; + case TUPLE: + case OPAQUE: + LOG(FATAL) << PrimitiveType_Name(primitive_type) + << " primitive type has no definitive size"; default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; } @@ -712,28 +745,32 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape, int64 pointer_size) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK_NE(OPAQUE, shape.element_type()); if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); + } else if (IsArray(shape)) { + int64 byte_size = ByteSizeOfElements(shape); + if (LayoutUtil::IsSparseArray(shape)) { + byte_size += ByteSizeOfSparseIndices(shape); + } + return byte_size; + } else if (shape.element_type() == TOKEN) { + return 0; } - int64 byte_size = ByteSizeOfElements(shape); - if (LayoutUtil::IsSparseArray(shape)) { - byte_size += ByteSizeOfSparseIndices(shape); - } - return byte_size; + LOG(FATAL) << PrimitiveType_Name(shape.element_type()) + << " primitive type has no definitive size"; } /* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, int64 pointer_size) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK_EQ(TUPLE, shape.element_type()); + CHECK_EQ(TUPLE, shape.element_type()); CHECK_GT(pointer_size, 0); return pointer_size * shape.tuple_shapes_size(); } /* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK(ShapeUtil::IsArray(shape)); + CHECK(ShapeUtil::IsArray(shape)); int64 allocated_element_count; if (LayoutUtil::IsSparseArray(shape)) { @@ -758,13 +795,17 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { TF_DCHECK_OK(ValidateShape(shape)); - DCHECK(LayoutUtil::IsSparseArray(shape)); + CHECK(LayoutUtil::IsSparseArray(shape)); return LayoutUtil::MaxSparseElements(shape.layout()) * ShapeUtil::Rank(shape) * sizeof(int64); } /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { + if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { + return InvalidArgument("shape has invalid element type: %s", + shape.ShortDebugString().c_str()); + } if (shape.element_type() == TUPLE) { if (shape.dimensions_size() != 0) { return InvalidArgument("tuples must not have dimensions specified"); @@ -780,10 +821,24 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { if (shape.tuple_shapes_size() > 0) { return InvalidArgument("non-tuple shape has tuple_shapes field"); } - if (shape.element_type() == PRIMITIVE_TYPE_INVALID) { - return InvalidArgument("shape has invalid element type: %s", - shape.ShortDebugString().c_str()); + + // Tokens and opaques can should not have layout or dimensions. + if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) { + if (shape.dimensions_size() != 0) { + return InvalidArgument( + "shape has %s element type, but has dimensions field: %s", + LowercasePrimitiveTypeName(shape.element_type()).c_str(), + shape.ShortDebugString().c_str()); + } + if (shape.has_layout()) { + return InvalidArgument( + "shape has %s element type, but has layout field: %s", + LowercasePrimitiveTypeName(shape.element_type()).c_str(), + shape.ShortDebugString().c_str()); + } + return Status::OK(); } + if (Rank(shape) != shape.dimensions_size()) { return InvalidArgument( "shape's rank is mismatched with dimension count; rank=%lld " @@ -863,64 +918,25 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) { return !IsTuple(GetSubshape(shape, index)); } -/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) { - std::vector dimension_sizes; - std::vector degenerate_dimensions; - for (int64 i = 0; i < shape.dimensions_size(); ++i) { - if (shape.dimensions(i) == 1) { - degenerate_dimensions.push_back(i); - } else { - dimension_sizes.push_back(shape.dimensions(i)); +/* static */ int64 ShapeUtil::GetLeafCount(const Shape& shape) { + int64 count = 0; + ForEachSubshape(shape, [&](const Shape&, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + ++count; } - } - - // Construct minor_to_major of stripped shape. The order of the non-degenerate - // dimensions should be preserved from the original shape. First, create - // vector of the non-degenerate dimensions from the original minor_to_major - // array. - std::vector minor_to_major; - for (int64 i : shape.layout().minor_to_major()) { - if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(), - i) == degenerate_dimensions.end()) { - minor_to_major.push_back(i); - } - } + }); + return count; +} - // The dimensions in minor_to_major need to be renumbered to account for the - // degenerate dimensions which have removed. Decrement each dimension number - // once for each degenerate dimension which has a smaller number. - for (int i = 0; i < minor_to_major.size(); ++i) { - int adjustment = 0; - for (int64 dim : degenerate_dimensions) { - if (minor_to_major[i] > dim) { - adjustment++; - } +/* static */ std::vector ShapeUtil::GetLeafShapes( + const Shape& shape) { + std::vector leaves; + ForEachSubshape(shape, [&](const Shape& sub_shape, const ShapeIndex& index) { + if (IsLeafIndex(shape, index)) { + leaves.emplace_back(index, sub_shape); } - minor_to_major[i] -= adjustment; - } - - { - std::vector dims(minor_to_major.size()); - std::iota(dims.begin(), dims.end(), 0); - DCHECK(minor_to_major.size() == dims.size() && - std::is_permutation(minor_to_major.begin(), minor_to_major.end(), - dims.begin())); - } - Shape stripped_shape; - if (LayoutUtil::IsDenseArray(shape)) { - stripped_shape = MakeShapeWithLayout(shape.element_type(), dimension_sizes, - minor_to_major); - } else if (LayoutUtil::IsSparseArray(shape)) { - stripped_shape = - MakeShapeWithSparseLayout(shape.element_type(), dimension_sizes, - shape.layout().max_sparse_elements()); - } else { - stripped_shape = MakeShape(shape.element_type(), dimension_sizes); - } - - VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape); - VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape); - return stripped_shape; + }); + return leaves; } namespace { @@ -1028,6 +1044,9 @@ Status ForEachMutableSubshapeHelper( /* static */ std::tuple, std::vector> ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, const Shape& shape_post) { + CHECK(IsArray(shape_pre)); + CHECK(IsArray(shape_post)); + auto nil = std::make_tuple(false, std::vector(), std::vector()); std::vector deleted_indices; @@ -1085,6 +1104,9 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre, /* static */ std::vector> ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, const Shape& output_shape) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + // Unmodified dimensions are merely common factors of rank 1. auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()), AsInt64Slice(output_shape.dimensions())); @@ -1138,8 +1160,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - CHECK(LayoutUtil::HasLayout(input_shape) && - LayoutUtil::HasLayout(output_shape)); + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + CHECK(LayoutUtil::HasLayout(input_shape)); + CHECK(LayoutUtil::HasLayout(output_shape)); if (!SameElementType(input_shape, output_shape)) { return false; @@ -1301,6 +1325,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ tensorflow::gtl::optional ShapeUtil::AlignLayouts( const Shape& input_shape, const Shape& output_shape) { + CHECK(IsArray(input_shape)); + CHECK(IsArray(output_shape)); + int64 input_rank = Rank(input_shape); int64 output_rank = Rank(output_shape); @@ -1435,6 +1462,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete, Shape shape) { + CHECK(IsArray(shape)); shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); @@ -1456,6 +1484,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { + CHECK(IsArray(shape)); std::vector dims_to_delete; for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) { if (!p(i)) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 82c75f85d838f94cb040e56d59d0e012af5e0db0..ae2d17d6bbbfed96e1da192253838ae5e9a67e17 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -61,6 +62,8 @@ class ShapeIndex { public: ShapeIndex() = default; ShapeIndex(std::initializer_list init) : indices_(init) {} + template + ShapeIndex(InputIt start, InputIt end) : indices_(start, end) {} bool empty() const { return indices_.empty(); } size_t size() const { return indices_.size(); } @@ -131,6 +134,10 @@ class ShapeIndexView { ++new_begin; return ShapeIndexView(new_begin, end_); } + ShapeIndex ToShapeIndex() const { return ShapeIndex(begin_, end_); } + + bool operator==(const ShapeIndexView& other) const; + bool operator!=(const ShapeIndexView& other) const; string ToString() const; @@ -150,12 +157,22 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); // properties, which do invariant checks before / after the operation. class ShapeUtil { public: + // Data structure which describes the coordinates and the shape, of a tuple + // shaped sub-shape. + struct IndexedShape { + IndexedShape() = default; + IndexedShape(ShapeIndex index, Shape shape) + : index(std::move(index)), shape(std::move(shape)) {} + ShapeIndex index; + Shape shape; + }; + // Returns the number of elements are contained within the provided shape; // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes // may not actually be able to store this number of elements. See // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of // elements that can be stored in a sparse shape. - // Precondition: !IsTuple(shape) + // Precondition: IsArray(shape) static int64 ElementsIn(const Shape& shape); // Returns true if 'shape' has zero elements. @@ -166,13 +183,11 @@ class ShapeUtil { // shapes. This includes only the size of the top-level buffer. For example, a // tuple is stored as an array of pointers to other buffers. In this case, // this method only returns the size of the pointer array. - // Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) && - // !ShapeUtil::IsOpaque(shape) static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1); // Returns the number of bytes used to store the primitive_type. // - // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) + // Precondition: ShapeUtil::IsArray(shape) static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); // Returns the number of bytes required to store the tuple member pointers for @@ -231,7 +246,7 @@ class ShapeUtil { } // Returns the higher-precision element type if a and b are both floating - // point types; otherwise, checks that they have the same element type + // point types; otherwise, checks that that they have the same element type // and returns it. static PrimitiveType HigherPrecisionElementType(const Shape& a, const Shape& b) { @@ -279,10 +294,10 @@ class ShapeUtil { // Scalar-specific static bool IsScalar(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0; + return IsArray(shape) && Rank(shape) == 0; } static bool IsEffectiveScalar(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0; + return IsArray(shape) && TrueRank(shape) == 0; } static bool IsScalarF32(const Shape& shape); @@ -311,6 +326,10 @@ class ShapeUtil { // into a custom operation. static Shape MakeOpaqueShape(); + // Creates a token shape. Values of this shape are used for ordering + // side-effecting operations. + static Shape MakeTokenShape(); + // Appends a shape to the given tuple. static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape); @@ -410,11 +429,15 @@ class ShapeUtil { return shape.element_type() == OPAQUE; } + // Returns whether the shape is an token value used for ordering + // side-effecting operations. + static bool IsToken(const Shape& shape) { + return shape.element_type() == TOKEN; + } + // Returns whether the shape is an array. Note that scalars are considered // arrays. - static bool IsArray(const Shape& shape) { - return !IsTuple(shape) && !IsOpaque(shape); - } + static bool IsArray(const Shape& shape); // Returns whether the shape is a tuple with at least one element which is // also a tuple. @@ -461,6 +484,13 @@ class ShapeUtil { // shape. static bool IsLeafIndex(const Shape& shape, const ShapeIndex& index); + // Returns the number of leaves in the shape. + static int64 GetLeafCount(const Shape& shape); + + // Retrieves all the leaf shapes and their indexes, in the order walked by + // the ForEachSubshape() API. + static std::vector GetLeafShapes(const Shape& shape); + // Calls the given visitor function for each subshape of the given shape. // Subshapes are visited in DFS pre-order starting with the entire shape // (index {}). @@ -483,26 +513,6 @@ class ShapeUtil { static Status ForEachMutableSubshapeWithStatus( Shape* shape, const MutatingStatusVisitorFunction& func); - // Removes all degenerate dimensions (size one) from the given shape. The - // stripped minor_to_major preserves the relative ordering of non-degenerate - // dimensions. The stripped shape has the property that the underlying - // representation (bits in memory) for the stripped shape is the same as the - // original shape modulo padding. Examples: - // - // input shape: F32 [1, 2, 1], minor_to_major = {0, 1, 2} - // stripped shape: F32 [2], minor_to_major = {0} - // - // input shape: F32 [6, 1, 5], minor_to_major = {2, 0, 1} - // stripped shape: F32 [6, 5], minor_to_major = {1, 0} - // - // input shape: F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1} - // stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1} - // - // input shape: F32 [1, 1], minor_to_major = {0, 1} - // stripped shape: F32 [], minor_to_major = {} - // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) - static Shape StripDegenerateDimensions(const Shape& shape); - // Permutes the dimensions by the given permutation, so // return_value.dimensions[permutation[i]] = argument.dimensions[i] static Shape PermuteDimensions(tensorflow::gtl::ArraySlice permutation, @@ -626,6 +636,28 @@ class ShapeUtil { .IgnoreError(); } + // These convenience wrappers don't take `base`, `count` and `incr` + // explicitly, but iterate over every element in `shape` instead. + + template + static Status ForEachIndexWithStatus(const Shape& shape, + const FnType& visitor_function) { + std::vector base(shape.dimensions_size()); + std::vector incr(shape.dimensions_size(), 1); + return ForEachIndexWithStatus(shape, base, + /*count=*/AsInt64Slice(shape.dimensions()), + incr, visitor_function); + } + + template + static void ForEachIndex(const Shape& shape, const FnType& visitor_function) { + ForEachIndexWithStatus(shape, + [&](tensorflow::gtl::ArraySlice indices) { + return StatusOr(visitor_function(indices)); + }) + .IgnoreError(); + } + // A parallel version of ForEachIndex(WithStatus). This can only be used if // the visitor_function is thread-safe and the order of iteration does not // matter. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index f7675e97da7b061bde063e5093256c2288f99c98..0ff514564bdb27b7afa4cf99b0d727f2c029a5ae 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -93,12 +93,14 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { } TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { - string shape_string = "(f32[1],(f32[2]), f32[3])"; + string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])"; TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeTupleShape({ ShapeUtil::MakeShape(F32, {1}), - ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeShape(F32, {3}), }); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) @@ -136,6 +138,23 @@ TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST(ShapeUtilTest, ParseOpaqueType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString("opaque[]")); + Shape expected = ShapeUtil::MakeOpaqueShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseTokenType) { + TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]")); + Shape expected = ShapeUtil::MakeTokenShape(); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + TEST(ShapeUtilTest, ParseInvalidShapeString) { string shape_strings[] = { "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", @@ -295,6 +314,9 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64)); EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {}))); EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20}))); + + EXPECT_EQ(0, ShapeUtil::ByteSizeOfPrimitiveType(TOKEN)); + EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape())); } TEST(ShapeUtilTest, ByteSizeOfWithPadding) { @@ -449,19 +471,21 @@ TEST(ShapeUtilTest, IsLeafIndex) { TEST(ShapeUtilTest, HumanString) { Shape opaque = ShapeUtil::MakeOpaqueShape(); + Shape token = ShapeUtil::MakeTokenShape(); Shape scalar = ShapeUtil::MakeShape(F32, {}); Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); - Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix}); + Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token}); EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque)); + EXPECT_EQ("token[]", ShapeUtil::HumanString(token)); EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar)); EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix)); EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2)); EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", ShapeUtil::HumanString(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(nested_tuple)); EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque)); @@ -470,8 +494,10 @@ TEST(ShapeUtilTest, HumanString) { EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2)); EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})", ShapeUtil::HumanStringWithLayout(tuple)); - EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})", - ShapeUtil::HumanStringWithLayout(nested_tuple)); + EXPECT_EQ( + "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " + "token[])", + ShapeUtil::HumanStringWithLayout(nested_tuple)); ProgramShape prog = ShapeUtil::MakeProgramShape( {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); @@ -481,8 +507,9 @@ TEST(ShapeUtilTest, HumanString) { "(unknown): u32[1,2], " "(unknown): s32[3,4], " "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " - "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(prog)); prog.add_parameter_names("arg0"); @@ -497,8 +524,10 @@ TEST(ShapeUtilTest, HumanString) { "matrix: u32[1,2], " "matrix2: s32[3,4], " "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " - "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])", + "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " + "token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", ShapeUtil::HumanString(prog)); } @@ -713,16 +742,6 @@ TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) { ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1}))); } -TEST(ShapeUtilTest, StripDegenerateDimensions) { - EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::StripDegenerateDimensions( - ShapeUtil::MakeShape(F32, {3, 1, 2})), - ShapeUtil::MakeShape(F32, {3, 2}))); - EXPECT_TRUE(ShapeUtil::Equal( - ShapeUtil::StripDegenerateDimensions( - ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 1, 2}, 10)), - ShapeUtil::MakeShapeWithSparseLayout(F32, {3, 2}, 10))); -} - TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast( ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}), diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 4883380be1f8a291bda829dff713de549ba58c65..e7e0a19db0516e4210f6bb78d6b5e6968bf78b2a 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -117,11 +117,11 @@ cc_library( "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend "//tensorflow/compiler/xla/service:platform_util", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -138,8 +138,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_verifier", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -619,6 +619,7 @@ xla_test( xla_test( name = "exhaustive_f32_elementwise_op_test", + size = "enormous", srcs = ["exhaustive_f32_elementwise_op_test.cc"], backends = [ "cpu", @@ -626,7 +627,6 @@ xla_test( ], shard_count = 48, tags = [ - "enormous", "manual", "notap", ], @@ -697,8 +697,8 @@ xla_test( "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -776,30 +776,42 @@ xla_test( ], ) +CONVOLUTION_TEST_DEPS = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", +] + xla_test( name = "convolution_test", timeout = "long", srcs = ["convolution_test.cc"], shard_count = 25, - deps = [ - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array4d", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/client/xla_client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], + deps = CONVOLUTION_TEST_DEPS, +) + +xla_test( + name = "convolution_test_gpu_alternative_layout", + timeout = "long", + srcs = ["convolution_test.cc"], + backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, + backends = ["gpu"], + shard_count = 25, + deps = CONVOLUTION_TEST_DEPS, ) xla_test( @@ -1183,9 +1195,25 @@ xla_test( ], deps = [ ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "token_hlo_test", + srcs = ["token_hlo_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1508,11 +1536,11 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/client/xla_client:xla_computation", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 34c86e007beea1cbac04641bdbdab62dc567f13e..3a0f51fc66d65c8684bd607b9e8103559cd4d8d4 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -671,7 +671,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("op BINOP_ADD with incompatible shapes")); + HasSubstr("op add with incompatible shapes")); } XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { @@ -684,7 +684,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { auto result_status = Execute(&b, {}); EXPECT_FALSE(result_status.ok()); EXPECT_THAT(result_status.status().error_message(), - HasSubstr("op BINOP_ADD with incompatible shapes")); + HasSubstr("op add with incompatible shapes")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 4ef0a77884c90b9fe32f96d3361fa3d80bde623b..722d882471a41a75c1e5e60f8c1a151b76c7e004 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -249,10 +249,10 @@ XLA_TEST_F(ConvertTest, ConvertR1F32ToR1S64) { -1.99f, -2.0f, -2.01f, - 0x1.FFFFFEp+62F, - 0x1.FFFFFCp+62F, - -0x1.FFFFFEp+62F, - -0x1.FFFFFCp+62F}; + 9223371487098961920.f, + 9223370937343148032.f, + -9223371487098961920.f, + -9223370937343148032.f}; std::unique_ptr arg_literal = Literal::CreateR1({arg}); auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); std::unique_ptr arg_data = diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 947959beb144e1509a77ad2f94b8493de46ba6f2..346bb3a3996ee5bf662b0f74dd0c2096efbf5295 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -47,9 +47,9 @@ class ConvolutionTest : public ClientLibraryTestBase { #if XLA_TEST_BACKEND_GPU // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial // convolution. So relax the absolute error threshold. - ErrorSpec error_spec_ = ErrorSpec(1e-2); + ErrorSpec error_spec_ = ErrorSpec(1e-2, 1e-4); #else - ErrorSpec error_spec_ = ErrorSpec(1e-4); + ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-4); #endif }; diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index b15988776513a60c9e5c85d4780912106db98e75..b151187c4b8f01c5b46ccadf27d2e22a7c902e98 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" namespace xla { namespace { @@ -32,11 +32,19 @@ class TrivialCrossReplicaSumTest : public HloTestBase {}; XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { const char* module_str = R"( HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + ENTRY test_computation { p = f32[3] parameter(0) - ROOT crs = f32[3] cross-replica-sum(p) + ROOT crs = f32[3] cross-replica-sum(p), to_apply=add })"; - auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal = Literal::CreateR1({1, 2, 3}); EXPECT_EQ(*literal, *ExecuteAndTransfer(std::move(module), {literal.get()})); } @@ -44,12 +52,20 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { const char* module_str = R"( HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] parameter(1) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add })"; - auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = Literal::CreateR1({1, 2, 3}); auto literal1 = Literal::CreateR1({10, 20}); EXPECT_EQ( @@ -63,12 +79,20 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { const char* module_str = R"( HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] constant({10, 20}) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add })"; - auto module = tools::Parse(module_str, GetModuleConfigForTest()).ValueOrDie(); + auto module = + ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); auto literal0 = Literal::CreateR1({1, 2, 3}); auto literal1 = Literal::CreateR1({10, 20}); EXPECT_EQ(*Literal::MakeTuple({literal0.get(), literal1.get()}), diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 4854c649c15f2ab89bd3b343abd248be6e227c60..143ffbdeb409d91ab6d46d386aa5ff98ebc4ae10 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -14,12 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" // NB! TODO(b/74360564): These tests do not test out of bounds behavior since // that hasn't been specced yet. @@ -41,7 +41,7 @@ class GatherOperationTest : public HloTestBase { HloModuleConfig config; config.set_debug_options(GetDebugOptionsForTest()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - tools::Parse(hlo_text, config)); + ParseHloString(hlo_text, config)); EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt)); } }; diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 36e19e6507fa3b6f4a21949583f92716d2f44333..242cc5db11ff2bdf69209df7537216573d8afbf3 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -94,8 +94,7 @@ HloTestBase::HloTestBase(se::Platform* test_platform, /* static */ std::unique_ptr HloTestBase::CreateNewModule(const string& name) { - return MakeUnique(name, VersionedComputationHandle(), - GetModuleConfigForTest()); + return MakeUnique(name, GetModuleConfigForTest()); } /*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index eb3a2ea76a667a2afa2562f01d28f34384b84a21..249da87f489324ed9d377cc46a15cef5a9e74192 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -66,6 +66,15 @@ namespace xla { // // For a more detailed example, see "../tests/sample_text_test.cc". class HloTestBase : public ::testing::Test { + public: + // Creates a new HLO module for a test. The module created will have + // TestName() for its name; it will also automatically populate its debug + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. + static std::unique_ptr CreateNewModule( + const string& name = TestName()); + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the @@ -80,14 +89,6 @@ class HloTestBase : public ::testing::Test { ~HloTestBase() override {} - // Creates a new HLO module for a test. The module created will have - // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. If you want a fresh HloModule object and - // then add HloComputations to it, it's recommended to use this method in your - // tests. - static std::unique_ptr CreateNewModule( - const string& name = TestName()); - // Populates debug options from command-line flags and adjusts the options for // testing. It is recommended to use this when you need to pass in // DebugOptions, e.g. when creating a module from a string or a file. diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index da4cf4ae0c31bc194cd2ec9b845df36afbde69b0..22c664d1426c598dbb695ff1b66ce009b0a19c00 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -41,14 +41,17 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - VerifyModule(); + VerifyModule(module_.get()); + } + for (int i = 0; i < modules_.size(); ++i) { + VerifyModule(modules_.at(i).get()); } HloTestBase::TearDown(); } -void HloVerifiedTestBase::VerifyModule() { - HloVerifier verifier; - xla::StatusOr mutated = verifier.Run(module_.get()); +void HloVerifiedTestBase::VerifyModule(HloModule* module) { + HloVerifier verifier(/*allow_mixed_precision=*/true); + xla::StatusOr mutated = verifier.Run(module); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); } else { @@ -59,15 +62,20 @@ void HloVerifiedTestBase::VerifyModule() { HloModule& HloVerifiedTestBase::module() { if (!module_) { - module_ = CreateNewModule(); + module_ = HloTestBase::CreateNewModule(); } return *module_; } +HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) { + modules_.emplace_back(HloTestBase::CreateNewModule()); + return modules_.back().get(); +} + void HloVerifiedTestBase::ParseAndVerifyModule( tensorflow::StringPiece hlo_text) { CHECK(!module_) << "Called ParseModule when test already has a module."; - TF_ASSERT_OK_AND_ASSIGN(module_, tools::Parse(hlo_text)); - VerifyModule(); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text)); + VerifyModule(module_.get()); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index e5bb14a8839acbdef8fd2b79bb0f574c46ea3d40..5b59cc77f61b05092d3afb331e73932c9edc5840 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -52,11 +52,23 @@ class HloVerifiedTestBase : public HloTestBase { shape_verifier_ = std::move(shape_verifier); } + // Creates a new module for a test, and stores it in modules_ so it can be + // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent + // creation of unverified modules. + HloModule* CreateNewModule(const string& name = TestName()); + + // It is confusing to store modules created by module() and CreateNewModule() + // in different fields, but it allows us to migrate tests to + // HloVerifiedTestBase more easily, so it's a win because we can verify more + // modules. See b/80488902. private: - std::unique_ptr module_; // Lazily populated. Access via module(). + // Lazily populated. Access via module(). + std::unique_ptr module_; + // Populated by calls to CreateNewModule. + std::vector> modules_; std::unique_ptr shape_verifier_; bool tear_down_called_ = false; - void VerifyModule(); + static void VerifyModule(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 2f46ee0be216d7dabf1c476d3cfb7d528f8ab6a4..082bc34136e004795ce300c66591758f47c665fe 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -124,8 +124,7 @@ class LLVMCompilerTest : public ::testing::Test { static std::unique_ptr CreateNewModule() { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - return MakeUnique(TestName(), VersionedComputationHandle(), - config); + return MakeUnique(TestName(), config); } }; diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 7df45bebebdd3eb2e71f27d831a8e2ac9e3b5f7c..3975e9125703ee081d4e84fa8bd27fcbe483ac34 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -488,10 +488,9 @@ TEST_F(MapTest, MapOperantionWithBuildError) { StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); - EXPECT_THAT( - computation_status.status().ToString(), - ::testing::HasSubstr("error from: ErrorAdd: Binary op BINOP_ADD with " - "different element types: f32[] and u16[]")); + EXPECT_THAT(computation_status.status().ToString(), + ::testing::HasSubstr("error from: ErrorAdd: Binary op add with " + "different element types: f32[] and u16[]")); } // MapTest disables inline and algsimp. MapTestWithFullOpt runs all diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index ec7ca20bdf266cf8ed220809c0c24bee473359be..41f723edf1ff3518686231f31b61b64291b1f6bf 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -273,5 +273,246 @@ XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { *result, *Literal::CreateR1({0.0, 4.0, 9.0}))); } +const char* const kScalarOps = R"( + HloModule m + + Add { + lhsadd = f32[] parameter(0) + rhsadd = f32[] parameter(1) + ROOT add = f32[] add(lhsadd, rhsadd) + } + + Max { + lhsmax = f32[] parameter(0) + rhsmax = f32[] parameter(1) + ROOT max = f32[] maximum(lhsmax, rhsmax) + } +)"; + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMinor)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned(Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMajor)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned( + Literal::CreateR2({{6, 8}, {10, 12}}), + Literal::CreateR2({{25, 36}, {49, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionScalar)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(1.17549e-38) + r2 = f32[2]{0} reduce(mul, c1), dimensions={0,2}, to_apply=Max + r3 = f32[2]{0} reduce(mul, c0), dimensions={0,2}, to_apply=Add + ROOT tuple = (f32[2]{0}, f32[2]{0}, f32[2]{0}) tuple(r1, r2, r3) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2]{0}, f32[2]{0}, f32[2]{0}) fusion(p), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned(Literal::CreateR1({14, 22}), + Literal::CreateR1({36, 64}), + Literal::CreateR1({66, 138})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMinorWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) + tuple(p0, r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2,2]{2,1,0}, f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned( + Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}), + Literal::CreateR2({{3, 7}, {11, 15}}), + Literal::CreateR2({{5, 16}, {36, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionMajorWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2,2]{1,0} reduce(p0, c0), dimensions={0}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + r2 = f32[2,2]{1,0} reduce(mul, c1), dimensions={0}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) + tuple(r1, mul, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2,2]{2,1,0}, f32[2,2]{1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned( + Literal::CreateR2({{6, 8}, {10, 12}}), + Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + Literal::CreateR2({{25, 36}, {49, 64}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionScalarWithExtraOutput)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + c0 = f32[] constant(0) + r1 = f32[2]{0} reduce(p0, c0), dimensions={0,2}, to_apply=Add + mul = f32[2,2,2]{2,1,0} multiply(p0, p0) + c1 = f32[] constant(5) + mul2 = f32[2,2,2]{2,1,0} multiply(p0, c1) + ROOT tuple = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) + tuple(r1, mul, mul2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + ROOT fusion = (f32[2]{0}, f32[2,2,2]{2,1,0}, f32[2,2,2]{2,1,0}) fusion(p), + kind=kInput, calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, + *Literal::MakeTupleOwned( + Literal::CreateR1({14, 22}), + Literal::CreateR3({{{1, 4}, {9, 16}}, {{25, 36}, {49, 64}}}), + Literal::CreateR3( + {{{5, 10}, {15, 20}}, {{25, 30}, {35, 40}}})))); +} + +XLA_TEST_F(MultiOutputFusionTest, + DISABLED_ON_CPU(MultiOutputReduceFusionNonConstInit)) { + const string testcase = tensorflow::strings::StrCat(kScalarOps, R"( + fused_reduce { + p0 = f32[2,2,2]{2,1,0} parameter(0) + init1 = f32[] parameter(1) + init2 = f32[] parameter(2) + r1 = f32[2,2]{1,0} reduce(p0, init1), dimensions={2}, to_apply=Add + r2 = f32[2,2]{1,0} reduce(p0, init2), dimensions={2}, to_apply=Max + ROOT tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(r1, r2) + } + + ENTRY reduce { + p = f32[2,2,2]{2,1,0} parameter(0) + i = f32[] parameter(1) + j = f32[] parameter(2) + ROOT fusion = (f32[2,2]{1,0}, f32[2,2]{1,0}) fusion(p, i, j), kind=kInput, + calls=fused_reduce + })"); + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::CreateR3({{{0, 2}, {3, 4}}, {{5, 6}, {7, 8}}}); + auto init1 = Literal::CreateR0(5); + auto init2 = Literal::CreateR0(6); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + Execute(std::move(module), {param.get(), init1.get(), init2.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned( + Literal::CreateR2({{167, 172}, {176, 180}}), + Literal::CreateR2({{6, 6}, {6, 8}})))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc index c0a2c0ca4cb8414e0771a541b9f963f9aedc8376..9052b188ed09a715b6ad7c3a40dc853d02cdd70c 100644 --- a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -15,9 +15,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" -#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -73,7 +73,7 @@ ENTRY reduce.1 { } )"; - return tools::Parse(hlo_string); + return ParseHloString(hlo_string); } // TODO(b/72454718): XLA:GPU does not support executing code compiled without diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 52195db2aa74710b901dd7744a670764a034e96b..5653bf11a7364bf9ed79bcb6b53f7db31f454803 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -197,9 +197,10 @@ class SliceR1Test : public ClientLibraryTestBase, // vector. tensorflow::gtl::InlinedVector input(spec.input_dim0); std::iota(input.begin(), input.end(), NativeT()); + auto literal = Literal::CreateR1(input); XlaBuilder builder(TestName()); - auto original = builder.ConstantR1(input); + auto original = builder.Parameter(0, literal->shape(), "p0"); builder.Slice(original, {spec.slice_start}, {spec.slice_limit}, {spec.slice_stride}); @@ -210,7 +211,9 @@ class SliceR1Test : public ClientLibraryTestBase, expected.push_back(i); } - ComputeAndCompareR1(&builder, expected, {}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, + client_->TransferToServer(*literal)); + ComputeAndCompareR1(&builder, expected, {arg.get()}); } }; @@ -365,15 +368,18 @@ XLA_TEST_P(SliceR2Test, DoIt) { const R2Spec& spec = GetParam(); Array2D input(spec.input_dim0, spec.input_dim1); input.FillUnique(); + auto literal = Literal::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(spec.layout)); XlaBuilder builder(TestName()); - auto a = builder.ConstantR2FromArray2DWithLayout( - input, LayoutUtil::MakeLayout(spec.layout)); + auto a = builder.Parameter(0, literal->shape(), "p0"); builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr arg, + client_->TransferToServer(*literal)); std::unique_ptr> expected = ReferenceUtil::Slice2D( input, spec.slice_starts, spec.slice_limits, spec.slice_strides); - ComputeAndCompareR2(&builder, *expected, {}); + ComputeAndCompareR2(&builder, *expected, {arg.get()}); } INSTANTIATE_TEST_CASE_P( @@ -453,7 +459,7 @@ class SliceR4Test : public ClientLibraryTestBase, void Run(const R4Spec& spec) { Array4D values(spec.input_dims[0], spec.input_dims[1], spec.input_dims[2], spec.input_dims[3]); - values.FillRandom(3.14f); + values.FillIota(3.14159); auto expected = ReferenceUtil::Slice4D( values, spec.slice_starts, spec.slice_limits, spec.slice_strides); XlaBuilder builder(TestName()); diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index de1865138802bc72e9a4b2db7a21343b0d327108..dd7c541733634213606b5a7983b59bb1f14bf75c 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -26,6 +26,7 @@ namespace { template void PopulateWithRandomFloatingPointDataImpl(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); // Create uniform numbers between 1 and 1.125 to avoid creating denormal @@ -59,12 +60,14 @@ void PopulateWithRandomFloatingPointDataImpl(Literal* literal, template void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); PopulateWithRandomFloatingPointDataImpl(literal, engine); } template <> void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); PopulateWithRandomFloatingPointDataImpl(literal, engine); } @@ -73,6 +76,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal, template <> void PopulateWithRandomFloatingPointData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), BF16); std::uniform_real_distribution generator(-0.9f, 1.0f); TF_CHECK_OK(literal->Populate( @@ -84,6 +88,7 @@ void PopulateWithRandomFloatingPointData(Literal* literal, template void PopulateWithRandomIntegralData(Literal* literal, std::minstd_rand0* engine) { + CHECK(engine != nullptr); CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); std::uniform_int_distribution generator( @@ -107,6 +112,9 @@ StatusOr> MakeFakeLiteralInternal( } return Literal::MakeTupleOwned(std::move(elements)); } + if (engine == nullptr) { + return Literal::CreateFromShape(shape); + } auto literal = MakeUnique(shape); switch (shape.element_type()) { case BF16: @@ -201,11 +209,13 @@ std::unique_ptr MakeRandomNonwrappingSliceIndex( std::minstd_rand0* engine) { const int64 rank = ShapeUtil::Rank(input_shape); std::vector start_indices(rank); - for (int i = 0; i < rank; ++i) { - const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - - ShapeUtil::GetDimension(slice_shape, i); - std::uniform_int_distribution generator(0, upper_bound); - start_indices[i] = generator(*engine); + if (engine != nullptr) { + for (int i = 0; i < rank; ++i) { + const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); + std::uniform_int_distribution generator(0, upper_bound); + start_indices[i] = generator(*engine); + } } return Literal::CreateR1(start_indices); } @@ -321,20 +331,21 @@ StatusOr> MakeConstrainedArgument( } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape) { - std::minstd_rand0 engine; - return MakeFakeLiteralInternal(shape, &engine); +StatusOr> MakeFakeLiteral(const Shape& shape, + bool pseudo_random) { + auto engine = pseudo_random ? MakeUnique() : nullptr; + return MakeFakeLiteralInternal(shape, engine.get()); } StatusOr>> MakeFakeArguments( - HloModule* const module) { + HloModule* const module, bool pseudo_random) { TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); - std::minstd_rand0 engine; + auto engine = pseudo_random ? MakeUnique() : nullptr; std::vector> arguments(params.size()); for (int i = 0; i < params.size(); ++i) { - TF_ASSIGN_OR_RETURN( - arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine)); + TF_ASSIGN_OR_RETURN(arguments[i], MakeConstrainedArgument( + *dataflow, *params[i], engine.get())); } return std::move(arguments); } diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index f483cdebea5c7c8a43e73ab57748a93c97bb78d7..a8689f64981569ceb7c8a712f8ece00c99e8cf2d 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -55,16 +55,28 @@ class PseudorandomGenerator { }; // Generates fake data in a literal of the given shape, or returns an error -// status if the element type is currently unhandled for fake data generation. -StatusOr> MakeFakeLiteral(const Shape& shape); +// status if the element type is currently unhandled for fake data +// generation. See below for documentation of pseudo_random. +StatusOr> MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. // // Will handle special cases such as making sure that indices used for dynamic // slices are bounded, reduces that call adds use 0 as an init value, etc. +// +// If pseudo_random is true, the generated numbers will be generated +// deterministically in a pseudo random way unless the values are constrated to +// be e.g. init values as above. If pseudo_random is false, the returned values +// will be generated in a faster way that yields less interesting data, e.g. the +// values may all be just the same value. +// +// TODO(b/79942829): Make interesting argument generation fast enough that using +// pseudo_random does not save any noticeable amount of time so that the +// parameter can be removed. StatusOr>> MakeFakeArguments( - HloModule* const module); + HloModule* const module, bool pseudo_random = true); // Check that a given module satisfies various constraints before trying to // execute it. diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4585244ce81c14ab6d4d629bb7d208d73c82248d --- /dev/null +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class TokenHloTest : public HloTestBase {}; + +// TODO(b/79770375): Compile, not just verify the HLO module when the backends +// support kGenerateToken. +XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + module->AddEntryComputation(builder.Build()); + EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); +} + +XLA_TEST_F(TokenHloTest, TokenTree) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto token0 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token1 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token2 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + builder.AddInstruction( + HloInstruction::CreateGenerateToken({token0, token0, token1, token2})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + module->AddEntryComputation(builder.Build()); + EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 1 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateParameter( + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}), + "param")); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenRoot) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Entry root is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { + std::unique_ptr module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction(HloInstruction::CreateGenerateToken({param})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr( + "Operands of token instructions must be TOKEN types")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 415cf9c16a2613913265d0342e5ab9932de5eb19..e4a052c8f1c0009619c3a94606f6384d04006e4e 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -85,7 +85,9 @@ cc_library( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", @@ -134,7 +136,7 @@ tf_cc_binary( deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/core:lib", ], ) @@ -163,7 +165,6 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", - "//tensorflow/compiler/xla/service:computation_tracker", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/tools/convert_computation.cc b/tensorflow/compiler/xla/tools/convert_computation.cc index fe03a6e7bdfe99877c250fe1ae22beee4c8018a2..14d01b5bfb067cc39abc4d6e0605007624b6e0ae 100644 --- a/tensorflow/compiler/xla/tools/convert_computation.cc +++ b/tensorflow/compiler/xla/tools/convert_computation.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/env.h" @@ -33,7 +33,7 @@ namespace xla { namespace tools { void RealMain(const string& mode, const string& path) { - SessionModule module; + HloSnapshot module; tensorflow::Env* env = tensorflow::Env::Default(); if (mode == "txt2bin") { TF_CHECK_OK(tensorflow::ReadTextProto(env, path, &module)); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index b815bbf854b82b323da7879c230a1026cae96625..5dd5150be339846d0775880931f615b92c5b08d8 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/service/computation_tracker.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/statusor.h" diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD deleted file mode 100644 index 0fa4b98d0a41a1e7c681bb2302da3b752315867b..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ /dev/null @@ -1,72 +0,0 @@ -# Build file for the Hlo parser. - -licenses(["notice"]) # Apache 2.0 - -package( - default_visibility = [":friends"], -) - -package_group( - name = "friends", - includes = [ - "//tensorflow/compiler/xla:friends", - ], -) - -# Filegroup used to collect source files for dependency checking. -filegroup( - name = "c_srcs", - data = glob([ - "**/*.cc", - "**/*.h", - ]), -) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -cc_library( - name = "hlo_lexer", - srcs = ["hlo_lexer.cc"], - hdrs = [ - "hlo_lexer.h", - "hlo_token.h", - ], - deps = [ - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - ], -) - -cc_library( - name = "hlo_parser", - srcs = ["hlo_parser.cc"], - hdrs = ["hlo_parser.h"], - deps = [ - ":hlo_lexer", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - -tf_cc_test( - name = "hlo_parser_test", - size = "small", - srcs = ["hlo_parser_test.cc"], - deps = [ - ":hlo_parser", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index df0501386c1e4de9111fbb6b2d9e8ec372dbf41e..f7574e0b1cc95daee6d6743ba4e2e490ee87e7c6 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -24,6 +24,9 @@ limitations under the License. // passing --use_fake_data on the command line. If the real data is available // in the proto and --use_fake_data is false, the real data is used. // +// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a +// textual HLO string. +// // The output format is: // // file_path: computation_name :: type:literal_str @@ -41,7 +44,9 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -63,105 +68,177 @@ namespace { // fields. struct Options { string fake_infeed_shape; + bool generate_fake_infeed = false; bool use_fake_data = false; bool print_result = true; int num_runs = 1; - bool xla_hlo_profile_last_run = false; }; // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // -// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; -// otherwise, no infeed is performed. -StatusOr> ReplayComputation(const HloSnapshot& module, - Client* client, - const Options& opts) { - TF_ASSIGN_OR_RETURN(auto computation, client->LoadSnapshot(module)); - - std::vector> arguments; +// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided. +// If generate_fake_infeed is true, the required infeed shape is derived from +// the computation and then used to provide a fake infeed shape. +// +// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided, +// no infeed is performed. +StatusOr ReplayComputation(const HloSnapshot& module, + LocalClient* client, const Options& opts) { + XlaComputation computation(module.hlo().hlo_module()); + + // Build the `argument_ptrs` vector, which contains ShapedBuffer*s to our + // arguments. This is a bit involved, because we may have to convert from + // GlobalData to ShapedBuffer*, and we have to manage the lifetime of all our + // objects. + std::vector scoped_shaped_buffer_arguments; + std::vector> global_data_arguments; + std::vector argument_ptrs; if (opts.use_fake_data) { - arguments = MakeFakeArgumentsOrDie(computation, client); + global_data_arguments = MakeFakeArgumentsOrDie(computation, client); + for (const auto& data : global_data_arguments) { + argument_ptrs.push_back( + client->GlobalDataToShapedBuffer(data->handle(), /*device_ordinal=*/0) + .ValueOrDie()); + } } else { // use recorded data if available for (const auto& proto : module.arguments()) { TF_ASSIGN_OR_RETURN(std::unique_ptr literal, Literal::CreateFromProto(proto)); - TF_ASSIGN_OR_RETURN(std::unique_ptr data, - client->TransferToServer(*literal)); - arguments.push_back(std::move(data)); + TF_ASSIGN_OR_RETURN( + ScopedShapedBuffer data, + client->LiteralToShapedBuffer(*literal, /*device_ordinal=*/0)); + scoped_shaped_buffer_arguments.push_back(std::move(data)); + } + for (const auto& argument : scoped_shaped_buffer_arguments) { + argument_ptrs.push_back(&argument); } } + bool provide_infeed = false; + Shape infeed_shape; + if (!opts.fake_infeed_shape.empty()) { + StatusOr shape_status = + ShapeUtil::ParseShapeString(opts.fake_infeed_shape); + TF_CHECK_OK(shape_status.status()); + infeed_shape = std::move(shape_status).ValueOrDie(); + provide_infeed = true; + } else if (opts.generate_fake_infeed) { + for (const auto& comp : computation.proto().computations()) { + for (const auto& instruction : comp.instructions()) { + if (instruction.opcode() == HloOpcodeString(HloOpcode::kInfeed)) { + CHECK(!provide_infeed) + << "--generate_fake_infeed only works if the model has 0 or 1 " + "infeed ops, but this one has >= 2."; + provide_infeed = true; + infeed_shape = instruction.shape(); + LOG(INFO) << "Generating fake infeed shape for inferred shape: " + << ShapeUtil::HumanString(infeed_shape); + } + } + } + } // We only instantiate the thread pool if the user has requested that a - // concurrent infeed occur via the fake_infeed_shape. + // concurrent infeed occur via the fake_infeed_shape, or when + // --generate_fake_infeed is passed and there exists an infeed operation in + // the HloSnapshot. tensorflow::gtl::optional pool; - - if (!opts.fake_infeed_shape.empty()) { + std::unique_ptr data; + if (provide_infeed) { + data = std::move(MakeFakeLiteral(infeed_shape)).ValueOrDie(); + } + auto transfer_infeed = [&data, client]() { + TF_CHECK_OK(client->TransferToInfeed(*data)); + }; + if (provide_infeed) { pool.emplace(tensorflow::Env::Default(), "infeed", /*num_threads=*/1); - pool->Schedule([opts, client]() { - StatusOr shape_status = - ShapeUtil::ParseShapeString(opts.fake_infeed_shape); - TF_CHECK_OK(shape_status.status()); - Shape shape = std::move(shape_status).ValueOrDie(); - StatusOr> data_status = MakeFakeLiteral(shape); - TF_CHECK_OK(data_status.status()); - std::unique_ptr data = std::move(data_status).ValueOrDie(); - while (true) { - TF_CHECK_OK(client->TransferToInfeed(*data)); - } + pool->Schedule([transfer_infeed]() { + // There may be several infeed buffers needed, however we don't know how + // many. If we proactively transfer too many infeed buffers, we may run + // out of memory. If we transfer too few infeed buffers, the program will + // hang. Therefore, we register a callback that is called when the infeed + // becomes empty, and in this callback we will transfer another fake + // infeed. + auto infeed_manager = xla::gpu::GetOrCreateInfeedManager(); + infeed_manager->RegisterOnEmptyCallback(transfer_infeed); + transfer_infeed(); }); } - std::vector execute_arguments; - execute_arguments.reserve(arguments.size()); - for (auto& argument : arguments) { - execute_arguments.push_back(argument.get()); + std::vector argument_layouts; + for (const auto& param : computation.proto().program_shape().parameters()) { + argument_layouts.push_back(¶m); } + std::unique_ptr executable = + client->Compile(computation, argument_layouts, ExecutableBuildOptions()) + .ValueOrDie(); // Run the computation num_runs times, and return the result from the last // execution. - std::unique_ptr result; + StreamExecutorMemoryAllocator allocator( + client->platform(), + {client->platform()->ExecutorForDevice(0).ValueOrDie()}); + tensorflow::gtl::optional result; for (int i = 0; i < opts.num_runs; ++i) { ExecutionProfile profile; - ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (opts.xla_hlo_profile_last_run && i == opts.num_runs - 1) { - execution_options.mutable_debug_options()->set_xla_hlo_profile(true); - } + ExecutableRunOptions run_options; + run_options.set_execution_profile(&profile); + run_options.set_allocator(&allocator); - if (opts.print_result) { - TF_ASSIGN_OR_RETURN( - result, client->ExecuteAndTransfer(computation, execute_arguments, - &execution_options, &profile)); - } else { - // If we're not printing the result, execute the computation but don't - // bother retrieving the result. This can be a significant speedup. - TF_RETURN_IF_ERROR(client - ->Execute(computation, execute_arguments, - &execution_options, &profile) - .status()); - } + TF_ASSIGN_OR_RETURN(result, executable->Run(argument_ptrs, run_options)); LOG(INFO) << "Execution took " << static_cast(profile.compute_time_ns()) / 1e9 << "s"; } - return std::move(result); + // Check that --num_runs > 0, otherwise *result below will fail with an + // unhelpful error (because the loop didn't run any iterations). + CHECK_GT(opts.num_runs, 0) << "--num_runs must be > 0"; + TF_ASSIGN_OR_RETURN(std::unique_ptr result_literal, + client->ShapedBufferToLiteral(*result)); + return std::move(*result_literal); } -int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { - Client* client = ClientLibrary::LocalClientOrDie(); +StatusOr ParseInputFile(const string& filename, + const Options& opts) { tensorflow::Env* env = tensorflow::Env::Default(); + HloSnapshot snapshot; + if (tensorflow::ReadBinaryProto(env, filename, &snapshot).ok()) { + return snapshot; + } + CHECK(opts.use_fake_data) + << "Without --use_fake_data, you must pass an HloSnapshot -- HloProto " + "and textual HLO don't carry real data."; + fprintf(stderr, "%s: is not HloSnapshot. Trying HloProto.\n", + filename.c_str()); + + if (tensorflow::ReadBinaryProto(env, filename, snapshot.mutable_hlo()).ok()) { + return snapshot; + } + fprintf(stderr, "%s: is not HloProto. Trying HLO text.\n", filename.c_str()); + string contents; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(env, filename, &contents)); + StatusOr> module = ParseHloString(contents); + if (module.ok()) { + *snapshot.mutable_hlo()->mutable_hlo_module() = + module.ValueOrDie()->ToProto(); + return snapshot; + } + fprintf(stderr, "%s: is not HLO text. Nothing left to try.\n", + filename.c_str()); + return InvalidArgument("Could not parse %s.", filename.c_str()); +} + +int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); int exit_status = EXIT_SUCCESS; for (char* arg : args) { - HloSnapshot snapshot; - auto status = tensorflow::ReadBinaryProto(env, arg, &snapshot); - if (!status.ok()) { - fprintf(stderr, "%s: is not HloSnapshot: %s.\n", arg, - status.ToString().c_str()); + StatusOr maybe_snapshot = ParseInputFile(arg, opts); + if (!maybe_snapshot.ok()) { continue; } - StatusOr> result_status = - ReplayComputation(snapshot, client, opts); + HloSnapshot snapshot = std::move(maybe_snapshot).ValueOrDie(); + StatusOr result_status = ReplayComputation(snapshot, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); @@ -169,12 +246,12 @@ int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { continue; } - std::unique_ptr result = result_status.ConsumeValueOrDie(); - if (result != nullptr) { + if (opts.print_result) { + Literal result = std::move(result_status).ValueOrDie(); fprintf(stdout, "%s: %s :: %s:%s\n", arg, snapshot.hlo().hlo_module().name().c_str(), - ShapeUtil::HumanString(result->shape()).c_str(), - result->ToString().c_str()); + ShapeUtil::HumanString(result.shape()).c_str(), + result.ToString().c_str()); if (snapshot.has_result()) { std::unique_ptr literal = Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie(); @@ -204,9 +281,9 @@ int main(int argc, char** argv) { "Number of times to run each computation"), tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), - tensorflow::Flag( - "xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run, - "Pass --xla_hlo_profile the last time we run the computation."), + tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed, + "Whether a fake infeed shape should be generated " + "derived from the computation"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index be33bd6dd1304fa8fc6e5aed1d4c4d65bf97e692..b4f45cc972d3d397ddff8e8d9163d1fef387392f 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -218,6 +218,12 @@ Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); // Passed-varargs variant of the InvalidArgument factory above. Status InvalidArgumentV(const char* format, va_list args); +template +Status InvalidArgumentStrCat(Args&&... concat) { + return InvalidArgument( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + template Status UnimplementedStrCat(Args&&... concat) { return Unimplemented( @@ -486,6 +492,12 @@ bool c_is_sorted(const C& c) { return std::is_sorted(std::begin(c), std::end(c)); } +template +bool c_is_sorted(const C& c, Compare&& comp) { + return std::is_sorted(std::begin(c), std::end(c), + std::forward(comp)); +} + template auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) { return std::adjacent_find(std::begin(c), std::end(c)); @@ -514,12 +526,29 @@ typename std::decay::type c_accumulate(const Sequence& sequence, T&& init, std::forward(binary_op)); } +template +typename std::iterator_traits< + decltype(std::begin(std::declval()))>::difference_type +c_count_if(const C& c, Pred&& pred) { + return std::count_if(std::begin(c), std::end(c), std::forward(pred)); +} + template int64 FindIndex(const C& c, Value&& value) { auto it = c_find(c, std::forward(value)); return std::distance(c.begin(), it); } +template +void InsertAt(C* c, int64 index, Value&& value) { + c->insert(c->begin() + index, std::forward(value)); +} + +template +void EraseAt(C* c, int64 index) { + c->erase(c->begin() + index); +} + // Returns true if `x` fits in 32-bits. template bool IsInt32(T x) { diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index f619b8dc24038af64a27fc0565c74447ca9d09cf..6f07e4606bef015214f2c564515c8258a906205b 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -17,7 +17,6 @@ syntax = "proto3"; import "tensorflow/compiler/xla/xla_data.proto"; import "tensorflow/compiler/xla/service/hlo.proto"; -import "tensorflow/compiler/xla/service/session.proto"; package xla; @@ -226,22 +225,6 @@ message ExecutionOptions { repeated DeviceHandle device_handles = 5; } -message SnapshotComputationRequest { - ComputationHandle computation = 1; -} - -message SnapshotComputationResponse { - SessionModule module = 1; -} - -message LoadComputationSnapshotRequest { - SessionModule module = 1; -} - -message LoadComputationSnapshotResponse { - ComputationHandle computation = 1; -} - message GetDeviceHandlesRequest { int64 device_count = 1; } @@ -300,11 +283,6 @@ message ResetDeviceRequest { message ResetDeviceResponse { } -message ComputationStatsRequest { - ComputationHandle computation = 1; - DebugOptions debug_options = 2; -} - message ComputationGraphStatsRequest { HloModuleProto computation = 1; DebugOptions debug_options = 2; @@ -314,14 +292,6 @@ message ComputationStatsResponse { ComputationStats stats = 1; } -message ComputationRequest { - string name = 1; -} - -message ComputationResponse { - ComputationHandle computation = 1; -} - message CreateChannelHandleRequest { } @@ -336,24 +306,6 @@ message UnregisterRequest { message UnregisterResponse { } -message SetReturnValueRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; -} - -message SetReturnValueResponse { -} - -message ExecuteRequest { - reserved 3, 4; - - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 5; -} - message ExecuteGraphRequest { HloModuleProto computation = 1; repeated GlobalDataHandle arguments = 2; @@ -362,10 +314,6 @@ message ExecuteGraphRequest { ExecutionOptions execution_options = 3; } -message ExecuteParallelRequest { - repeated ExecuteRequest requests = 1; -} - message ExecuteGraphParallelRequest { repeated ExecuteGraphRequest requests = 1; } @@ -379,21 +327,6 @@ message ExecuteParallelResponse { repeated ExecuteResponse responses = 1; } -message ExecuteAsyncRequest { - reserved 3, 4; - - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; - - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 6; -} - -message ExecuteAsyncResponse { - // A handle to the execution launched asynchronously. - ExecutionHandle execution = 1; -} - message WaitForExecutionRequest { ExecutionHandle execution = 1; } @@ -403,31 +336,13 @@ message WaitForExecutionResponse { ExecutionProfile profile = 2; } -message IsConstantRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; - int64 num_parameters = 3; -} - -message IsConstantResponse { - bool is_constant = 1; -} - -message ComputeConstantRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; - Layout output_layout = 3; - repeated LiteralProto parameters = 4; -} - message ComputeConstantGraphRequest { HloModuleProto computation = 1; Layout output_layout = 2; } message ComputeConstantResponse { - // A LiteralProto is returned directly for this request, instead of a - // ComputationDataHandle. + // A LiteralProto is returned directly for this request. LiteralProto literal = 1; } @@ -469,14 +384,6 @@ message LoadDataResponse { int64 nanoseconds = 5; } -message SpecializeRequest { - ComputationHandle computation = 1; - repeated GlobalDataHandle arguments = 2; -} - -message SpecializeResponse { -} - message GetShapeRequest { GlobalDataHandle data = 1; } @@ -485,14 +392,6 @@ message GetShapeResponse { Shape shape = 1; } -message GetComputationShapeRequest { - ComputationHandle computation = 1; -} - -message GetComputationShapeResponse { - ProgramShape program_shape = 1; -} - message UnpackRequest { GlobalDataHandle data = 1; } diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index b895ac045c361b2336e0081eadf16334d49d3bee..0af73e8a93060f4569ddef9697b89a6fa2b8674b 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -66,11 +66,16 @@ enum PrimitiveType { // in the dimensions field. TUPLE = 13; - // An opaque type used for passing context specific data to a custom - // operation. + // An opaque type used for passing context-specific data to a custom + // operation. Shapes of this primitive type will have empty dimensions and + // tuple_shapes fields. OPAQUE = 14; - // Next = 17 + // A token type threaded between side-effecting operations. Shapes of this + // primitive type will have empty dimensions and tuple_shapes fields. + TOKEN = 17; + + // Next = 18 } // Describes the value held inside padding elements. @@ -271,12 +276,6 @@ message ExecutionProfile { int64 compute_and_transfer_time_ns = 5; } -// Handle given to a user that represents a computation that the user builds up -// before execution. -message ComputationHandle { - int64 handle = 1; -} - // Handle given to a user that represents an execution that the user launched // asynchronously on the device. message ExecutionHandle { @@ -290,13 +289,6 @@ message GlobalDataHandle { int64 handle = 1; } -// Handle given to a user that represents a data result in a computation. -// This is used to pass to subsequent computations that depends upon the data as -// an operand. -message ComputationDataHandle { - int64 handle = 1; -} - // Handle given to a user that represents a replicated virtual device. Each // replicated device represents N physical devices for execution where N is the // number of replicas. @@ -436,44 +428,6 @@ message GatherDimensionNumbers { int64 index_vector_dim = 4; } -// Operation requests that are all collected as a tagged union with a oneof -// field in OpRequest. - -message ConstantRequest { - LiteralProto literal = 2; -} - -message GetTupleElementRequest { - ComputationDataHandle operand = 2; - int64 index = 3; -} - -message SliceRequest { - ComputationDataHandle operand = 2; - repeated int64 start_indices = 3; - repeated int64 limit_indices = 4; - repeated int64 strides = 5; -} - -message DynamicSliceRequest { - // Operand from which to slice at dynamic 'start_indices'. - ComputationDataHandle operand = 2; - // Dynamically computed 'start_indices' for slice operation. - ComputationDataHandle start_indices = 3; - // Slice sizes for each dimension (note that indices calculations are computed - // modulo dimension sizes to avoid out-of-bound array accesses). - repeated int64 slice_sizes = 4; -} - -message DynamicUpdateSliceRequest { - // Operand on which slice 'update' is to be applied. - ComputationDataHandle operand = 2; - // The slice update to apply to 'operand'. - ComputationDataHandle update = 3; - // Dynamically computed start indices for the update slice operation. - ComputationDataHandle start_indices = 4; -} - message ConvolutionDimensionNumbers { // The number of the dimension that represents batch in the input. int64 input_batch_dimension = 7; @@ -511,13 +465,6 @@ message ConvolutionDimensionNumbers { // Next = 13 }; -message ConvolveRequest { - ComputationDataHandle lhs = 2; - ComputationDataHandle rhs = 3; // This is the filter/kernel. - Window window = 4; // Describes the filter/kernel. - ConvolutionDimensionNumbers dimension_numbers = 5; -} - enum FftType { FFT = 0; // Forward FFT; complex in, complex out. IFFT = 1; // Inverse FFT; complex in, complex out. @@ -526,56 +473,6 @@ enum FftType { // fft_length real out } -message FftRequest { - FftType fft_type = 1; - repeated int64 fft_length = 2; // Multivalent for higher-order FFT. - ComputationDataHandle operand = 3; -} - -message InfeedRequest { - // The shape of the data returned by reading the device's infeed buffer. - Shape shape = 2; - - // Additional infeed configuration for the backend. - bytes config = 3; -} - -message OutfeedRequest { - // The shape of the data returned by reading the device's outfeed buffer. - Shape shape = 1; - - // Operand to the Outfeed. Supports tuple. - ComputationDataHandle operand = 2; - - // Backend-specific information for how to perform the outfeed. - bytes outfeed_config = 3; -} - -message CallRequest { - ComputationHandle to_apply = 2; - repeated ComputationDataHandle operands = 3; -} - -message CustomCallRequest { - string call_target_name = 2; - repeated ComputationDataHandle operands = 3; - Shape shape = 4; -} - -message HostComputeRequest { - // Operand to the HostCompute. Supports tuple. - repeated ComputationDataHandle operands = 1; - - // Name used to identify HostSend/Recv channels. - string channel_name = 2; - - // Cost estimate in nanoseconds. - int64 cost_estimate_ns = 3; - - // The shape of any data returned by host. - Shape shape = 4; -} - message DotDimensionNumbers { // The dimension numbers that represent the 'lhs' contracting dimensions. repeated int64 lhs_contracting_dimensions = 1; @@ -587,297 +484,6 @@ message DotDimensionNumbers { repeated int64 rhs_batch_dimensions = 4; }; -message DotRequest { - ComputationDataHandle lhs = 2; - ComputationDataHandle rhs = 3; - DotDimensionNumbers dimension_numbers = 4; -} - -message MapRequest { - repeated ComputationDataHandle operands = 2; - ComputationHandle to_apply = 3; - repeated ComputationDataHandle static_operands = 4; - // The dimensions over which to map. - // Example mapping a Dot operation along the batch dimension 0: - // operand0.shape = [2, 2, 2], operand1.shape = [2,2,3] - // Map({operand0, operand1}, Dot, {0}) - repeated int64 dimensions = 5; -} - -message ReduceRequest { - // Operand to the reduction. - ComputationDataHandle operand = 2; - - // Initial value for the reduction. This must be consistent with the result - // shape of to_apply. - ComputationDataHandle init_value = 3; - - // The dimensions to reduce over. - repeated int64 dimensions = 4; - - // The computation to apply in the reduction. - ComputationHandle to_apply = 5; -} - -message ReduceWindowRequest { - ComputationDataHandle operand = 2; - ComputationDataHandle init_value = 3; - Window window = 4; - ComputationHandle to_apply = 5; -} - -message BatchNormTrainingRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle offset = 3; - float epsilon = 4; - int64 feature_index = 5; -} - -message BatchNormInferenceRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle offset = 3; - ComputationDataHandle mean = 4; - ComputationDataHandle variance = 5; - float epsilon = 6; - int64 feature_index = 7; -} - -message BatchNormGradRequest { - ComputationDataHandle operand = 1; - ComputationDataHandle scale = 2; - ComputationDataHandle mean = 3; - ComputationDataHandle variance = 4; - ComputationDataHandle grad_output = 5; - float epsilon = 6; - int64 feature_index = 7; -} - -message CrossReplicaSumRequest { - ComputationDataHandle operand = 2; -} - -message SelectAndScatterRequest { - // Operand array on which the windows slide. - ComputationDataHandle operand = 2; - - // Source array for the data to scatter. - ComputationDataHandle source = 3; - - // Initial scalar value for each element in the output. - ComputationDataHandle init_value = 4; - - // Window configuration. - Window window = 5; - - // Binary function used to select an element from each window. - ComputationHandle select = 6; - - // Binary function used to combine each scattered value from source with the - // current output value at the selected location. - ComputationHandle scatter = 7; -} - -message ReverseRequest { - ComputationDataHandle operand = 2; - repeated int64 dimensions = 3; -} - -message BroadcastRequest { - ComputationDataHandle operand = 2; - repeated int64 broadcast_sizes = 3; -} - -message PadRequest { - ComputationDataHandle operand = 2; - ComputationDataHandle padding_value = 3; - PaddingConfig padding_config = 4; -} - -message ReshapeRequest { - ComputationDataHandle operand = 2; - - // The dimension order for collapse (from fastest-changing to slowest). - repeated int64 dimensions = 3; - - // The new dimension sizes (from dimension 0 to n-1). - repeated int64 new_sizes = 4; -} - -message TransposeRequest { - ComputationDataHandle operand = 2; - - // The permutation of the operand's dimensions (in the range 0 to n-1). - repeated int64 dimensions = 3; -} - -message ParameterRequest { - Shape shape = 2; - int64 parameter = 3; - string name = 4; -} - -message GetLocalShapeRequest { - ComputationHandle computation = 1; - ComputationDataHandle operand = 2; -} - -message GetLocalShapeResponse { - Shape shape = 1; -} - -message TraceRequest { - string tag = 2; - ComputationDataHandle operand = 3; -} - -message ConvertRequest { - ComputationDataHandle operand = 2; - PrimitiveType new_element_type = 3; -} - -message ConcatenateRequest { - repeated ComputationDataHandle operands = 2; - // The dimension in which we concatenate; e.g. if you had dimension arrays of - // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1]. - // Attempting to concatenate those in dimension 1 would produce an error, as - // 4 != 5 (and there is no ragged array support). - int64 dimension = 3; -} - -message ConditionalRequest { - ComputationDataHandle predicate = 2; - ComputationDataHandle true_operand = 3; - ComputationHandle true_computation = 4; - ComputationDataHandle false_operand = 5; - ComputationHandle false_computation = 6; -} - -message WhileRequest { - ComputationHandle condition = 2; - ComputationHandle body = 3; - ComputationDataHandle init = 4; -} - -enum UnaryOperation { - UNOP_INVALID = 0; - - // Elementwise, logical negation on booleans and bitwise negation on ints. - UNOP_NOT = 1; - - // Elementwise, computes e^x. - UNOP_EXP = 2; - - // Elementwise, computes -x. - UNOP_NEGATE = 3; - - // Puts the elements in the operand into sorted order. - UNOP_SORT = 4; - - // Elementwise, computes tanh(x). - UNOP_TANH = 5; - - // Elementwise, computes the natural logarithm of x. - UNOP_LOG = 6; - - // Elementwise, computes the floor of x. - UNOP_FLOOR = 7; - - // Elementwise, computes the ceil of x. - UNOP_CEIL = 8; - - // Elementwise, computes the abs of x. - UNOP_ABS = 9; - - // Elementwise, computes the sign of x. - UNOP_SIGN = 10; - - // Elementwise, tests if values are finite (not NaN or inf) - UNOP_IS_FINITE = 11; - - // Elementwise, computes the cosine of x. - UNOP_COS = 12; - - // Elementwise, computes the sine of x. - UNOP_SIN = 13; - - // Elementwise, rounds x to nearest integral value, rounding half-way cases - // away from zero. - UNOP_ROUND_NEAREST_AFZ = 14; - - // Elementwise, extract real component of complex x. - UNOP_REAL = 15; - - // Elementwise, extract real component of complex x. - UNOP_IMAG = 16; - - // Elementwise, computes clz(x). - UNOP_CLZ = 17; - - // Elementwise, computes exp(x)-1. - UNOP_EXPM1 = 18; - - // Elementwise, computes log(x+1). - UNOP_LOG1P = 19; -} - -message UnaryOpRequest { - UnaryOperation unop = 2; - ComputationDataHandle operand = 3; -} - -enum BinaryOperation { - BINOP_INVALID = 0; - - // Arithmetic operations. - BINOP_ADD = 1; - BINOP_DIV = 2; - BINOP_MUL = 3; - BINOP_SUB = 4; - - // Comparison operators. - BINOP_EQ = 5; - BINOP_GE = 6; - BINOP_GT = 7; - BINOP_LE = 8; - BINOP_LT = 9; - BINOP_NE = 10; - - // Element-wise maximum. - BINOP_MAX = 14; - - // Element-wise minimum. - BINOP_MIN = 15; - - // Raises the left-hand-side to the right-hand-side power. - BINOP_POW = 16; - - // Remainder operation. - BINOP_REM = 17; - - // Element-wise, logical operators on booleans and bitwise operators on ints. - BINOP_AND = 18; - BINOP_OR = 19; - - BINOP_SHIFT_LEFT = 20; - BINOP_SHIFT_RIGHT_ARITHMETIC = 21; - BINOP_SHIFT_RIGHT_LOGICAL = 22; - - // Complex from real, imag. - BINOP_COMPLEX = 23; - - // Computes the 4-quadrant arctangent of the y, x input arguments. - BINOP_ATAN2 = 24; -} - -message BinaryOpRequest { - BinaryOperation binop = 2; - ComputationDataHandle lhs = 3; - ComputationDataHandle rhs = 4; - repeated int64 broadcast_dimensions = 5; -} - enum RandomDistribution { RNG_INVALID = 0; @@ -892,67 +498,6 @@ enum RandomDistribution { // Next: 4 } -message RngRequest { - RandomDistribution distribution = 2; - repeated ComputationDataHandle parameter = 3; - Shape shape = 4; -} - -enum TernaryOperation { - TRIOP_INVALID = 0; - - // Given a predicate and two operands, selects operand0 if the predicate is - // true and operand1 if the predicate is false. - TRIOP_SELECT = 1; - - // Given a min, max and an operand returns the operand if between min and max, - // else returns min if operand is less than min or max if operand is greater - // than max. - TRIOP_CLAMP = 3; -} - -message TernaryOpRequest { - TernaryOperation triop = 2; - ComputationDataHandle lhs = 3; - ComputationDataHandle rhs = 4; - ComputationDataHandle ehs = 5; -} - -enum VariadicOperation { - VAROP_INVALID = 0; - - // Creates a tuple from its operands. - VAROP_TUPLE = 1; -} - -message VariadicOpRequest { - VariadicOperation varop = 2; - repeated ComputationDataHandle operands = 3; -} - -message ReducePrecisionRequest { - ComputationDataHandle operand = 1; - int32 exponent_bits = 2; - int32 mantissa_bits = 3; -} - -message SendRequest { - ComputationDataHandle operand = 1; - ChannelHandle channel_handle = 2; -} - -message RecvRequest { - Shape shape = 1; - ChannelHandle channel_handle = 2; -} - -message GatherRequest { - ComputationDataHandle input = 1; - ComputationDataHandle gather_indices = 2; - GatherDimensionNumbers dimension_numbers = 3; - repeated int64 window_bounds = 4; -} - message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, @@ -983,59 +528,3 @@ message OpSharding { // to. repeated OpSharding tuple_shardings = 5; } - -message OpRequest { - ComputationHandle computation = 1; - OpMetadata metadata = 33; - OpSharding sharding = 40; - - oneof op { - BinaryOpRequest binary_op_request = 2; - BroadcastRequest broadcast_request = 3; - CallRequest call_request = 4; - ConcatenateRequest concatenate_request = 5; - ConstantRequest constant_request = 6; - ConvertRequest convert_request = 7; - ConvolveRequest convolve_request = 8; - CrossReplicaSumRequest cross_replica_sum_request = 9; - CustomCallRequest custom_call_request = 10; - DotRequest dot_request = 43; - DynamicSliceRequest dynamic_slice_request = 11; - DynamicUpdateSliceRequest dynamic_update_slice_request = 12; - GetTupleElementRequest get_tuple_element_request = 13; - InfeedRequest infeed_request = 14; - MapRequest map_request = 15; - PadRequest pad_request = 16; - ParameterRequest parameter_request = 17; - ReducePrecisionRequest reduce_precision_request = 36; - ReduceRequest reduce_request = 18; - ReduceWindowRequest reduce_window_request = 19; - ReshapeRequest reshape_request = 20; - ReverseRequest reverse_request = 21; - RngRequest rng_request = 22; - SelectAndScatterRequest select_and_scatter_request = 23; - SliceRequest slice_request = 24; - TernaryOpRequest ternary_op_request = 25; - TraceRequest trace_request = 26; - TransposeRequest transpose_request = 34; - UnaryOpRequest unary_op_request = 27; - VariadicOpRequest variadic_op_request = 28; - WhileRequest while_request = 29; - SendRequest send_request = 30; - RecvRequest recv_request = 31; - OutfeedRequest outfeed_request = 32; - BatchNormTrainingRequest batch_norm_training_request = 35; - BatchNormGradRequest batch_norm_grad_request = 37; - BatchNormInferenceRequest batch_norm_inference_request = 38; - FftRequest fft_request = 41; - ConvertRequest bitcast_convert_request = 42; - ConditionalRequest conditional_request = 44; - HostComputeRequest host_compute_request = 45; - GatherRequest gather_request = 46; - // Next: 47 - } -} - -message OpResponse { - ComputationDataHandle output = 1; -} diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 0f9c80404ad33c39ae783e0bfa3cfb26e342fe3d..50b1ae5cc3cba2d6ac89c4415a3419ffdf7aec93 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -31,13 +31,15 @@ py_library( "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", "//tensorflow/contrib/coder:coder_py", "//tensorflow/contrib/compiler:compiler_py", + "//tensorflow/contrib/autograph", "//tensorflow/contrib/constrained_optimization", + "//tensorflow/contrib/control_flow", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/data", - "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/deprecated:deprecated_py", + "//tensorflow/contrib/distribute:distribute", "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/contrib/estimator:estimator_py", @@ -83,7 +85,6 @@ py_library( "//tensorflow/contrib/proto", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", - "//tensorflow/contrib/autograph", "//tensorflow/contrib/receptive_field:receptive_field_py", "//tensorflow/contrib/recurrent:recurrent_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 9aad772f0acd941d50d6ba238d345616195a6939..ad8c40395c2cdcc5e4288e04bb2115bd3627cdc9 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -30,6 +30,7 @@ from tensorflow.contrib import cluster_resolver from tensorflow.contrib import coder from tensorflow.contrib import compiler from tensorflow.contrib import constrained_optimization +from tensorflow.contrib import control_flow from tensorflow.contrib import copy_graph from tensorflow.contrib import crf from tensorflow.contrib import cudnn_rnn diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index 62d1b1cf079d04d50e4899cfd9ba1d405ee1efb9..881808a98bfd688c2efaa8beb5b8f11a2527fee8 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -11,6 +11,16 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "tf_py_test") +py_library( + name = "all_reduce_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":all_reduce", + "//tensorflow/python:util", + ], +) + py_library( name = "all_reduce", srcs = [ diff --git a/tensorflow/contrib/all_reduce/__init__.py b/tensorflow/contrib/all_reduce/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f9824f4cfbf83d9b001a58cafe582226e96c076f --- /dev/null +++ b/tensorflow/contrib/all_reduce/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""All-reduce implementations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.all_reduce.python.all_reduce import * + +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'build_ring_all_reduce', + 'build_recursive_hd_all_reduce', + 'build_shuffle_all_reduce', + 'build_nccl_all_reduce', + 'build_nccl_then_ring', + 'build_nccl_then_recursive_hd', + 'build_nccl_then_shuffle', + 'build_shuffle_then_ring', + 'build_shuffle_then_shuffle' +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/android/jni/run_stats_jni.cc b/tensorflow/contrib/android/jni/run_stats_jni.cc index 707853b59befc2625145ad96952fbf9f66d62b43..30de7b59af79cb36ee266a15bb6e668c2e3f628a 100644 --- a/tensorflow/contrib/android/jni/run_stats_jni.cc +++ b/tensorflow/contrib/android/jni/run_stats_jni.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/contrib/android/jni/run_stats_jni.h" #include + #include #include "tensorflow/core/protobuf/config.pb.h" @@ -73,7 +74,8 @@ JNIEXPORT jstring RUN_STATS_METHOD(summary)(JNIEnv* env, jclass clazz, StatSummarizer* s = requireHandle(env, handle); if (s == nullptr) return nullptr; std::stringstream ret; - ret << s->GetStatsByMetric("Top 10 CPU", StatSummarizer::BY_TIME, 10) + ret << s->GetStatsByMetric("Top 10 CPU", tensorflow::StatsCalculator::BY_TIME, + 10) << s->GetStatsByNodeType() << s->ShortSummary(); return env->NewStringUTF(ret.str().c_str()); } diff --git a/tensorflow/contrib/autograph/CONTRIBUTING.md b/tensorflow/contrib/autograph/CONTRIBUTING.md index a7a3fe1452d2a3e9c2a37a25ae96f541f8f939e0..a4aec8c74a9ad1418072471a5d3cde8c3b968a38 100644 --- a/tensorflow/contrib/autograph/CONTRIBUTING.md +++ b/tensorflow/contrib/autograph/CONTRIBUTING.md @@ -2,6 +2,9 @@ We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below. +## TensorFlow Code of Conduct +Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md). + ## Contributor License Agreement Contributions to this project must be accompanied by a Contributor License @@ -28,7 +31,7 @@ repository (with credit to the original author) and closes the pull request. ## Style -See the [TensorFlow AutoGraph style guide](STYLE_GUIDE.md). +See the [AutoGraph style guide](STYLE_GUIDE.md). ## Unit tests diff --git a/tensorflow/contrib/autograph/STYLE_GUIDE.md b/tensorflow/contrib/autograph/STYLE_GUIDE.md index 5618ec3e34499ad0f0b2a0d8b0ad04c11ee9bf9c..866e5f583a34570dfddc733f57561ed1d2b7c5bf 100644 --- a/tensorflow/contrib/autograph/STYLE_GUIDE.md +++ b/tensorflow/contrib/autograph/STYLE_GUIDE.md @@ -1,43 +1,26 @@ -# TensorFlow AutoGraph Style Guide +# AutoGraph Style Guide -This page contains style decisions that both developers and users of TensorFlow -AutoGraph should follow to increase the readability of their code, reduce the -number of errors, and promote consistency. We borrow many style principles from the TensorFlow Probability style guide. +This page contains style decisions that developers should follow when +contributing code to AutoGraph. ## TensorFlow Style Follow the [TensorFlow style -guide](https://www.tensorflow.org/community/style_guide) and [documentation -guide](https://www.tensorflow.org/community/documentation). Below are additional -TensorFlow conventions not noted in those guides. In the future, these noted -conventions may be moved upstream. +guide](https://www.tensorflow.org/community/style_guide), the [documentation +guide](https://www.tensorflow.org/community/documentation) and the +[Google Python style guide](https://google.github.io/styleguide/pyguide.html). + +Naming conventions: 1. The name is TensorFlow, not Tensorflow. 2. The name is AutoGraph, not Autograph. -## TensorFlow Code of Conduct -Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md). - -## TensorFlow AutoGraph Style +## AutoGraph Style -Below are TensorFlow AutoGraph-specific conventions. In the event of conflict, +Below are AutoGraph-specific conventions. In the event of conflict, it supercedes all previous conventions. -1. __Importing submodule aliases.__ Use the Pythonic style -`from tensorflow.contrib.autograph.converters import ifexp` and `from tensorflow.contrib import autograph as ag`. - -2. __Examples in Docstrings.__ Write a `#### Examples` subsection below `Args`, - `Returns`, `Raises`, etc. to illustrate examples. If the docstring's last - line is a fence bracket (\`\`\`) closing a code snippet, add an empty line - before closing the docstring with \"\"\". This properly displays the code - snippet. - - Justification: Users regularly need to remind themselves of args and - semantics. But rarely look at examples more than the first time. But since - examples are usually long (which is great!) it means they have to do a lot - of annoying scrolling ...unless Examples follow Args/Returns/Raises. - -3. __Citations in Docstrings.__ Write a `#### References` subsection at the +1. __Citations in Docstrings.__ Write a `#### References` subsection at the bottom of any docstring with citations. Use ICLR’s bibliography style to write references; for example, order entries by the first author's last name. Add a link to the paper if the publication is open source (ideally, @@ -77,21 +60,12 @@ it supercedes all previous conventions. https://arxiv.org/abs/1803.04386 ``` -4. When doing float math over literals eg use `1.` instead of `1` or `1.0`. - - * Using `1.` is another line of defense against an automatic casting - mistake. (Using `1.0` is also such a defense but is not minimal.) - -5. Prefer using named args for functions' 2nd args onward. - - * Definitely use named args for 2nd args onward in docstrings. - -9. Avoid LaTeX in docstrings. +2. Avoid LaTeX in docstrings. * It is not rendered in many (if not most) editors and can be hard to read for both LaTeX experts and non-experts. -10. Write docstring and comment math using ASCII friendly notation; python using +3. Write docstring and comment math using ASCII friendly notation; python using operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`, `sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx: x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`. @@ -99,27 +73,3 @@ it supercedes all previous conventions. * The more we stick to python style, the more someone can copy/paste/execute. * Python style is usually easier to read as ASCII. - -11. All public functions require docstrings with: one line description, Args, - Returns, Raises (if raises exceptions). - - * Returns docstrings should be in the same format as Args, eg, of the form - "name: Description." Part of the rationale is that we are suggesting a - reasonable variable name for the returned object(s). - -12. Regard `*args` and/or `**kwargs` as features of last resort. - - * Keyword arguments make the intention of a function call more clear. - * [Possible exceptions for - `kwargs`](https://stackoverflow.com/questions/1415812/why-use-kwargs-in-python-what-are-some-real-world-advantages-over-using-named). - -18. The `__init__.py` file for modules should use TensorFlow's - `remove_undocumented` feature, which seals the module's methods. - -21. Use `"{}".format()` rather than `"" %` for string formatting. - - Justification: [PEP 3101](https://www.python.org/dev/peps/pep-3101/) and - [Python official - tutorials](https://docs.python.org/3.2/tutorial/inputoutput.html#old-string-formatting): - "...this old style of formatting will eventually be removed from the - language, str.format() should generally be used." diff --git a/tensorflow/contrib/autograph/__init__.py b/tensorflow/contrib/autograph/__init__.py index 3386c4eca4b93e850f6fe3c6239d29c61d787ece..dbdbad8f4c91c725294baa36acebbaf5b5e8cf5c 100644 --- a/tensorflow/contrib/autograph/__init__.py +++ b/tensorflow/contrib/autograph/__init__.py @@ -23,18 +23,37 @@ from __future__ import print_function # TODO(mdan): Bring only the relevant symbols to the top level. from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph import operators from tensorflow.contrib.autograph.impl.api import convert from tensorflow.contrib.autograph.impl.api import converted_call from tensorflow.contrib.autograph.impl.api import do_not_convert from tensorflow.contrib.autograph.impl.api import RunMode from tensorflow.contrib.autograph.impl.api import to_code from tensorflow.contrib.autograph.impl.api import to_graph +from tensorflow.contrib.autograph.impl.directives import set_element_type +from tensorflow.contrib.autograph.impl.directives import set_loop_options +from tensorflow.contrib.autograph.impl.special_functions import stack from tensorflow.contrib.autograph.pyct.transformer import AutographParseError from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'utils', 'convert', 'converted_call', 'do_not_convert', 'RunMode', - 'to_code', 'to_graph', 'AutographParseError' + # Main API + 'RunMode', + 'convert', + 'converted_call', + 'do_not_convert', + 'to_code', + 'to_graph', + # Overloaded operators + 'operators', + # Special functions and directives + 'set_element_type', + 'set_loop_options', + 'stack', + # Exceptions + 'AutographParseError', + # Utilities: to be removed + 'utils', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/autograph/converters/BUILD b/tensorflow/contrib/autograph/converters/BUILD index 8f9bffa55e44e4942bb3845945b3d440c7957cc9..284ad84be566199adaaa1ab641d37528ae4dfd2d 100644 --- a/tensorflow/contrib/autograph/converters/BUILD +++ b/tensorflow/contrib/autograph/converters/BUILD @@ -31,6 +31,7 @@ py_library( "name_scopes.py", "side_effect_guards.py", "single_return.py", + "slices.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -208,3 +209,14 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "slices_test", + srcs = ["slices_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py index 35877224b87c1abda1a270be4869e9dcfd0cf97c..775d92c1d9f8bc35d1eda62f3f3ef7ee43414779 100644 --- a/tensorflow/contrib/autograph/converters/break_statements.py +++ b/tensorflow/contrib/autograph/converters/break_statements.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gast - from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer @@ -34,14 +32,6 @@ CONTROL_VAR_NAME = 'control_var_name' class BreakStatementTransformer(transformer.Base): """Canonicalizes break statements into additional conditionals.""" - def _track_body(self, nodes, break_var): - self.enter_local_scope() - self.set_local(CONTROL_VAR_NAME, break_var) - nodes = self.visit_block(nodes) - break_used = self.get_local(BREAK_USED, False) - self.exit_local_scope() - return nodes, break_used - def visit_Break(self, node): self.set_local(BREAK_USED, True) var_name = self.get_local(CONTROL_VAR_NAME) @@ -54,13 +44,9 @@ class BreakStatementTransformer(transformer.Base): def _guard_if_present(self, block, var_name): """Prevents the block from executing if var_name is set.""" - - # If we don't have statements that immediately depend on the break - # we still need to make sure that the break variable remains - # used, in case the break becomes useful in later stages of transformation. - # Not having this broke the break_in_inner_loop test. if not block: - block = [gast.Pass()] + return block + template = """ if not var_name: block @@ -71,9 +57,17 @@ class BreakStatementTransformer(transformer.Base): block=block) return node + def _track_body(self, nodes, break_var): + self.enter_local_scope() + self.set_local(CONTROL_VAR_NAME, break_var) + nodes = self.visit_block(nodes) + break_used = self.get_local(BREAK_USED, False) + self.exit_local_scope() + return nodes, break_used + def visit_While(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break__', scope.referenced) + break_var = self.context.namer.new_symbol('break_', scope.referenced) node.test = self.visit(node.test) node.body, break_used = self._track_body(node.body, break_var) @@ -81,6 +75,10 @@ class BreakStatementTransformer(transformer.Base): node.orelse = self.visit_block(node.orelse) if break_used: + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + template = """ var_name = False while test and not var_name: @@ -88,20 +86,18 @@ class BreakStatementTransformer(transformer.Base): else: orelse """ - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). node = templates.replace( template, var_name=break_var, test=node.test, body=node.body, - orelse=self._guard_if_present(node.orelse, break_var)) + orelse=guarded_orelse) return node def visit_For(self, node): scope = anno.getanno(node, NodeAnno.BODY_SCOPE) - break_var = self.context.namer.new_symbol('break__', scope.referenced) + break_var = self.context.namer.new_symbol('break_', scope.referenced) node.target = self.visit(node.target) node.iter = self.visit(node.iter) @@ -110,19 +106,32 @@ class BreakStatementTransformer(transformer.Base): node.orelse = self.visit_block(node.orelse) if break_used: - node.orelse = self._guard_if_present(node.orelse, break_var) + # Python's else clause only triggers if the loop exited cleanly (e.g. + # break did not trigger). + guarded_orelse = self._guard_if_present(node.orelse, break_var) + extra_test = templates.replace_as_expression( + 'not var_name', var_name=break_var) + + # The extra test is hidden in the AST, which will confuse the static + # analysis. To mitigate that, we insert a no-op statement that ensures + # the control variable is marked as used. + # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name) template = """ var_name = False - for_stmt + for target in iter_: + (var_name,) + body + else: + orelse """ - # Python's else clause only triggers if the loop exited cleanly (e.g. - # break did not trigger). node = templates.replace( template, var_name=break_var, - for_stmt=node) - extra_test = templates.replace_as_expression( - 'not var_name', var_name=break_var) + iter_=node.iter, + target=node.target, + body=node.body, + orelse=guarded_orelse) + anno.setanno(node[1], 'extra_test', extra_test) return node diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py index 317711a866f731de1b497295a2752dee0eb544f5..231e4ee35a72f51845a476d9f605986ac73b4676 100644 --- a/tensorflow/contrib/autograph/converters/builtin_functions.py +++ b/tensorflow/contrib/autograph/converters/builtin_functions.py @@ -31,9 +31,6 @@ class BuiltinFunctionTransformer(transformer.Base): TF equivalent, like `len`. """ - def __init__(self, context): - super(BuiltinFunctionTransformer, self).__init__(context) - def _convert_builtin(self, node): template = """ ag__.utils.dynamic_builtin(func, args) @@ -51,7 +48,7 @@ class BuiltinFunctionTransformer(transformer.Base): # TODO(mdan): This won't work if the function was hidden. # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead. if (isinstance(node.func, gast.Name) and - node.func.id in ('len', 'range', 'xrange')): + node.func.id in ('len', 'range', 'xrange', 'float', 'int')): return self._convert_builtin(node) # Print needs to be handled separately because it can be read as statement. if isinstance(node.func, gast.Name) and node.func.id == 'print': diff --git a/tensorflow/contrib/autograph/converters/continue_statements.py b/tensorflow/contrib/autograph/converters/continue_statements.py index 4299a8a9d59715d032222c47794bbb4393f34ce6..0417817a77e706fc0ce805f7391bea600f5fbb2d 100644 --- a/tensorflow/contrib/autograph/converters/continue_statements.py +++ b/tensorflow/contrib/autograph/converters/continue_statements.py @@ -24,103 +24,115 @@ from tensorflow.contrib.autograph.pyct import transformer from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno -class ContinueCanonicalizationTransformer(transformer.Base): - """Canonicalizes continue statements into additional conditionals.""" +# Tags for local state. +CONTROL_VAR_NAME = 'control_var_name' +CONTINUE_USED = 'continue_used' +GUARD_CREATED = 'guard_created' +CREATE_GUARD_NEXT = 'create_guard_next' - def __init__(self, context): - super(ContinueCanonicalizationTransformer, self).__init__(context) - # This is a stack structure, to correctly process nested loops. - self.continuation_uses = [] - def _create_continuation_check(self): - template = """ - if not var_name: - pass - """ - cond, = templates.replace(template, var_name=self.continuation_uses[-1][1]) - cond.body = [] - return cond +class ContinueCanonicalizationTransformer(transformer.Base): + """Canonicalizes continue statements into additional conditionals.""" - def _create_continuation_trigger(self): + def visit_Continue(self, node): + self.set_local(CONTINUE_USED, True) template = """ var_name = True """ - assign, = templates.replace( - template, var_name=self.continuation_uses[-1][1]) - return assign - - def _create_continuation_init(self): - template = """ - var_name = False - """ - assign, = templates.replace( - template, var_name=self.continuation_uses[-1][1]) - return assign - - def _visit_and_reindent_if_necessary(self, nodes): - reorganized_nodes = [] - current_dest = reorganized_nodes - continue_used_in_block = False - for i, n in enumerate(nodes): - # TODO(mdan): This could be optimized if control structures are simple. - self.continuation_uses[-1][0] = False - n = self.visit(n) - current_dest.append(n) - if self.continuation_uses[-1][0]: - continue_used_in_block = True - if i < len(nodes) - 1: # Last statement in block needs no protection. - cond = self._create_continuation_check() - current_dest.append(cond) - current_dest = cond.body - self.continuation_uses[-1][0] = continue_used_in_block - return reorganized_nodes - - def _process_loop_block(self, block, scope): - cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced) - self.continuation_uses.append([False, cont_var]) - block = self._visit_and_reindent_if_necessary(block) - if self.continuation_uses[-1][0]: - block.insert(0, self._create_continuation_init()) - self.continuation_uses.pop() - return block + return templates.replace( + template, var_name=self.get_local(CONTROL_VAR_NAME)) + + def _postprocess_statement(self, node): + # Example of how the state machine below works: + # + # 1| stmt # State: CONTINUE_USED = False + # | # Action: none + # 2| if cond: + # 3| continue # State: CONTINUE_USED = True, + # | # GUARD_CREATED = False, + # | # CREATE_GUARD_NEXT = False + # | # Action: set CREATE_GUARD_NEXT = True + # 4| stmt # State: CONTINUE_USED = True, + # | # GUARD_CREATED = False, + # | # CREATE_GUARD_NEXT = True + # | # Action: create `if not continue_used`, + # | # set GUARD_CREATED = True + # 5| stmt # State: CONTINUE_USED = True, GUARD_CREATED = True + # | # Action: none (will be wrapped under previously + # | # created if node) + + if self.get_local(CONTINUE_USED, False): + if self.get_local(GUARD_CREATED, False): + return node, None + + elif not self.get_local(CREATE_GUARD_NEXT, False): + self.set_local(CREATE_GUARD_NEXT, True) + return node, None + + else: + self.set_local(GUARD_CREATED, True) + template = """ + if not var_name: + original_node + """ + cond, = templates.replace( + template, + var_name=self.get_local(CONTROL_VAR_NAME), + original_node=node) + return cond, cond.body + return node, None + + def _visit_loop_body(self, node, nodes): + self.enter_local_scope() + scope = anno.getanno(node, NodeAnno.BODY_SCOPE) + continue_var = self.context.namer.new_symbol('continue_', scope.referenced) + self.set_local(CONTROL_VAR_NAME, continue_var) + + nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) + + if self.get_local(CONTINUE_USED, False): + template = """ + var_name = False + """ + control_var_init = templates.replace(template, var_name=continue_var) + nodes = control_var_init + nodes + + self.exit_local_scope() + return nodes + + def _visit_non_loop_body(self, nodes): + self.enter_local_scope(inherit=(CONTROL_VAR_NAME,)) + nodes = self.visit_block(nodes, after_visit=self._postprocess_statement) + continue_used = self.get_local(CONTINUE_USED, False) + self.exit_local_scope(keep=(CONTINUE_USED,)) + return nodes, continue_used def visit_While(self, node): - self.generic_visit(node.test) - node.body = self._process_loop_block(node.body, - anno.getanno(node, - NodeAnno.BODY_SCOPE)) - for n in node.orelse: - self.generic_visit(n) + node.test = self.visit(node.test) + node.body = self._visit_loop_body(node, node.body) + # A continue in the else clause applies to the containing scope. + node.orelse, _ = self._visit_non_loop_body(node.orelse) return node def visit_For(self, node): - self.generic_visit(node.target) - self.generic_visit(node.iter) - node.body = self._process_loop_block(node.body, - anno.getanno(node, - NodeAnno.BODY_SCOPE)) - for n in node.orelse: - self.generic_visit(n) + node.target = self.generic_visit(node.target) + node.iter = self.generic_visit(node.iter) + node.body = self._visit_loop_body(node, node.body) + # A continue in the else clause applies to the containing scope. + node.orelse, _ = self._visit_non_loop_body(node.orelse) return node def visit_If(self, node): - if self.continuation_uses: - self.generic_visit(node.test) - node.body = self._visit_and_reindent_if_necessary(node.body) - continue_used_in_body = self.continuation_uses[-1][0] - node.orelse = self._visit_and_reindent_if_necessary(node.orelse) - self.continuation_uses[-1][0] = ( - continue_used_in_body or self.continuation_uses[-1][0]) - else: - node = self.generic_visit(node) + node.test = self.generic_visit(node.test) + node.body, continue_used_body = self._visit_non_loop_body(node.body) + node.orelse, continue_used_orelse = self._visit_non_loop_body(node.orelse) + self.set_local(CONTINUE_USED, continue_used_body or continue_used_orelse) return node - def visit_Continue(self, node): - self.continuation_uses[-1][0] = True - return self._create_continuation_trigger() - - def visit_Break(self, node): - assert False, 'break statement should be desugared at this point' + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body, _ = self._visit_non_loop_body(node.body) + return node def transform(node, namer): diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 1a863590f97add9bfa587d1142a09ae26a9fdb44..9d23d9b5b7e8e8480e04fccc1c8c81799abf382b 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -42,7 +42,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.while_loop) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual((10, 5, 5), sess.run(result.test_fn(constant_op.constant(5)))) @@ -57,7 +57,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.while_loop) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5)))) @@ -75,7 +75,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.cond) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual((-1, 0), sess.run(result.test_fn(constant_op.constant(1)))) @@ -92,7 +92,7 @@ class ControlFlowTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {}) node = control_flow.transform(node, self.ctx) - with self.compiled(node, control_flow_ops.cond) as result: + with self.compiled(node) as result: with self.test_session() as sess: self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1)))) diff --git a/tensorflow/contrib/autograph/converters/lists.py b/tensorflow/contrib/autograph/converters/lists.py index b49521b2c328f418828a5e92890aa1b169384b70..c15dfff9e8ebd8b96fd4aff82459a6fd7d0ac8ab 100644 --- a/tensorflow/contrib/autograph/converters/lists.py +++ b/tensorflow/contrib/autograph/converters/lists.py @@ -33,82 +33,193 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer -from tensorflow.python.framework import dtypes +from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno + + +# Tags for local state. +POP_USES = 'pop_uses' class ListTransformer(transformer.Base): """Converts lists and related operations to their TF counterpart.""" - def _empty_list(self, node): - if not anno.hasanno(node, 'element_type'): - raise NotImplementedError( - 'type inference for empty lists is not yet supported; ' - 'use set_element_type(, ) to continue') - dtype = anno.getanno(node, 'element_type') - if not isinstance(dtype, dtypes.DType): - # TODO(mdan): Allow non-TF dtypes? - # That would be consistent with the dynamic dispatch pattern, but - # we must make sure that doesn't become confusing. - raise NotImplementedError('element type "%s" not yet supported' % dtype) - - dtype_name = dtype.name - # TODO(mdan): Does it ever make sense not to use tensor lists? + def visit_List(self, node): + node = self.generic_visit(node) template = """ - tf.TensorArray(tf.dtype_name, size=0, dynamic_size=True) + ag__.new_list(elements) """ - return templates.replace_as_expression(template, dtype_name=dtype_name) + return templates.replace_as_expression(template, elements=node) - def _pre_populated_list(self, node): - raise NotImplementedError('pre-populated lists') + def _replace_append_call(self, node): + assert len(node.args) == 1 + assert isinstance(node.func, gast.Attribute) + template = """ + target = ag__.list_append(target, element) + """ + return templates.replace( + template, + target=node.func.value, + element=node.args[0]) + + def _replace_pop_call(self, node): + # Expressions that use pop() are converted to a statement + expression. + # + # For example: + # + # print(target.pop()) + # + # ... is converted to: + # + # target, target_pop = ag__.list_pop(target) + # print(target_pop) + # + # Here, we just generate the variable name and swap it in, + # and _generate_pop_operation will handle the rest. + # + # Multiple uses of pop() are allowed: + # + # print(tartget.pop(), target.pop()) + # print(tartget.pop().pop()) + # + assert isinstance(node.func, gast.Attribute) + scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) + target_node = node.func.value + + # Attempt to use a related name if can get one. Otherwise use something + # generic. + if anno.hasanno(target_node, anno.Basic.QN): + target_name = anno.getanno(target_node, anno.Basic.QN).ssf() + else: + target_name = 'list' + pop_var_name = self.context.namer.new_symbol(target_name, scope.referenced) + + pop_uses = self.get_local(POP_USES, []) + pop_uses.append((node, pop_var_name)) + self.set_local(POP_USES, pop_uses) + + return templates.replace_as_expression('var_name', var_name=pop_var_name) + + def _replace_stack_call(self, node): + assert len(node.args) == 1 + dtype = anno.getanno( + node.args[0], + 'element_type', + default=templates.replace_as_expression('None')) + template = """ + ag__.list_stack( + target, + opts=ag__.ListStackOpts( + element_dtype=dtype, + original_call=orig_call)) + """ + return templates.replace_as_expression( + template, + dtype=dtype, + target=node.args[0], + orig_call=node.func) - def visit_Expr(self, node): + def visit_Call(self, node): node = self.generic_visit(node) - if isinstance(node.value, gast.Call): - call_node = node.value - - if not anno.hasanno(call_node.func, anno.Basic.QN): - return node - qn = anno.getanno(call_node.func, anno.Basic.QN) - - if qn.qn[-1] == 'append' and (len(call_node.args) == 1): - template = """ - target = ag__.utils.dynamic_list_append(target, element) - """ - node = templates.replace( - template, - target=qn.parent.ast(), - element=call_node.args[0]) + + # TODO(mdan): This is insufficient if target is a function argument. + # In the case of function arguments, we need to add the list to the + # function's return value, because it is being modified. + # TODO(mdan): Checking just the name is brittle, can it be improved? + if isinstance(node.func, gast.Attribute): + func_name = node.func.attr + if func_name == 'append' and (len(node.args) == 1): + node = self._replace_append_call(node) + elif func_name == 'pop' and (len(node.args) <= 1): + node = self._replace_pop_call(node) + elif func_name == 'stack' and (len(node.args) == 1): + node = self._replace_stack_call(node) + return node - def _replace_list_constructors(self, targets, values): - for target in targets: - if (isinstance(target, (gast.Tuple, gast.List)) and - isinstance(values, (gast.Tuple, gast.List))): - n_targets = len(target.elts) - for i in range(n_targets): - target_el, value_el = target.elts[i], values.elts[i] - values.elts[i] = self._replace_list_constructors( - (target_el,), value_el) - return values - if isinstance(values, gast.List): - if values.elts: - return self._pre_populated_list(values) - else: - return self._empty_list(values) - return values - - def visit_Assign(self, node): - node = self.generic_visit(node) + def _generate_pop_operation(self, original_call_node, pop_var_name): + assert isinstance(original_call_node.func, gast.Attribute) + + if original_call_node.args: + pop_element = original_call_node.args[0] + else: + pop_element = parser.parse_expression('None') + # The call will be something like "target.pop()", and the dtype is hooked to + # target, hence the func.value. + dtype = anno.getanno( + original_call_node.func.value, + 'element_type', + default=templates.replace_as_expression('None')) + shape = anno.getanno( + original_call_node.func.value, + 'element_shape', + default=templates.replace_as_expression('None')) + + template = """ + target, pop_var_name = ag__.list_pop( + target, element, + opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) + """ + return templates.replace( + template, + target=original_call_node.func.value, + pop_var_name=pop_var_name, + element=pop_element, + dtype=dtype, + shape=shape) + + def _postprocess_statement(self, node): + """Inserts any separate pop() calls that node may use.""" + pop_uses = self.get_local(POP_USES, None) + if pop_uses: + replacements = [] + for original_call_node, pop_var_name in pop_uses: + replacements.extend( + self._generate_pop_operation(original_call_node, pop_var_name)) + replacements.append(node) + node = replacements + self.exit_local_scope() + return node, None + + # TODO(mdan): Should we have a generic visit_block instead? + # Right now it feels that a visit_block would add too much magic that's + # hard to follow. + + def _visit_and_process_block(self, block): + return self.visit_block( + block, + before_visit=self.enter_local_scope, + after_visit=self._postprocess_statement) + + def visit_FunctionDef(self, node): + node.args = self.generic_visit(node.args) + node.decorator_list = self.visit_block(node.decorator_list) + node.body = self._visit_and_process_block(node.body) + return node + + def visit_For(self, node): + node.target = self.visit(node.target) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_While(self, node): + node.test = self.visit(node.test) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node + + def visit_If(self, node): + node.test = self.visit(node.test) + node.body = self._visit_and_process_block(node.body) + node.orelse = self._visit_and_process_block(node.orelse) + return node - # Only convert lists when they are assigned to a variable, e.g.: - # l = [] - # TODO(mdan): A similar pattern exists in type_info.py - # We should add a generic "unpack_assignment" function to the base - # transformer, that has the same effect as applying some logic to the SSA - # form. - node.value = self._replace_list_constructors(node.targets, node.value) + def visit_With(self, node): + node.items = self.visit_block(node.items) + node.body = self._visit_and_process_block(node.body) return node diff --git a/tensorflow/contrib/autograph/converters/lists_test.py b/tensorflow/contrib/autograph/converters/lists_test.py index 74c6dc64f197f75eb3e66c01fb078467e8e8ea89..9f18ab9f44dd8c3f341a02b950f75317c676eff8 100644 --- a/tensorflow/contrib/autograph/converters/lists_test.py +++ b/tensorflow/contrib/autograph/converters/lists_test.py @@ -22,74 +22,126 @@ from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.converters import converter_test_base from tensorflow.contrib.autograph.converters import lists from tensorflow.python.framework import dtypes -from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import list_ops from tensorflow.python.platform import test class ListTest(converter_test_base.TestCase): - def test_empty_annotated_list(self): + def test_empty_list(self): def test_fn(): - l = [] - utils.set_element_type(l, dtypes.int32) - l.append(1) - return l + return [] - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = self.parse_and_analyze(test_fn, {}) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: - # TODO(mdan): Attach these additional modules automatically. - result.utils = utils - result.dtypes = dtypes + with self.compiled(node) as result: + tl = result.test_fn() + # Empty tensor lists cannot be evaluated or stacked. + self.assertTrue(isinstance(tl, ops.Tensor)) + self.assertEqual(tl.dtype, dtypes.variant) + + def test_initialized_list(self): + + def test_fn(): + return [1, 2, 3] + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: with self.test_session() as sess: - self.assertAllEqual([1], sess.run(result.test_fn().stack())) + tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2, 3]) - def test_empty_annotated_lists_unpacked(self): + def test_list_append(self): def test_fn(): - l, m = [], [] - utils.set_element_type(l, dtypes.int32) - utils.set_element_type(m, dtypes.int32) - l.append(1) - m.append(2) - return l, m + l = [1] + l.append(2) + l.append(3) + return l - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + node = self.parse_and_analyze(test_fn, {}) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: + with self.compiled(node) as result: + with self.test_session() as sess: + tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2, 3]) + + def test_list_pop(self): + + def test_fn(): + l = [1, 2, 3] + utils.set_element_type(l, dtypes.int32, ()) + s = l.pop() + return s, l + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: result.utils = utils result.dtypes = dtypes with self.test_session() as sess: - res_l, res_m = result.test_fn() - self.assertEqual([1], sess.run(res_l.stack())) - self.assertEqual([2], sess.run(res_m.stack())) + ts, tl = result.test_fn() + r = list_ops.tensor_list_stack(tl, dtypes.int32) + self.assertAllEqual(sess.run(r), [1, 2]) + self.assertAllEqual(sess.run(ts), 3) + + def test_double_list_pop(self): - def test_empty_annotated_lists_list_unpacked(self): + def test_fn(l): + s = l.pop().pop() + return s + + node = self.parse_and_analyze(test_fn, {}) + node = lists.transform(node, self.ctx) + + with self.compiled(node) as result: + test_input = [1, 2, [1, 2, 3]] + # TODO(mdan): Pass a list of lists of tensor when we fully support that. + # For now, we just pass a regular Python list of lists just to verify that + # the two pop calls are sequenced properly. + self.assertAllEqual(result.test_fn(test_input), 3) + + def test_list_stack(self): + + tf = None # Will be replaced with a mock. def test_fn(): - [l, m] = [], [] + l = [1, 2, 3] utils.set_element_type(l, dtypes.int32) - utils.set_element_type(m, dtypes.int32) - l.append(1) - m.append(2) - return l, m - - node = self.parse_and_analyze(test_fn, {'dtypes': dtypes, 'utils': utils}) + return tf.stack(l) + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) node = lists.transform(node, self.ctx) - with self.compiled(node, tensor_array_ops.TensorArray, - dtypes.int32) as result: + with self.compiled(node, array_ops.stack, dtypes.int32) as result: result.utils = utils result.dtypes = dtypes with self.test_session() as sess: - res_l, res_m = result.test_fn() - self.assertEqual([1], sess.run(res_l.stack())) - self.assertEqual([2], sess.run(res_m.stack())) + self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3]) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/converters/slices.py b/tensorflow/contrib/autograph/converters/slices.py new file mode 100644 index 0000000000000000000000000000000000000000..85aeda9c4164eb70329bd50f789eea5441c8fc87 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/slices.py @@ -0,0 +1,83 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Converter for slice operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import templates +from tensorflow.contrib.autograph.pyct import transformer + + +class SliceTransformer(transformer.Base): + """Converts slicing operations to their TF counterpart. + + Currently, relying on the default slice operator that Tensor uses is + insufficient, because TensorArray and tensor lists use dedicated index read + and write functions. + """ + + def _process_single_assignment(self, target, value): + if not isinstance(target, gast.Subscript): + return None + + template = """ + target = ag__.set_item(target, key, item) + """ + return templates.replace( + template, target=target.value, key=target.slice, item=value) + + def visit_Assign(self, node): + node = self.generic_visit(node) + # TODO(mdan): Support unpackings and multiple assignments. + if len(node.targets) != 1: + raise NotImplementedError('multiple assignment') + replacement = self._process_single_assignment(node.targets[0], node.value) + if replacement is not None: + return replacement + return node + + def visit_Subscript(self, node): + node = self.generic_visit(node) + if not isinstance(node.slice, gast.Index): + # TODO(mdan): It might make more sense to wave them through. + raise NotImplementedError('non-index slice') + + if not isinstance(node.ctx, gast.Load): + # Index writes are handled at a higher level, one at which the rvalue is + # also available. + return node + + dtype = anno.getanno( + node.value, + 'element_type', + default=templates.replace_as_expression('None')) + + template = """ + ag__.get_item( + target, + key, + opts=ag__.GetItemOpts(element_dtype=dtype)) + """ + return templates.replace_as_expression( + template, target=node.value, key=node.slice, dtype=dtype) + + +def transform(node, context): + return SliceTransformer(context).visit(node) diff --git a/tensorflow/contrib/autograph/converters/slices_test.py b/tensorflow/contrib/autograph/converters/slices_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6c2d7e1ea1a6c46fcc3a2c6972a24507646ef858 --- /dev/null +++ b/tensorflow/contrib/autograph/converters/slices_test.py @@ -0,0 +1,59 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slices module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph import utils +from tensorflow.contrib.autograph.converters import converter_test_base +from tensorflow.contrib.autograph.converters import slices +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SliceTest(converter_test_base.TestCase): + + def test_index_access(self): + + def test_fn(l): + utils.set_element_type(l, dtypes.int32) + return l[1] + + node = self.parse_and_analyze( + test_fn, + { + 'utils': utils, + 'dtypes': dtypes + }, + include_type_analysis=True, + ) + node = slices.transform(node, self.ctx) + + with self.compiled(node, dtypes.int32) as result: + result.utils = utils + result.dtypes = dtypes + with self.test_session() as sess: + tl = list_ops.tensor_list_from_tensor( + [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32)) + y = result.test_fn(tl) + self.assertEqual(2, sess.run(y)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/impl/BUILD b/tensorflow/contrib/autograph/impl/BUILD index 54424e26472b8466b8fe68ea848b5463c10224c9..02f16ae1875d6bd1fb87d19f8bfc5cae900391dd 100644 --- a/tensorflow/contrib/autograph/impl/BUILD +++ b/tensorflow/contrib/autograph/impl/BUILD @@ -20,7 +20,9 @@ py_library( "api.py", "config.py", "conversion.py", + "directives.py", "naming.py", + "special_functions.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -69,3 +71,13 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "special_functions_test", + srcs = ["special_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":impl", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 55a30dc127957b2a9caa053db843380c94bacfbf..7802bbbe27ec5fed891440af2f589801918b3bdd 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -38,6 +38,7 @@ from tensorflow.contrib.autograph.converters import logical_expressions from tensorflow.contrib.autograph.converters import name_scopes from tensorflow.contrib.autograph.converters import side_effect_guards from tensorflow.contrib.autograph.converters import single_return +from tensorflow.contrib.autograph.converters import slices from tensorflow.contrib.autograph.impl import config from tensorflow.contrib.autograph.impl import naming from tensorflow.contrib.autograph.pyct import ast_util @@ -371,6 +372,8 @@ def node_to_graph(node, ctx, nocompile_decorators): # TODO(mdan): Clean this up. # Some intermediate analyses are not required, and some comments got orphaned. + # TODO(mdan): We may assume all converters require analysis to be re-done. + # Past this point, line numbers are no longer accurate so we ignore the # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? @@ -393,6 +396,8 @@ def node_to_graph(node, ctx, nocompile_decorators): node = _static_analysis_pass(node, ctx) node = lists.transform(node, ctx) + node = _static_analysis_pass(node, ctx) + node = slices.transform(node, ctx) node = builtin_functions.transform(node, ctx) node = _static_analysis_pass(node, ctx) diff --git a/tensorflow/contrib/autograph/impl/directives.py b/tensorflow/contrib/autograph/impl/directives.py new file mode 100644 index 0000000000000000000000000000000000000000..aabe5d99394a0cb921196d1c6a6b2a9496ea7545 --- /dev/null +++ b/tensorflow/contrib/autograph/impl/directives.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Directives are special no-op functions that serve as compilation markers. + +They provide static information like type hints, compilation and TensorFlow +overrides. + +These serve as annotations in the compiled code, allowing the user some control +over the compilation process. They have no functional role at runtime. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +UNSPECIFIED = object() + + +def set_element_type(entity, dtype, shape=UNSPECIFIED): + """Indicates that the entity is expected hold items of specified type/shape. + + The staged TensorFlow ops will reflect and assert this data type. Ignored + otherwise. + + Args: + entity: The entity to annotate. + dtype: TensorFlow dtype value to assert for entity. + shape: Optional shape to assert for entity. + """ + del entity + del dtype + del shape + + +def set_loop_options( + parallel_iterations=UNSPECIFIED, + back_prop=UNSPECIFIED, + swap_memory=UNSPECIFIED, + maximum_iterations=UNSPECIFIED): + """Specifies additional arguments to be passed to the enclosing while_loop. + + The parameters apply to and only to the immediately enclosing loop. It only + has effect if the loop is staged as a TF while_loop; otherwise the parameters + have no effect. + + Args: + parallel_iterations: See tf.while_loop. + back_prop: See tf.while_loop. + swap_memory: See tf.while_loop. + maximum_iterations: See tf.while_loop. + """ + del parallel_iterations + del back_prop + del swap_memory + del maximum_iterations diff --git a/tensorflow/contrib/autograph/impl/special_functions.py b/tensorflow/contrib/autograph/impl/special_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a8177c44c88217560fb7f72c77d3ac1aa0c9ec --- /dev/null +++ b/tensorflow/contrib/autograph/impl/special_functions.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Special functions that only make sense for AutoGraph. + +These functions are meant to ensure feature parity between Python and AutoGraph, +so that the exact same code works in both modes. In general, AutoGraph will +replace these calls. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import data_structures + + +def stack(list_or_tensor, element_dtype=None): + """Stacks the input, if it admits the notion of stacking. No-op otherwise. + + For example, a list of tensors can be stacked into a larger tensor. This + function is similar to tf.stack, but it accepts non-lists and lists of + non-tensors as arguments. In the latter case, the function does nothing. + + Args: + list_or_tensor: Any entity. + element_dtype: Optional dtype for the elements in the list. Required if the + input is stackable, and the list is untyped. + + Returns: + If the input is stackable, a new object representing the stacked inputs. + Otherwise it returns list_or_tensor unchanged. + """ + return data_structures.list_stack( + list_or_tensor, + data_structures.ListStackOpts( + element_dtype=element_dtype, original_call=lambda x: x)) diff --git a/tensorflow/contrib/autograph/impl/special_functions_test.py b/tensorflow/contrib/autograph/impl/special_functions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9b52d2a59b5a3e3c92f11343197379c773ecc828 --- /dev/null +++ b/tensorflow/contrib/autograph/impl/special_functions_test.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for special_functions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.impl import special_functions +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SpecialFunctionsTest(test.TestCase): + + def test_basic(self): + self.assertEqual(special_functions.stack(1), 1) + self.assertListEqual(special_functions.stack([1, 2, 3]), [1, 2, 3]) + # TODO(mdan): This should probably forward to tf.stack. + self.assertTrue( + isinstance( + special_functions.stack( + [constant_op.constant(1), + constant_op.constant(2)]), list)) + + t = constant_op.constant([1.0, 2.0]) + l = list_ops.tensor_list_from_tensor( + t, element_shape=constant_op.constant([], dtype=dtypes.int32)) + self.assertTrue( + tensor_util.is_tensor( + special_functions.stack(l, element_dtype=dtypes.float32))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index 18bfec5d9c69912f90414c51ac63ba540cf4d5fc..0c6ab65505ee03e19588adae73d3134399a34b65 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -22,7 +22,7 @@ py_library( "__init__.py", "control_flow.py", "data_structures.py", - "dispatch_context.py", + "slices.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], @@ -52,3 +52,13 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "slices_test", + srcs = ["slices_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index 38b761d97d54bdaee4da91269964469b482895ae..c900fd6af2ea5dfb419f731ee8d8822d68424b27 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -28,6 +28,10 @@ closures for the body. # - the names used in the Python docs, if the operator is a function (e.g. # list_ and x for append, see # https://docs.python.org/3.7/tutorial/datastructures.html) +# +# All operators may accept a final argument named "opts", of a type that +# subclasses namedtuple and contains any arguments that are only required +# for some specializations of the operator. from __future__ import absolute_import from __future__ import division @@ -35,3 +39,12 @@ from __future__ import print_function from tensorflow.contrib.autograph.operators.control_flow import for_stmt from tensorflow.contrib.autograph.operators.control_flow import while_stmt +from tensorflow.contrib.autograph.operators.data_structures import list_append +from tensorflow.contrib.autograph.operators.data_structures import list_pop +from tensorflow.contrib.autograph.operators.data_structures import list_stack +from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts +from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts +from tensorflow.contrib.autograph.operators.data_structures import new_list +from tensorflow.contrib.autograph.operators.slices import get_item +from tensorflow.contrib.autograph.operators.slices import GetItemOpts +from tensorflow.contrib.autograph.operators.slices import set_item diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py index c862306baa9e8114a71a26323ddcbd35c8592c55..06d8727b0fcc30b532b3f11281cd1a83c51ac8bc 100644 --- a/tensorflow/contrib/autograph/operators/data_structures.py +++ b/tensorflow/contrib/autograph/operators/data_structures.py @@ -18,39 +18,250 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variables + + +# TODO(mdan): Once control flow supports objects, repackage as a class. + + +def new_list(iterable=None): + """The list constructor. + + Args: + iterable: Optional elements to fill the list with. + + Returns: + A list-like object. The exact return value depends on the initial elements. + """ + if iterable: + elements = tuple(iterable) + else: + elements = () + + # TODO(mdan): Extend these criteria. + if any(isinstance(el, variables.Variable) for el in elements): + return _py_list_new(elements) + return _tf_tensor_list_new(elements) -# TODO(mdan): Add support for TensorList once functional. -# TODO(mdan): Add primitives for empty list, list with elements. +def _tf_tensor_list_new(elements): + """Overload of new_list that stages a Tensor list creation.""" + elements = tuple(ops.convert_to_tensor(el) for el in elements) + all_dtypes = set(el.dtype for el in elements) + if len(all_dtypes) == 1: + element_dtype = tuple(all_dtypes)[0] + else: + # Heterogeneous lists are ok. + element_dtype = dtypes.variant + + # TODO(mdan): This may fail for elements of variable shapes. + all_shapes = set(tuple(el.shape.as_list()) for el in elements) + if len(all_shapes) == 1: + element_shape = array_ops.shape(elements[0]) + else: + # Heterogeneous lists are ok. + element_shape = constant_op.constant(-1) # unknown shape, by convention + + l = list_ops.empty_tensor_list( + element_shape=element_shape, element_dtype=element_dtype) + for el in elements: + l = list_ops.tensor_list_push_back(l, el) + return l -def append(target, element): + +def _py_list_new(elements): + """Overload of new_list that creates a Python list.""" + return list(elements) + + +def list_append(list_, x): """The list append function. - Note: it is unspecified where target will be mutated or not. If target is - a TensorFlow entity, it will not be typically mutated. If target is a plain - list, it will be. In general, if the target is mutated then the return value + Note: it is unspecified where list_ will be mutated or not. If list_ is + a TensorFlow entity, it will not be typically mutated. If list_ is a plain + list, it will be. In general, if the list is mutated then the return value should point to the original entity. Args: - target: An entity that supports append semantics. - element: The element to append. + list_: An entity that supports append semantics. + x: The element to append. Returns: - Same as target, after the append was performed. + Same as list_, after the append was performed. + + Raises: + ValueError: if list_ is not of a known list-like type. """ - if isinstance(target, tensor_array_ops.TensorArray): - return _tf_tensorarray_append(target, element) + if isinstance(list_, tensor_array_ops.TensorArray): + return _tf_tensorarray_append(list_, x) + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_append(list_, x) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % list_) else: - return _py_append(target, element) + return _py_list_append(list_, x) + + +def _tf_tensor_list_append(list_, x): + """Overload of list_append that stages a Tensor list write.""" + def empty_list_of_elements_like_x(): + tensor_x = ops.convert_to_tensor(x) + return list_ops.empty_tensor_list( + element_shape=array_ops.shape(tensor_x), + element_dtype=tensor_x.dtype) + + list_ = control_flow_ops.cond( + list_ops.tensor_list_length(list_) > 0, + lambda: list_, + empty_list_of_elements_like_x, + ) + return list_ops.tensor_list_push_back(list_, x) + + +def _tf_tensorarray_append(list_, x): + """Overload of list_append that stages a TensorArray write.""" + return list_.write(list_.size(), x) + + +def _py_list_append(list_, x): + """Overload of list_append that executes a Python list append.""" + # Revert to the original call. + list_.append(x) + return list_ + + +class ListPopOpts( + collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))): + pass + + +def list_pop(list_, i, opts): + """The list pop function. + + Note: it is unspecified where list_ will be mutated or not. If list_ is + a TensorFlow entity, it will not be typically mutated. If list_ is a plain + list, it will be. In general, if the list is mutated then the return value + should point to the original entity. + + Args: + list_: An entity that supports pop semantics. + i: Optional index to pop from. May be None. + opts: A ListPopOpts. + + Returns: + Tuple (x, out_list_): + out_list_: same as list_, after the removal was performed. + x: the removed element value. + + Raises: + ValueError: if list_ is not of a known list-like type or the operation is + not supported for that type. + """ + assert isinstance(opts, ListPopOpts) + + if isinstance(list_, tensor_array_ops.TensorArray): + raise ValueError('TensorArray does not support item removal') + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_pop(list_, i, opts) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % list_) + else: + return _py_list_pop(list_, i) + + +def _tf_tensor_list_pop(list_, i, opts): + """Overload of list_pop that stages a Tensor list pop.""" + if i is not None: + raise NotImplementedError('tensor lists only support removing from the end') + + if opts.element_dtype is None: + raise ValueError('cannot pop from a list without knowing its element ' + 'type; use set_element_type to annotate it') + if opts.element_shape is None: + raise ValueError('cannot pop from a list without knowing its element ' + 'shape; use set_element_type to annotate it') + list_out, x = list_ops.tensor_list_pop_back( + list_, element_dtype=opts.element_dtype) + x.set_shape(opts.element_shape) + return list_out, x + + +def _py_list_pop(list_, i): + """Overload of list_pop that executes a Python list append.""" + if i is None: + x = list_.pop() + else: + x = list_.pop(i) + return list_, x + + +# TODO(mdan): Look into reducing duplication between all these containers. +class ListStackOpts( + collections.namedtuple('ListStackOpts', + ('element_dtype', 'original_call'))): + pass + + +def list_stack(list_, opts): + """The list stack function. + + This does not have a direct correspondent in Python. The closest idiom to + this is tf.append or np.stack. It's different from those in the sense that it + accepts a Tensor list, rather than a list of tensors. It can also accept + TensorArray. When the target is anything else, the dispatcher will rely on + ctx.original_call for fallback. + + Args: + list_: An entity that supports append semantics. + opts: A ListStackOpts object. + + Returns: + The output of the stack operation, typically a Tensor. + """ + assert isinstance(opts, ListStackOpts) + + if isinstance(list_, tensor_array_ops.TensorArray): + return _tf_tensorarray_stack(list_) + elif tensor_util.is_tensor(list_): + if list_.dtype == dtypes.variant: + return _tf_tensor_list_stack(list_, opts) + else: + # No-op for primitive Tensor arguments. + return list_ + else: + return _py_list_stack(list_, opts) + + +def _tf_tensorarray_stack(list_): + """Overload of list_stack that stages a TensorArray stack.""" + return list_.stack() -def _tf_tensorarray_append(target, element): - """Overload of append that stages a TensorArray write at the last position.""" - return target.write(target.size(), element) +def _tf_tensor_list_stack(list_, opts): + """Overload of list_stack that stages a Tensor list write.""" + if opts.element_dtype is None: + raise ValueError('cannot stack a list without knowing its element type;' + ' use set_element_type to annotate it') + return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype) -def _py_append(target, element): - """Overload of append that executes a Python list append.""" - target.append(element) - return target +def _py_list_stack(list_, opts): + """Overload of list_stack that executes a Python list append.""" + # Revert to the original call. + return opts.original_call(list_) diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py index 577d28c34da39f1216669513c29a00ac07bec126..8bbb52d6c10b241ec754c7dea599fa15a869595f 100644 --- a/tensorflow/contrib/autograph/operators/data_structures_test.py +++ b/tensorflow/contrib/autograph/operators/data_structures_test.py @@ -19,25 +19,98 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.autograph.operators import data_structures +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test -class AppendTest(test.TestCase): +class ListTest(test.TestCase): - def test_tf_tensorarray(self): + def test_new_list_empty(self): + l = data_structures.new_list() + # Can't evaluate an empty list. + # TODO(mdan): sess.run should allow tf.variant maybe? + self.assertTrue(isinstance(l, ops.Tensor)) + + def test_new_list_tensor(self): + l = data_structures.new_list([3, 4, 5]) + t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32) + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4, 5]) + + def test_append_tensor_list(self): + l = data_structures.new_list() + x = constant_op.constant([1, 2, 3]) + l = data_structures.list_append(l, x) + + t = list_ops.tensor_list_stack(l, element_dtype=x.dtype) + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [[1, 2, 3]]) + + def test_append_tensorarray(self): l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True) - l1 = data_structures.append(l, 1) - l2 = data_structures.append(l1, 2) + l1 = data_structures.list_append(l, 1) + l2 = data_structures.list_append(l1, 2) with self.test_session() as sess: self.assertAllEqual(sess.run(l1.stack()), [1]) self.assertAllEqual(sess.run(l2.stack()), [1, 2]) - def test_python(self): + def test_append_python(self): l = [] - self.assertAllEqual(data_structures.append(l, 1), [1]) - self.assertAllEqual(data_structures.append(l, 2), [1, 2]) + self.assertAllEqual(data_structures.list_append(l, 1), [1]) + self.assertAllEqual(data_structures.list_append(l, 2), [1, 2]) + + def test_pop_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + + opts = data_structures.ListPopOpts( + element_dtype=initial_list.dtype, + element_shape=(2,)) + + with self.assertRaises(NotImplementedError): + data_structures.list_pop(l, 0, opts) + + with self.test_session() as sess: + l, x = data_structures.list_pop(l, None, opts) + self.assertAllEqual(sess.run(x), [3, 4]) + + t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) + self.assertAllEqual(sess.run(t), [[1, 2]]) + + def test_pop_python(self): + l = [1, 2, 3] + opts = data_structures.ListPopOpts(element_dtype=None, element_shape=()) + self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1, 2], 3)) + self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1], 2)) + + def test_stack_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + + opts = data_structures.ListStackOpts( + element_dtype=initial_list.dtype, original_call=None) + + with self.test_session() as sess: + t = data_structures.list_stack(l, opts) + self.assertAllEqual(sess.run(t), sess.run(initial_list)) + + def test_stack_fallback(self): + + def dummy_function(l): + # Lazy person's mock: just transform the argument in a way in which we + # can check that this function was indeed called. + return [x * 2 for x in l] + + opts = data_structures.ListStackOpts( + element_dtype=None, original_call=dummy_function) + + self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4]) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py new file mode 100644 index 0000000000000000000000000000000000000000..04fbeb2f6e39234cad139442704fd7a8d0f56172 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/slices.py @@ -0,0 +1,133 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators specific to slicing operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops +from tensorflow.python.ops import tensor_array_ops + + +# TODO(mdan): Support extended slices. + + +class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))): + pass + + +def get_item(target, i, opts): + """The slice read operator (i.e. __getitem__). + + Note: it is unspecified whether target will be mutated or not. In general, + if target is mutable (like Python lists), it will be mutated. + + Args: + target: An entity that supports getitem semantics. + i: Index to read from. + opts: A GetItemOpts object. + + Returns: + The read element. + + Raises: + ValueError: if target is not of a supported type. + """ + assert isinstance(opts, GetItemOpts) + + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_get_item(target, i) + elif tensor_util.is_tensor(target): + if target.dtype == dtypes.variant: + return _tf_tensor_list_get_item(target, i, opts) + else: + return _tf_tensor_get_item(target, i) + else: + return _py_get_item(target, i) + + +def _tf_tensorarray_get_item(target, i): + """Overload of get_item that stages a TensorArray read.""" + return target.read(i) + + +def _tf_tensor_list_get_item(target, i, opts): + """Overload of get_item that stages a Tensor list read.""" + if opts.element_dtype is None: + raise ValueError('cannot retrieve from a list without knowing its ' + 'element type; use set_element_type to annotate it') + x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype) + return x + + +def _tf_tensor_get_item(target, i): + """Overload of get_item that stages a Tensor (not Tensor list) read.""" + return target[i] + + +def _py_get_item(target, i): + """Overload of get_item that executes a Python list modification.""" + return target[i] + + +def set_item(target, i, x): + """The slice write operator (i.e. __setitem__). + + Note: it is unspecified whether target will be mutated or not. In general, + if target is mutable (like Python lists), it will be mutated. + + Args: + target: An entity that supports setitem semantics. + i: Index to modify. + x: The new element value. + + Returns: + Same as target, after the update was performed. + + Raises: + ValueError: if target is not of a supported type. + """ + if isinstance(target, tensor_array_ops.TensorArray): + return _tf_tensorarray_set_item(target, i, x) + elif tensor_util.is_tensor(target): + if target.dtype == dtypes.variant: + return _tf_tensor_list_set_item(target, i, x) + else: + raise ValueError( + 'tensor lists are expected to be Tensors with dtype=tf.variant,' + ' instead found %s' % target) + else: + return _py_set_item(target, i, x) + + +def _tf_tensorarray_set_item(target, i, x): + """Overload of set_item that stages a TensorArray write.""" + return target.write(i, x) + + +def _tf_tensor_list_set_item(target, i, x): + """Overload of set_item that stages a Tensor list update.""" + return list_ops.tensor_list_set_item(target, i, x) + + +def _py_set_item(target, i, x): + """Overload of set_item that executes a Python list modification.""" + target[i] = x + return target diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d4aacb9d2015fec56a8df5ad85a20b733765ba26 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/slices_test.py @@ -0,0 +1,51 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for slices module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.operators import slices +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import list_ops +from tensorflow.python.platform import test + + +class SlicesTest(test.TestCase): + + def test_set_item_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + l = slices.set_item(l, 0, [5, 6]) + + with self.test_session() as sess: + t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype) + self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]]) + + def test_get_item_tensor_list(self): + initial_list = constant_op.constant([[1, 2], [3, 4]]) + elem_shape = constant_op.constant([2]) + l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape) + t = slices.get_item( + l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype)) + + with self.test_session() as sess: + self.assertAllEqual(sess.run(t), [3, 4]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/autograph/pyct/BUILD b/tensorflow/contrib/autograph/pyct/BUILD index 796ab445c74128e1123e24b67c288e0e3c5ca24c..989b821e53a5cefbe39095e669f9a9e0bec65b8a 100644 --- a/tensorflow/contrib/autograph/pyct/BUILD +++ b/tensorflow/contrib/autograph/pyct/BUILD @@ -130,6 +130,7 @@ py_test( name = "transformer_test", srcs = ["transformer_test.py"], srcs_version = "PY2AND3", + tags = ["no_windows"], deps = [ ":pyct", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py index cc4a7edf02ed7556c9a552d8730e4c7875038c83..ae861627fd65cca057e7bf1af41424e605d4b7a1 100644 --- a/tensorflow/contrib/autograph/pyct/anno.py +++ b/tensorflow/contrib/autograph/pyct/anno.py @@ -46,8 +46,15 @@ class Basic(NoValue): '`name_map` allows renaming symbols.') -def getanno(node, key, field_name='___pyct_anno'): - return getattr(node, field_name)[key] +FAIL = object() + + +def getanno(node, key, default=FAIL, field_name='___pyct_anno'): + if (default is FAIL or + (hasattr(node, field_name) and (key in getattr(node, field_name)))): + return getattr(node, field_name)[key] + else: + return default def hasanno(node, key, field_name='___pyct_anno'): @@ -73,5 +80,9 @@ def delanno(node, key, field_name='___pyct_anno'): def copyanno(from_node, to_node, key, field_name='___pyct_anno'): - if hasanno(from_node, key, field_name): - setanno(to_node, key, getanno(from_node, key, field_name), field_name) + if hasanno(from_node, key, field_name=field_name): + setanno( + to_node, + key, + getanno(from_node, key, field_name=field_name), + field_name=field_name) diff --git a/tensorflow/contrib/autograph/pyct/anno_test.py b/tensorflow/contrib/autograph/pyct/anno_test.py index 1d4d9d119e0c45c4bf9dd4e5b8156766489a2e4d..f2c0c8cf05ca4b3671eb653ce56f6da61de54aee 100644 --- a/tensorflow/contrib/autograph/pyct/anno_test.py +++ b/tensorflow/contrib/autograph/pyct/anno_test.py @@ -38,12 +38,14 @@ class AnnoTest(test.TestCase): anno.setanno(node, 'foo', 3) self.assertTrue(anno.hasanno(node, 'foo')) - self.assertEqual(3, anno.getanno(node, 'foo')) + self.assertEqual(anno.getanno(node, 'foo'), 3) + self.assertEqual(anno.getanno(node, 'bar', default=7), 7) anno.delanno(node, 'foo') self.assertFalse(anno.hasanno(node, 'foo')) with self.assertRaises(AttributeError): anno.getanno(node, 'foo') + self.assertIsNone(anno.getanno(node, 'foo', default=None)) def test_copyanno(self): node_1 = ast.Name() diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py index 583cf7ecd7bce31c55de58361ab5295abb5d6707..da07013cf4f4309c0e24adda3017575d942861b7 100644 --- a/tensorflow/contrib/autograph/pyct/qual_names.py +++ b/tensorflow/contrib/autograph/pyct/qual_names.py @@ -205,6 +205,7 @@ class QnResolver(gast.NodeTransformer): return node def visit_Subscript(self, node): + # TODO(mdan): This may no longer apply if we overload getitem. node = self.generic_visit(node) s = node.slice if not isinstance(s, gast.Index): @@ -216,7 +217,11 @@ class QnResolver(gast.NodeTransformer): elif isinstance(s.value, gast.Str): subscript = QN(StringLiteral(s.value.s)) else: - subscript = anno.getanno(node.slice.value, anno.Basic.QN) + # The index may be an expression, case in which a name doesn't make sense. + if anno.hasanno(node.slice.value, anno.Basic.QN): + subscript = anno.getanno(node.slice.value, anno.Basic.QN) + else: + return node if anno.hasanno(node.value, anno.Basic.QN): anno.setanno(node, anno.Basic.QN, QN(anno.getanno(node.value, anno.Basic.QN), diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index c00946f9c41bc68d5c638d71f356b484db1286d1..7d1e65c958d7787ef5ed707d4822d14a83092975 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -17,8 +17,8 @@ This analyzer uses known live values to further infer object types. This may include for instance constructed objects and object member functions. -In addition, the analyzer will also process annotations for TF (staged) type -annotations. +In addition, the analyzer also handles user annotations made in the code (for +example, the autograph.set_element_type function). Requires annotations generated by LiveValuesResolver. """ @@ -44,6 +44,7 @@ from __future__ import print_function import gast from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -136,14 +137,14 @@ class TypeInfoResolver(transformer.Base): def _process_function_arg(self, arg_name): str_name = str(arg_name) + type_holder = arg_name.ast() + self.scope.setval(arg_name, type_holder) if len(self.enclosing_entities) == 1 and str_name in self.context.arg_types: # Forge a node to hold the type information, so that method calls on # it can resolve the type. - type_holder = arg_name.ast() type_string, type_obj = self.context.arg_types[str_name] anno.setanno(type_holder, 'type', type_obj) anno.setanno(type_holder, 'type_fqn', tuple(type_string.split('.'))) - self.scope.setval(arg_name, type_holder) def visit_arg(self, node): self._process_function_arg(anno.getanno(node.arg, anno.Basic.QN)) @@ -159,58 +160,47 @@ class TypeInfoResolver(transformer.Base): # a = b # then for future references to `a` we should have definition = `b` definition = self.scope.getval(qn) - if anno.hasanno(definition, 'type'): - anno.setanno(node, 'type', anno.getanno(definition, 'type')) - anno.setanno(node, 'type_fqn', anno.getanno(definition, 'type_fqn')) - if anno.hasanno(definition, 'element_type'): - anno.setanno(node, 'element_type', - anno.getanno(definition, 'element_type')) + anno.copyanno(definition, node, 'type') + anno.copyanno(definition, node, 'type_fqn') + anno.copyanno(definition, node, 'element_type') + anno.copyanno(definition, node, 'element_shape') return node - def _process_variable_assignment(self, source, targets): - # Special case: constructors. - if isinstance(source, gast.Call): - func = source.func + def _process_variable_assignment(self, target, value): + # Constructors + if isinstance(value, gast.Call): + func = value.func if anno.hasanno(func, 'live_val'): func_obj = anno.getanno(func, 'live_val') if tf_inspect.isclass(func_obj): - anno.setanno(source, 'is_constructor', True) - anno.setanno(source, 'type', func_obj) - anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) + anno.setanno(value, 'is_constructor', True) + anno.setanno(value, 'type', func_obj) + anno.setanno(value, 'type_fqn', anno.getanno(func, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. - # Multiple targets mean multiple assignment. - for target in targets: - # Tuple target means unpacking. - if isinstance(target, (gast.Tuple, gast.List)): - for i, target_item in enumerate(target.elts): - # Two cases here: - # 1. Static unpacking, e.g. a, b = c, d - # 2. Dynamic unpacking, e.g. a, b = c - # The former case is optimized away. - if isinstance(source, (gast.Tuple, gast.List)): - source_item = source.elts[i] - else: - source_item = gast.Subscript(source, gast.Index(i), ctx=None) - self._process_variable_assignment(source_item, (target_item,)) - elif isinstance(target, (gast.Name, gast.Attribute)): - target_symbol = anno.getanno(target, anno.Basic.QN) - self.scope.setval(target_symbol, source) - else: - raise ValueError('assignment target has unknown type: %s' % target) + if isinstance(target, (gast.Name, gast.Attribute)): + target_symbol = anno.getanno(target, anno.Basic.QN) + self.scope.setval(target_symbol, value) + elif isinstance(target, gast.Subscript): + pass + else: + raise ValueError('assignment target has unknown type: %s' % target) def visit_With(self, node): - for wi in node.items: - if wi.optional_vars is not None: - self._process_variable_assignment(wi.context_expr, (wi.optional_vars,)) + for item in node.items: + if item.optional_vars is not None: + self.apply_to_single_assignments((item.optional_vars,), + item.context_expr, + self._process_variable_assignment) self.generic_visit(node) return node def visit_Assign(self, node): self.generic_visit(node) - self._process_variable_assignment(node.value, node.targets) + self.apply_to_single_assignments( + node.targets, node.value, self._process_variable_assignment) return node def visit_Call(self, node): @@ -220,23 +210,20 @@ class TypeInfoResolver(transformer.Base): if (anno.getanno(node.func, 'live_val') is self.context.type_annotation_func): - if len(node.args) != 2: - raise ValueError('"%s" must have exactly two parameters' + if len(node.args) < 2 or len(node.args) > 3: + raise ValueError('"%s" must have either two or three parameters' % self.context.type_annotation_func) - target_arg, type_arg = node.args + if len(node.args) == 2: + target_arg, type_arg = node.args + shape_arg = parser.parse_expression('None') + else: + target_arg, type_arg, shape_arg = node.args if not anno.hasanno(target_arg, anno.Basic.QN): raise ValueError('the first argument of "%s" must by a symbol' % self.context.type_annotation_func) - if isinstance(type_arg, gast.Str): - element_type = type_arg.s - elif isinstance(type_arg, gast.Num): - element_type = type_arg.n - else: - if not anno.hasanno(type_arg, 'live_val'): - raise ValueError( - 'the second argument of "%s" must be statically resolvable' % - self.context.type_annotation_func) - element_type = anno.getanno(type_arg, 'live_val') + # TODO(mdan): This is vulnerable to symbol renaming. + element_type = type_arg + element_shape = shape_arg target_symbol = anno.getanno(target_arg, anno.Basic.QN) # Find the definition of this symbol and annotate it with the given @@ -244,7 +231,9 @@ class TypeInfoResolver(transformer.Base): # to receive the same type annotation. definition = self.scope.getval(target_symbol) anno.setanno(node, 'element_type', element_type) + anno.setanno(node, 'element_shape', element_shape) anno.setanno(definition, 'element_type', element_type) + anno.setanno(definition, 'element_shape', element_shape) # TODO(mdan): Should we update references between definition and here? return self.generic_visit(node) diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py index 46b7701624a43073fb7cc612d2678ab851513d91..484562f294bb53a63feeca965b8f94c58aa2a685 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info_test.py @@ -187,14 +187,27 @@ class TypeInfoResolverTest(test.TestCase): def test_fn(): f = [] - f = utils.set_element_type(f, Foo) + f = utils.set_element_type(f, Foo, (1, 2, 3)) return f node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) f_def = node.body[0].body[0].value - self.assertEqual(anno.getanno(f_def, 'element_type'), Foo) + self.assertEqual(anno.getanno(f_def, 'element_type').id, 'Foo') f_ref = node.body[0].body[1].value - self.assertEqual(anno.getanno(f_ref, 'element_type'), Foo) + self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo') + + def test_type_annotation_args(self): + + class Foo(object): + pass + + def test_fn(f): + utils.set_element_type(f, Foo) + return f + + node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'utils': utils}) + f_ref = node.body[0].body[1].value + self.assertEqual(anno.getanno(f_ref, 'element_type').id, 'Foo') def test_nested_unpacking(self): @@ -210,9 +223,9 @@ class TypeInfoResolverTest(test.TestCase): node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar}) a, b, c = node.body[0].body[1].value.elts - self.assertEquals(Foo, anno.getanno(a, 'type')) - self.assertEquals(Bar, anno.getanno(b, 'type')) - self.assertEquals(Foo, anno.getanno(c, 'type')) + self.assertEquals(anno.getanno(a, 'type'), Foo) + self.assertEquals(anno.getanno(b, 'type'), Bar) + self.assertEquals(anno.getanno(c, 'type'), Foo) self.assertFalse(anno.hasanno(a, 'live_val')) self.assertFalse(anno.hasanno(b, 'live_val')) self.assertFalse(anno.hasanno(c, 'live_val')) @@ -229,8 +242,8 @@ class TypeInfoResolverTest(test.TestCase): node = self._parse_and_analyze(test_fn, {'utils': utils}) a, b = node.body[0].body[2].body[2].value.elts - self.assertEquals(1, anno.getanno(a, 'element_type')) - self.assertEquals(2, anno.getanno(b, 'element_type')) + self.assertEquals(anno.getanno(a, 'element_type').n, 1) + self.assertEquals(anno.getanno(b, 'element_type').n, 2) self.assertFalse(anno.hasanno(a, 'type')) self.assertFalse(anno.hasanno(b, 'type')) self.assertFalse(anno.hasanno(a, 'live_val')) diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py index baf7923fff7c786c1abd05e11fa6ffdb8c8f0912..9c479ebc2fa83d27dc363ae306daedb556734a1f 100644 --- a/tensorflow/contrib/autograph/pyct/templates.py +++ b/tensorflow/contrib/autograph/pyct/templates.py @@ -239,8 +239,13 @@ def replace_as_expression(template, **replacements): raise ValueError( 'single expression expected; for more general templates use replace') node = replacement[0] - if not isinstance(node, gast.Expr): - raise ValueError( - 'the template is expected to generate an expression node; instead ' - 'found %s' % node) - return node.value + node = qual_names.resolve(node) + + if isinstance(node, gast.Expr): + return node.value + elif isinstance(node, gast.Name): + return node + + raise ValueError( + 'the template is expected to generate an expression or a name node;' + ' instead found %s' % node) diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py index 4db6cc0adfad90ffc1a6bbcadfc80215688d271e..60bca8b38dcf62b4e997379d075cfc45511a894f 100644 --- a/tensorflow/contrib/autograph/pyct/transformer.py +++ b/tensorflow/contrib/autograph/pyct/transformer.py @@ -70,14 +70,40 @@ class Base(gast.NodeTransformer): return tuple(self._enclosing_entities) @property - def locel_scope_level(self): + def local_scope_level(self): return len(self._local_scope_state) - def enter_local_scope(self): - self._local_scope_state.append({}) + def enter_local_scope(self, inherit=None): + """Marks entry into a new local scope. - def exit_local_scope(self): - return self._local_scope_state.pop() + Args: + inherit: Optional enumerable of variable names to copy from the + parent scope. + """ + scope_entered = {} + if inherit: + this_scope = self._local_scope_state[-1] + for name in inherit: + if name in this_scope: + scope_entered[name] = this_scope[name] + self._local_scope_state.append(scope_entered) + + def exit_local_scope(self, keep=None): + """Marks exit from the current local scope. + + Args: + keep: Optional enumerable of variable names to copy into the + parent scope. + Returns: + A dict containing the scope that has just been exited. + """ + scope_left = self._local_scope_state.pop() + if keep: + this_scope = self._local_scope_state[-1] + for name in keep: + if name in scope_left: + this_scope[name] = scope_left[name] + return scope_left def set_local(self, name, value): self._local_scope_state[-1][name] = value @@ -91,38 +117,163 @@ class Base(gast.NodeTransformer): print(pretty_printer.fmt(node)) return node - def visit_block(self, nodes): - """Helper equivalent to generic_visit, but for node lists.""" + def visit_block(self, nodes, before_visit=None, after_visit=None): + """A more powerful version of generic_visit for statement blocks. + + An example of a block is the body of an if statement. + + This function allows specifying a postprocessing callback (the + after_visit argument) argument which can be used to move nodes to a new + destination. This is done by after_visit by returning a non-null + second return value, e.g. return new_node, new_destination. + + For example, a transformer could perform the following move: + + foo() + bar() + baz() + + foo() + if cond: + bar() + baz() + + The above could be done with a postprocessor of this kind: + + def after_visit(node): + if node_is_function_call(bar): + new_container_node = build_cond() + new_container_node.body.append(node) + return new_container_node, new_container_node.body + else: + # Once we set a new destination, all subsequent items will be + # moved to it, so we don't need to explicitly handle baz. + return node, None + + Args: + nodes: enumerable of AST node objects + before_visit: optional callable that is called before visiting each item + in nodes + after_visit: optional callable that takes in an AST node and + returns a tuple (new_node, new_destination). It is called after + visiting each item in nodes. Is used in the same was as the + visit_* methods: new_node will replace the node; if not None, + new_destination must be a list, and subsequent nodes will be placed + in this list instead of the list returned by visit_block. + Returns: + A list of AST node objects containing the transformed items fron nodes, + except those nodes that have been relocated using after_visit. + """ results = [] + node_destination = results for node in nodes: + if before_visit: + # TODO(mdan): We can modify node here too, if ever needed. + before_visit() + replacement = self.visit(node) + + if after_visit and replacement: + replacement, new_destination = after_visit(replacement) + else: + new_destination = None + if replacement: if isinstance(replacement, (list, tuple)): - results.extend(replacement) + node_destination.extend(replacement) else: - results.append(replacement) + node_destination.append(replacement) + + # Allow the postprocessor to reroute the remaining nodes to a new list. + if new_destination is not None: + node_destination = new_destination return results + # TODO(mdan): Once we have error tracing, we may be able to just go to SSA. + def apply_to_single_assignments(self, targets, values, apply_fn): + """Applies a fuction to each individual assignment. + + This function can process a possibly-unpacked (e.g. a, b = c, d) assignment. + It tries to break down the unpacking if possible. In effect, it has the same + effect as passing the assigned values in SSA form to apply_fn. + + Examples: + + The following will result in apply_fn(a, c), apply_fn(b, d): + + a, b = c, d + + The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]): + + a, b = c + + The following will result in apply_fn(a, (b, c)): + + a = b, c + + It uses the visitor pattern to allow subclasses to process single + assignments individually. + + Args: + targets: list, tuple of or individual AST node. Should be used with the + targets field of an ast.Assign node. + values: an AST node. + apply_fn: a function of a single argument, which will be called with the + respective nodes of each single assignment. The signaure is + apply_fn(target, value), no return value. + """ + if not isinstance(targets, (list, tuple)): + targets = (targets,) + for target in targets: + if isinstance(target, (gast.Tuple, gast.List)): + for i in range(len(target.elts)): + target_el = target.elts[i] + if isinstance(values, (gast.Tuple, gast.List)): + value_el = values.elts[i] + else: + value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store()) + self.apply_to_single_assignments(target_el, value_el, apply_fn) + else: + # TODO(mdan): Look into allowing to rewrite the AST here. + apply_fn(target, values) + def visit(self, node): source_code = self.context.source_code source_file = self.context.source_file did_enter_function = False - local_scope_state_size = len(self._local_scope_state) + local_scope_size_at_entry = len(self._local_scope_state) try: if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)): - self._enclosing_entities.append(node) did_enter_function = True + if did_enter_function: + self._enclosing_entities.append(node) + if source_code and hasattr(node, 'lineno'): self._lineno = node.lineno self._col_offset = node.col_offset - if anno.hasanno(node, anno.Basic.SKIP_PROCESSING): - return node - return super(Base, self).visit(node) - except (ValueError, AttributeError, KeyError, NotImplementedError, - AssertionError) as e: + if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING): + result = super(Base, self).visit(node) + + # On exception, the local scope integrity is not guaranteed. + if did_enter_function: + self._enclosing_entities.pop() + + if local_scope_size_at_entry != len(self._local_scope_state): + raise AssertionError( + 'Inconsistent local scope stack. Before entering node %s, the' + ' stack had length %d, after exit it has length %d. This' + ' indicates enter_local_scope and exit_local_scope are not' + ' well paired.' % ( + node, + local_scope_size_at_entry, + len(self._local_scope_state) + )) + return result + + except (ValueError, AttributeError, KeyError, NotImplementedError) as e: msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % ( e.__class__.__name__, str(e), try_ast_to_source(node), pretty_printer.fmt(node, color=False)) @@ -130,18 +281,11 @@ class Base(gast.NodeTransformer): line = source_code.splitlines()[self._lineno - 1] else: line = '' + # TODO(mdan): Avoid the printing of the original exception. + # In other words, we need to find how to suppress the "During handling + # of the above exception, another exception occurred" message. six.reraise(AutographParseError, AutographParseError( msg, (source_file, self._lineno, self._col_offset + 1, line)), sys.exc_info()[2]) - finally: - if did_enter_function: - self._enclosing_entities.pop() - - if local_scope_state_size != len(self._local_scope_state): - raise AssertionError( - 'Inconsistent local scope stack. Before entering node %s, the' - ' stack had length %d, after exit it has length %d. This' - ' indicates enter_local_scope and exit_local_scope are not' - ' well paired.') diff --git a/tensorflow/contrib/autograph/pyct/transformer_test.py b/tensorflow/contrib/autograph/pyct/transformer_test.py index f96b0dc377521a482d347436caa98633a0a32c8a..f110e79605945e908e8a49112cf758ec29fa1b11 100644 --- a/tensorflow/contrib/autograph/pyct/transformer_test.py +++ b/tensorflow/contrib/autograph/pyct/transformer_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gast + from tensorflow.contrib.autograph.pyct import anno from tensorflow.contrib.autograph.pyct import context from tensorflow.contrib.autograph.pyct import parser @@ -27,7 +29,7 @@ from tensorflow.python.platform import test class TransformerTest(test.TestCase): - def _context_for_nodetesting(self): + def _context_for_testing(self): return context.EntityContext( namer=None, source_code=None, @@ -53,7 +55,7 @@ class TransformerTest(test.TestCase): anno.setanno(node, 'enclosing_entities', self.enclosing_entities) return self.generic_visit(node) - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._context_for_testing()) def test_function(): a = 0 @@ -94,7 +96,7 @@ class TransformerTest(test.TestCase): inner_function, lambda_node), anno.getanno(lambda_expr, 'enclosing_entities')) - def test_statement_info_stack(self): + def test_local_scope_info_stack(self): class TestTransformer(transformer.Base): @@ -116,7 +118,7 @@ class TransformerTest(test.TestCase): def visit_For(self, node): return self._annotate_result(node) - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._context_for_testing()) def test_function(a): """Docstring.""" @@ -142,7 +144,7 @@ class TransformerTest(test.TestCase): self.assertFalse(anno.hasanno(while_node, 'string')) self.assertEqual('1', anno.getanno(while_node, 'test')) - def test_statement_info_stack_checks_integrity(self): + def test_local_scope_info_stack_checks_integrity(self): class TestTransformer(transformer.Base): @@ -155,7 +157,7 @@ class TransformerTest(test.TestCase): self.exit_local_scope() return node - tr = TestTransformer(self._context_for_nodetesting()) + tr = TestTransformer(self._context_for_testing()) def no_exit(a): if a > 0: @@ -174,6 +176,38 @@ class TransformerTest(test.TestCase): with self.assertRaises(AssertionError): tr.visit(node) + def test_visit_block_postprocessing(self): + + class TestTransformer(transformer.Base): + + def _process_body_item(self, node): + if isinstance(node, gast.Assign) and (node.value.id == 'y'): + if_node = gast.If(gast.Name('x', gast.Load(), None), [node], []) + return if_node, if_node.body + return node, None + + def visit_FunctionDef(self, node): + node.body = self.visit_block( + node.body, after_visit=self._process_body_item) + return node + + def test_function(x, y): + z = x + z = y + return z + + tr = TestTransformer(self._context_for_testing()) + + node, _ = parser.parse_entity(test_function) + node = tr.visit(node) + node = node.body[0] + + self.assertEqual(len(node.body), 2) + self.assertTrue(isinstance(node.body[0], gast.Assign)) + self.assertTrue(isinstance(node.body[1], gast.If)) + self.assertTrue(isinstance(node.body[1].body[0], gast.Assign)) + self.assertTrue(isinstance(node.body[1].body[1], gast.Return)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD index d3a1b9468892531cbc51bc13de66ef595f1a95f8..d82c17bf2afd01aedf4344f983b02c09abcb9bad 100644 --- a/tensorflow/contrib/autograph/utils/BUILD +++ b/tensorflow/contrib/autograph/utils/BUILD @@ -33,6 +33,8 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ + "//tensorflow/contrib/autograph/pyct", + "//tensorflow/python:dtypes", "//tensorflow/python:list_ops", "//tensorflow/python:script_ops", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py index 211e8eaee9082dd3e4f035e4379871cd2e154a39..998087e056c2cd264399982220d6e0528aab9edb 100644 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib.autograph.utils import py_func from tensorflow.contrib.autograph.utils import type_check +from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops @@ -38,7 +39,13 @@ def dynamic_builtin(f, *args, **kwargs): return dynamic_range(*args, **kwargs) if f is range: return dynamic_range(*args, **kwargs) - raise ValueError('%s is not supported' % f) + if f is int: + return dynamic_int(*args, **kwargs) + if f is float: + return dynamic_float(*args, **kwargs) + + raise NotImplementedError( + 'The "%s" builtin is not yet supported.' % f.__name__) def dynamic_len(list_or_tensor): @@ -52,6 +59,20 @@ def dynamic_len(list_or_tensor): return len(list_or_tensor) +def dynamic_int(num_or_tensor, **kwargs): + """Implementation of int() using dynamic dispatch.""" + if tensor_util.is_tensor(num_or_tensor): + return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs) + return int(num_or_tensor) + + +def dynamic_float(num_or_tensor, **kwargs): + """Implementation of float() using dynamic dispatch.""" + if tensor_util.is_tensor(num_or_tensor): + return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs) + return float(num_or_tensor) + + def dynamic_range(start_or_stop, stop=None, step=None): """Implementation of range using dynamic dispatch.""" if type_check.is_tensor(start_or_stop, stop, step): diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py index 163e6984079fea5c3b3d9aeda0ec8048d651686f..0c2312178a921037fa419818bf309d671c33914d 100644 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib.autograph.utils import builtins from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.platform import test @@ -77,7 +78,7 @@ class BuiltinsTest(test.TestCase): return x # Functions that just have the names of builtins are rejected. - with self.assertRaises(ValueError): + with self.assertRaises(NotImplementedError): self.assertEqual(builtins.dynamic_builtin(range, 1), 1) if six.PY2: self.assertListEqual( @@ -87,6 +88,20 @@ class BuiltinsTest(test.TestCase): self.assertListEqual( list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2]) + def test_casts(self): + i = constant_op.constant(2, dtype=dtypes.int32) + f = constant_op.constant(1.0, dtype=dtypes.float32) + + self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32) + self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32) + self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32) + self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32) + + self.assertEqual(builtins.dynamic_builtin(int, True), 1) + self.assertEqual(builtins.dynamic_builtin(int, False), 0) + self.assertEqual(builtins.dynamic_builtin(float, True), 1.0) + self.assertEqual(builtins.dynamic_builtin(float, False), 0.0) + def test_dynamic_print_tf(self): try: out_capturer = six.StringIO() diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index d65c990c87cbc316472237d183c03765416501e7..b27a19b16c08cb588b45949105a6399623e766e1 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -49,6 +49,14 @@ cc_library( ], ) +cc_library( + name = "serial_device_batch_scheduler", + hdrs = ["serial_device_batch_scheduler.h"], + deps = [ + "//tensorflow/core/kernels/batching_util:serial_device_batch_scheduler", + ], +) + cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], @@ -96,6 +104,7 @@ py_test( name = "batch_ops_test", size = "small", srcs = ["python/ops/batch_ops_test.py"], + shard_count = 5, srcs_version = "PY2AND3", tags = [ "manual", diff --git a/tensorflow/contrib/batching/__init__.py b/tensorflow/contrib/batching/__init__.py index 44fa5f42a73bfb1bf008f6f4eafd14913c88dcfa..1e503a097a7b72d9244b0a1cf57747c4b4122c81 100644 --- a/tensorflow/contrib/batching/__init__.py +++ b/tensorflow/contrib/batching/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Ops and modules related to batch. +@@batch_function_v1 @@batch_function """ from __future__ import absolute_import diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index 921d6917a4e478c3e60771fdc3ae99febc33d2e3..012a51f71101471850d312033c41dcbc4805d44c 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import @@ -83,6 +84,74 @@ def batch_function(num_batch_threads, SparseTensor is not supported. The return value of the decorated function must be a Tensor or a list/tuple of Tensors. + Args: + num_batch_threads: Number of scheduling threads for processing batches + of work. Determines the number of batches processed in parallel. + max_batch_size: Batch sizes will never be bigger than this. + batch_timeout_micros: Maximum number of microseconds to wait before + outputting an incomplete batch. + allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, + does nothing. Otherwise, supplies a list of batch sizes, causing the op + to pad batches up to one of those sizes. The entries must increase + monotonically, and the final entry must equal max_batch_size. + grad_timeout_micros: The timeout to use for the gradient. See the + documentation of the unbatch op for more details. Defaults to 60s. + unbatch_timeout_micros: The timeout to use for unbatching. See the + documentation of the unbatch op for more details. Defaults to 60s. + max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. + + Returns: + The decorated function will return the unbatched computation output Tensors. + """ + + def decorator(fn): # pylint: disable=missing-docstring + + def decorated(*args): # pylint: disable=missing-docstring + types = [arg.dtype for arg in args] + + @function.Defun(*types) + def computation(*computation_args): + return fn(*computation_args) + + with ops.name_scope("batch") as name: + for a in args: + if not isinstance(a, ops.Tensor): + raise ValueError("All arguments to functions decorated with " + "`batch_function` are supposed to be Tensors; " + "found %s" % repr(a)) + for inp in computation.captured_inputs: + print("inp: %s" % inp) + for op in inp.consumers(): + print("op: %s" % op) + return gen_batch_ops.batch_function( + num_batch_threads=num_batch_threads, + max_batch_size=max_batch_size, + batch_timeout_micros=batch_timeout_micros, + allowed_batch_sizes=allowed_batch_sizes, + max_enqueued_batches=max_enqueued_batches, + shared_name=name, + f=computation, + in_tensors=list(args), + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + return decorated + + return decorator + + +def batch_function_v1(num_batch_threads, + max_batch_size, + batch_timeout_micros, + allowed_batch_sizes=None, + grad_timeout_micros=60 * 1000 * 1000, + unbatch_timeout_micros=60 * 1000 * 1000, + max_enqueued_batches=10): + """Batches the computation done by the decorated function. + + This is the older version of batch_function(). Please use the former instead + of this. + Args: num_batch_threads: Number of scheduling threads for processing batches of work. Determines the number of batches processed in parallel. diff --git a/tensorflow/contrib/batching/python/ops/batch_ops_test.py b/tensorflow/contrib/batching/python/ops/batch_ops_test.py index e22f978dde6f1b7febc771d526201579c20292c7..78468145469df216344bc00f116add250dc51dd3 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops_test.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops_test.py @@ -23,7 +23,10 @@ import time from tensorflow.contrib.batching.python.ops import batch_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework.errors import InvalidArgumentError from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_batch_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -185,12 +188,62 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBasicUnbatchV1Decorated(self): + """Tests that the batch_function_v1 decorator works.""" + with self.test_session() as sess: + @batch_ops.batch_function_v1(1, 10, 100000) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testBasicUnbatchDecorated(self): """Tests that the batch_function decorator works.""" with self.test_session() as sess: + # TODO(apassos): Removing this line causes test flakiness! Ideally should + # be investigated. + default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable + @batch_ops.batch_function(1, 10, 100000) def computation(in_t): return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchDecoratedWithCapturedInput(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return in_t + captured_inp0 - captured_inp1 + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) result = computation(inp) thread_results = [] @@ -205,6 +258,114 @@ class BatchOpsTest(test.TestCase): self.assertEqual(thread_results[0], [2]) self.assertEqual(main_results[0], [3]) + def testBatchFunctionOp(self): + """Tests that the batch_function op works.""" + with self.test_session() as sess: + + @function.Defun(dtypes.int32) + def computation(in_t): + return in_t + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + result = gen_batch_ops.batch_function( + [inp], + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, + Tout=[dtypes.int32], + f=computation, + captured_tensors=computation.captured_inputs) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchFunctionOpWithCapturedInput(self): + """Tests that batch_function op works with captured input.""" + with self.test_session() as sess: + captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) + captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32) + def computation(inp): + return inp + captured_inp0 - captured_inp1 + + result = gen_batch_ops.batch_function( + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + allowed_batch_sizes=[3, 10], + batching_queue="", + f=computation, + in_tensors=[inp], + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [1]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [2]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + + def testBatchFunctionOpWithInputError(self): + """Tests that batch_function op works with error in the inputs.""" + with self.test_session() as sess: + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) + + @function.Defun(dtypes.int32, dtypes.int32) + def computation(in0, in1): + return in0 + in1 + + result = gen_batch_ops.batch_function( + [inp], # computation actually expects 2 inputs. + num_batch_threads=1, + max_batch_size=10, + batch_timeout_micros=100000, # 100ms + batching_queue="", + f=computation, + captured_tensors=computation.captured_inputs, + Tout=[o.type for o in computation.definition.signature.output_arg]) + + with self.assertRaisesRegexp(InvalidArgumentError, + ".*2 arguments.*but 1.*"): + sess.run([result], feed_dict={inp: [2]}) + + def testBasicUnbatchDecoratedWithReshape(self): + """Tests that the batch_function decorator works.""" + with self.test_session() as sess: + + @batch_ops.batch_function(1, 10, 100000) + def computation(in_t): + return array_ops.reshape(in_t, [-1]) + 1 + + inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1]) + result = computation(inp) + thread_results = [] + + def worker(): + thread_results.extend(sess.run([result], feed_dict={inp: [[1]]})) + + worker_thread = threading.Thread(target=worker) + worker_thread.start() + main_results = sess.run([result], feed_dict={inp: [[2]]}) + worker_thread.join() + self.assertEqual(thread_results[0], [2]) + self.assertEqual(main_results[0], [3]) + def testUnbatchTimeout(self): """Tests that the unbatch timeout works.""" with self.test_session() as sess: diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.cc b/tensorflow/contrib/batching/serial_device_batch_scheduler.h similarity index 59% rename from tensorflow/compiler/xla/service/versioned_computation_handle.cc rename to tensorflow/contrib/batching/serial_device_batch_scheduler.h index a693c4695f0e776cf297d0ecd28d6de53bd5c0c6..bf6b7083612018eecf0d1784e60cbbf0c5796fef 100644 --- a/tensorflow/compiler/xla/service/versioned_computation_handle.cc +++ b/tensorflow/contrib/batching/serial_device_batch_scheduler.h @@ -13,20 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" +#ifndef TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ -#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h" -namespace xla { - -string VersionedComputationHandle::ToString() const { - return tensorflow::strings::StrCat(handle.handle(), ":v", version); -} - -std::ostream& operator<<(std::ostream& out, - const VersionedComputationHandle& versioned_handle) { - out << versioned_handle.ToString(); - return out; -} - -} // namespace xla +#endif // TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py index 5770bcdd706723394bb06196d24aeb32b8b8491a..68fa415eeaf1d1ae7c2ecf1be1c300eddbfa4e69 100644 --- a/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py +++ b/tensorflow/contrib/bayesflow/python/ops/monte_carlo.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Monte Carlo integration and helpers. - -See the @{$python/contrib.bayesflow.monte_carlo} guide. -""" +"""Monte Carlo integration and helpers.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py index 758754feac31f1d2cf10e69d7a9a6d288931c900..911d87fa10570382ee5f03edfc1bfd1d116c8360 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -232,7 +232,13 @@ def _dnn_tree_combined_model_fn(features, return update_op if predict_with_tree_only: - tree_train_logits = tree_logits + if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.PREDICT: + tree_train_logits = tree_logits + else: + tree_train_logits = control_flow_ops.cond( + global_step > dnn_steps_to_train, + lambda: tree_logits, + lambda: dnn_logits) else: tree_train_logits = dnn_logits + tree_logits diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index 89d0d611d2905492cec09e033b8cbc238ec7fac6..9c36c302210185bc390751a0229a61f2f8cd91b8 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -41,7 +41,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - use_core_libs=False): + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeClassifier estimator instance. Args: @@ -66,6 +67,16 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): the bias. use_core_libs: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. It is a Tensor of rank 2 and its shape is + [batch_size, num_trees]. + For example, + result_iter = classifier.predict(...) + for result_dict in result_iter: + # access leaf index list by result_dict["leaf_index"] + # which contains one leaf index per tree + Raises: ValueError: If learner_config is not valid. """ @@ -74,7 +85,9 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): # supports second order derivative. def loss_fn(labels, logits, weights=None): result = losses.per_example_maxent_loss( - labels=labels, logits=logits, weights=weights, + labels=labels, + logits=logits, + weights=weights, num_classes=n_classes) return math_ops.reduce_mean(result[0]) else: @@ -102,6 +115,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): 'center_bias': center_bias, 'logits_modifier_function': logits_modifier_function, 'use_core_libs': use_core_libs, + 'output_leaf_index': output_leaf_index, }, model_dir=model_dir, config=config, @@ -124,7 +138,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - use_core_libs=False): + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeRegressor estimator instance. Args: @@ -151,6 +166,13 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): the bias. use_core_libs: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree """ head = head_lib.regression_head( label_name=label_name, @@ -173,6 +195,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'center_bias': center_bias, 'use_core_libs': use_core_libs, + 'output_leaf_index': False, }, model_dir=model_dir, config=config, @@ -197,7 +220,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): feature_engineering_fn=None, logits_modifier_function=None, center_bias=True, - use_core_libs=False): + use_core_libs=False, + output_leaf_index=False): """Initializes a GradientBoostedDecisionTreeEstimator estimator instance. Args: @@ -220,6 +244,13 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): the bias. use_core_libs: Whether feature columns and loss are from the core (as opposed to contrib) version of tensorflow. + output_leaf_index: whether to output leaf indices along with predictions + during inference. The leaf node indexes are available in predictions + dict by the key 'leaf_index'. For example, + result_dict = classifier.predict(...) + for example_prediction_result in result_dict: + # access leaf index list by example_prediction_result["leaf_index"] + # which contains one leaf index per tree """ super(GradientBoostedDecisionTreeEstimator, self).__init__( model_fn=model.model_builder, @@ -233,6 +264,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator): 'logits_modifier_function': logits_modifier_function, 'center_bias': center_bias, 'use_core_libs': use_core_libs, + 'output_leaf_index': False, }, model_dir=model_dir, config=config, diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 0d58317bd59331cfcde0e12aeb3a3a03fc45d89b..75ef1b050028b6462b255827c06e836e5c481844 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -68,6 +68,28 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): classifier.evaluate(input_fn=_eval_input_fn, steps=1) classifier.export(self._export_dir_base) + def testThatLeafIndexIsInPredictions(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=[contrib_feature_column.real_valued_column("x")], + output_leaf_index=True) + + classifier.fit(input_fn=_train_input_fn, steps=15) + result_iter = classifier.predict(input_fn=_eval_input_fn) + for prediction_dict in result_iter: + self.assertTrue("leaf_index" in prediction_dict) + self.assertTrue("logits" in prediction_dict) + def testFitAndEvaluateDontThrowExceptionWithCoreForEstimator(self): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 15ab6d814522ab1dee58dcd71246354fc4d8a483..1ee891198939e53fc5913104b2c2e65dc977823f 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -63,6 +63,8 @@ def model_builder(features, labels, mode, params, config): num_trees = params["num_trees"] use_core_libs = params["use_core_libs"] logits_modifier_function = params["logits_modifier_function"] + output_leaf_index = params["output_leaf_index"] + if features is None: raise ValueError("At least one feature must be specified.") @@ -96,7 +98,8 @@ def model_builder(features, labels, mode, params, config): feature_columns=feature_columns, logits_dimension=head.logits_dimension, features=training_features, - use_core_columns=use_core_libs) + use_core_columns=use_core_libs, + output_leaf_index=output_leaf_index) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] @@ -127,6 +130,9 @@ def model_builder(features, labels, mode, params, config): labels=labels, train_op_fn=_train_op_fn, logits=logits) + if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict: + model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[ + gbdt_batch.LEAF_INDEX] if num_trees: if center_bias: num_trees += 1 diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index b3fe38614e05801b223f0c96f7a70ce7e432a70b..9493c1a1394040db3b744f1b382b20bd5bd1988d 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -59,6 +59,7 @@ const char* kApplyDropoutAttributeName = "apply_dropout"; const char* kApplyAveragingAttributeName = "apply_averaging"; const char* kDropoutInfoOutputTensorName = "drop_out_tree_indices_weights"; const char* kPredictionsTensorName = "predictions"; +const char* kLeafIndexTensorName = "leaf_index"; void CalculateTreesToInclude( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, @@ -170,15 +171,22 @@ class GradientTreesPredictionOp : public OpKernel { core::ScopedUnref unref_me(ensemble_resource); if (use_locking_) { tf_shared_lock l(*ensemble_resource->get_mutex()); - DoCompute(context, ensemble_resource); + DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/false); } else { - DoCompute(context, ensemble_resource); + DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/false); } } - private: - void DoCompute(OpKernelContext* context, - DecisionTreeEnsembleResource* ensemble_resource) { + protected: + // return_output_leaf_index is a boolean variable indicating whether to output + // leaf index in prediction. Though this class invokes only with this param + // value as false, the subclass GradientTreesPredictionVerboseOp will invoke + // with the true value. + virtual void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource, + const bool return_output_leaf_index) { // Read dense float features list; OpInputList dense_float_features_list; OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures( @@ -267,6 +275,14 @@ class GradientTreesPredictionOp : public OpKernel { &output_predictions_t)); auto output_predictions = output_predictions_t->matrix(); + // Allocate output leaf index matrix. + Tensor* output_leaf_index_t = nullptr; + if (return_output_leaf_index) { + OP_REQUIRES_OK(context, context->allocate_output( + kLeafIndexTensorName, + {batch_size, ensemble_resource->num_trees()}, + &output_leaf_index_t)); + } // Run predictor. thread::ThreadPool* const worker_threads = context->device()->tensorflow_cpu_worker_threads()->workers; @@ -288,11 +304,13 @@ class GradientTreesPredictionOp : public OpKernel { i, weight * (num_ensembles - i + start_averaging) / num_ensembles); } MultipleAdditiveTrees::Predict(adjusted, trees_to_include, batch_features, - worker_threads, output_predictions); + worker_threads, output_predictions, + output_leaf_index_t); } else { MultipleAdditiveTrees::Predict( ensemble_resource->decision_tree_ensemble(), trees_to_include, - batch_features, worker_threads, output_predictions); + batch_features, worker_threads, output_predictions, + output_leaf_index_t); } // Output dropped trees and original weights. @@ -302,7 +320,6 @@ class GradientTreesPredictionOp : public OpKernel { {2, static_cast(dropped_trees.size())}, &output_dropout_info_t)); auto output_dropout_info = output_dropout_info_t->matrix(); - for (int32 i = 0; i < dropped_trees.size(); ++i) { output_dropout_info(0, i) = dropped_trees[i]; output_dropout_info(1, i) = original_weights[i]; @@ -326,6 +343,27 @@ class GradientTreesPredictionOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("GradientTreesPrediction").Device(DEVICE_CPU), GradientTreesPredictionOp); +// GradientTreesPredictionVerboseOp is derived from GradientTreesPredictionOp +// and have an additional output of tensor of rank 2 containing leaf ids for +// each tree where an instance ended up with. +class GradientTreesPredictionVerboseOp : public GradientTreesPredictionOp { + public: + explicit GradientTreesPredictionVerboseOp(OpKernelConstruction* const context) + : GradientTreesPredictionOp(context) {} + + protected: + void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource, + bool return_output_leaf_index) override { + GradientTreesPredictionOp::DoCompute(context, ensemble_resource, + /*return_output_leaf_index=*/true); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("GradientTreesPredictionVerbose").Device(DEVICE_CPU), + GradientTreesPredictionVerboseOp); + class GradientTreesPartitionExamplesOp : public OpKernel { public: explicit GradientTreesPartitionExamplesOp(OpKernelConstruction* const context) diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 04e32267cc4a00b3169c3abbcbf549805a0fb462..401bec84a20a0fefcddbfa1039a117e65f853633 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -43,47 +43,60 @@ namespace { const int32 DUMMY_FEATURE_DIMENSION = -1; } // namespace -class BaseBuildSplitOp : public OpKernel { +class SplitBuilderState { public: - explicit BaseBuildSplitOp(OpKernelConstruction* const context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->GetAttr("feature_column_group_id", - &feature_column_group_id_)); + explicit SplitBuilderState(OpKernelContext* const context) { + const Tensor* l1_regularization_t; OP_REQUIRES_OK(context, - context->GetAttr("l1_regularization", &l1_regularization_)); + context->input("l1_regularization", &l1_regularization_t)); + const Tensor* l2_regularization_t; OP_REQUIRES_OK(context, - context->GetAttr("l2_regularization", &l2_regularization_)); - OP_REQUIRES_OK(context, context->GetAttr("tree_complexity_regularization", - &tree_complexity_regularization_)); + context->input("l2_regularization", &l2_regularization_t)); + const Tensor* tree_complexity_regularization_t; + OP_REQUIRES_OK(context, context->input("tree_complexity_regularization", + &tree_complexity_regularization_t)); + const Tensor* min_node_weight_t; OP_REQUIRES_OK(context, - context->GetAttr("min_node_weight", &min_node_weight_)); + context->input("min_node_weight", &min_node_weight_t)); - int strategy; - OP_REQUIRES_OK(context, context->GetAttr("multiclass_strategy", &strategy)); + const Tensor* feature_column_group_id_t; + OP_REQUIRES_OK(context, context->input("feature_column_group_id", + &feature_column_group_id_t)); + + const Tensor* multiclass_strategy_t; + OP_REQUIRES_OK( + context, context->input("multiclass_strategy", &multiclass_strategy_t)); + int strategy = multiclass_strategy_t->scalar()(); OP_REQUIRES( context, boosted_trees::learner::LearnerConfig_MultiClassStrategy_IsValid( strategy), errors::InvalidArgument("Wrong multiclass strategy passed.")); - multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy); - } - NodeStats ComputeNodeStats(const GradientStats& grad_stats) { - return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_, - multiclass_strategy_, grad_stats); - } + multiclass_strategy_ = LearnerConfig_MultiClassStrategy(strategy); - void ReadClassId(OpKernelContext* const context, int32* class_id) { const Tensor* class_id_t; OP_REQUIRES_OK(context, context->input("class_id", &class_id_t)); OP_REQUIRES(context, TensorShapeUtils::IsScalar(class_id_t->shape()), errors::InvalidArgument("class_id must be a scalar.")); - *class_id = class_id_t->scalar()(); + class_id_ = class_id_t->scalar()(); + + l1_regularization_ = l1_regularization_t->scalar()(); + l2_regularization_ = l2_regularization_t->scalar()(); + tree_complexity_regularization_ = + tree_complexity_regularization_t->scalar()(); + min_node_weight_ = min_node_weight_t->scalar()(); + feature_column_group_id_ = feature_column_group_id_t->scalar()(); + } + + NodeStats ComputeNodeStats(const GradientStats& grad_stats) { + return NodeStats(l1_regularization_, l2_regularization_, min_node_weight_, + multiclass_strategy_, grad_stats); } - void FillLeaf(const int class_id, const NodeStats& best_node_stats, + void FillLeaf(const NodeStats& best_node_stats, boosted_trees::trees::Leaf* leaf) const { - if (class_id == -1) { + if (class_id_ == -1) { // This would be the case either for TREE_PER_CLASS with only 2 classes, // or for other multiclass strategies. for (float f : best_node_stats.weight_contribution) { @@ -93,25 +106,31 @@ class BaseBuildSplitOp : public OpKernel { CHECK(best_node_stats.weight_contribution.size() == 1) << "Weight contribution size = " << best_node_stats.weight_contribution.size(); - leaf->mutable_sparse_vector()->add_index(class_id); + leaf->mutable_sparse_vector()->add_index(class_id_); leaf->mutable_sparse_vector()->add_value( best_node_stats.weight_contribution[0]); } } - protected: + int32 feature_column_group_id() { return feature_column_group_id_; } + float tree_complexity_regularization() { + return tree_complexity_regularization_; + } + + private: LearnerConfig_MultiClassStrategy multiclass_strategy_; - int32 feature_column_group_id_; float l1_regularization_; float l2_regularization_; - float min_node_weight_; float tree_complexity_regularization_; + float min_node_weight_; + int32 class_id_; + int32 feature_column_group_id_; }; -class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { +class BuildDenseInequalitySplitsOp : public OpKernel { public: explicit BuildDenseInequalitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) {} + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -139,9 +158,6 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); - // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; partition_boundaries.push_back(0); @@ -185,6 +201,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[root_idx]; @@ -196,7 +213,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats(*gradients_t, *hessians_t, bucket_idx); } root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); int32 best_bucket_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -206,10 +223,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats g(*gradients_t, *hessians_t, bucket_idx); g *= normalizer_ratio; left_gradient_stats += g; - NodeStats left_stats = ComputeNodeStats(left_gradient_stats); + NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats right_stats = ComputeNodeStats(right_gradient_stats); + NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -220,18 +237,18 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { SplitInfo split_info; auto* dense_split = split_info.mutable_split_node()->mutable_dense_float_binary_split(); - dense_split->set_feature_column(feature_column_group_id_); + dense_split->set_feature_column(state.feature_column_group_id()); dense_split->set_threshold( bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(start_index); } } @@ -239,13 +256,10 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { REGISTER_KERNEL_BUILDER(Name("BuildDenseInequalitySplits").Device(DEVICE_CPU), BuildDenseInequalitySplitsOp); -class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { +class BuildSparseInequalitySplitsOp : public OpKernel { public: explicit BuildSparseInequalitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) { - OP_REQUIRES_OK(context, - context->GetAttr("bias_feature_id", &bias_feature_id_)); - } + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -275,8 +289,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); + const Tensor* bias_feature_id_t; + OP_REQUIRES_OK(context, + context->input("bias_feature_id", &bias_feature_id_t)); + int64 bias_feature_id = bias_feature_id_t->scalar()(); // For each partition (tree node), store starting index for each dimension. PartitionAndDimensionBoundaries partition_boundaries; @@ -354,6 +370,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); // For each tree node that needs to be split. for (int root_idx = 0; root_idx < num_elements; ++root_idx) { const auto& dimension_boundaries = @@ -372,7 +389,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { OP_REQUIRES( context, - bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id_, + bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id, errors::InvalidArgument("Bias feature ID missing.")); // Dimension for bias feature is always 0 @@ -388,7 +405,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { GradientStats root_gradient_stats(*gradients_t, *hessians_t, bias_start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); // Iterate through dimensions. for (int j = 0; j < dimension_boundaries.size() - 1; ++j) { @@ -408,7 +425,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { << bucket_ids_and_dimensions(start_index, 1) << " and for " << bucket_ids_and_dimensions(end_index - 1, 0) << " " << bucket_ids_and_dimensions(end_index - 1, 1); - if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id_) { + if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id) { // 0-dimension case which has a first bucket for catch all feature. CHECK(bucket_ids_and_dimensions(start_index, 1) == 0) << "Dimension of bias feature should be 0"; @@ -447,10 +464,10 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { present_gradient_stats - left_gradient_stats; { - NodeStats left_stats_default_left = - ComputeNodeStats(root_gradient_stats - right_gradient_stats); + NodeStats left_stats_default_left = state.ComputeNodeStats( + root_gradient_stats - right_gradient_stats); NodeStats right_stats_default_left = - ComputeNodeStats(right_gradient_stats); + state.ComputeNodeStats(right_gradient_stats); if (left_stats_default_left.gain + right_stats_default_left.gain > best_gain) { best_gain = @@ -466,9 +483,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // enough missing examples. if (!fixed_default_direction) { NodeStats left_stats_default_right = - ComputeNodeStats(left_gradient_stats); - NodeStats right_stats_default_right = - ComputeNodeStats(root_gradient_stats - left_gradient_stats); + state.ComputeNodeStats(left_gradient_stats); + NodeStats right_stats_default_right = state.ComputeNodeStats( + root_gradient_stats - left_gradient_stats); if (left_stats_default_right.gain + right_stats_default_right.gain > best_gain) { best_gain = left_stats_default_right.gain + @@ -494,7 +511,7 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { ->mutable_sparse_float_binary_split_default_left() ->mutable_split(); } - dense_split->set_feature_column(feature_column_group_id_); + dense_split->set_feature_column(state.feature_column_group_id()); // Set the feature index for the best feature column. const int64 best_dimension_id = bucket_ids_and_dimensions(best_element_idx, 1); @@ -505,11 +522,11 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(bias_start_index); } } @@ -526,19 +543,14 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { // For each partition, store start indices of feature column dimensions. typedef std::vector> PartitionAndDimensionBoundaries; - - int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER(Name("BuildSparseInequalitySplits").Device(DEVICE_CPU), BuildSparseInequalitySplitsOp); -class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { +class BuildCategoricalEqualitySplitsOp : public OpKernel { public: explicit BuildCategoricalEqualitySplitsOp(OpKernelConstruction* const context) - : BaseBuildSplitOp(context) { - OP_REQUIRES_OK(context, - context->GetAttr("bias_feature_id", &bias_feature_id_)); - } + : OpKernel(context) {} void Compute(OpKernelContext* const context) override { const Tensor* num_minibatches_t; @@ -561,8 +573,10 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { const Tensor* hessians_t; OP_REQUIRES_OK(context, context->input("hessians", &hessians_t)); - int class_id; - ReadClassId(context, &class_id); + const Tensor* bias_feature_id_t; + OP_REQUIRES_OK(context, + context->input("bias_feature_id", &bias_feature_id_t)); + int64 bias_feature_id = bias_feature_id_t->scalar()(); // Find the number of unique partitions before we allocate the output. std::vector partition_boundaries; @@ -605,16 +619,17 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + SplitBuilderState state(context); for (int root_idx = 0; root_idx < num_elements; ++root_idx) { float best_gain = std::numeric_limits::lowest(); int start_index = partition_boundaries[non_empty_partitions[root_idx]]; int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1]; // First feature ID in each partition should be the bias feature. - OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id_, + OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id, errors::InvalidArgument("Bias feature ID missing.")); GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); root_gradient_stats *= normalizer_ratio; - NodeStats root_stats = ComputeNodeStats(root_gradient_stats); + NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats); int32 best_feature_idx = 0; NodeStats best_right_node_stats(0); NodeStats best_left_node_stats(0); @@ -625,8 +640,8 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { left_gradient_stats *= normalizer_ratio; GradientStats right_gradient_stats = root_gradient_stats - left_gradient_stats; - NodeStats left_stats = ComputeNodeStats(left_gradient_stats); - NodeStats right_stats = ComputeNodeStats(right_gradient_stats); + NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats); + NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats); if (left_stats.gain + right_stats.gain > best_gain) { best_gain = left_stats.gain + right_stats.gain; best_left_node_stats = left_stats; @@ -637,21 +652,18 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { SplitInfo split_info; auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); - equality_split->set_feature_column(feature_column_group_id_); + equality_split->set_feature_column(state.feature_column_group_id()); equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); - FillLeaf(class_id, best_left_node_stats, left_child); - FillLeaf(class_id, best_right_node_stats, right_child); + state.FillLeaf(best_left_node_stats, left_child); + state.FillLeaf(best_right_node_stats, right_child); split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = - best_gain - root_stats.gain - tree_complexity_regularization_; + best_gain - root_stats.gain - state.tree_complexity_regularization(); output_partition_ids(root_idx) = partition_ids(start_index); } } - - private: - int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index f06b73c00d0bebb2717a79b7894e2addf914daba..409a2d8f46c331c13aec10542c4967d50575e94a 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -64,6 +64,8 @@ from __future__ import print_function import re from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler +from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops +from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops from tensorflow.contrib.boosted_trees.python.ops import quantile_ops from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops @@ -72,9 +74,11 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops + _BIAS_FEATURE_ID = -1 # Pattern to remove all non alpha numeric from a string. _PATTERN = re.compile(r"[\W_]+") @@ -130,11 +134,14 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler): gradient_shape, hessian_shape, name="StatsAccumulator/{}".format(self._name)) - self._quantile_accumulator = quantile_ops.QuantileAccumulator( - init_stamp_token, - epsilon=epsilon, - num_quantiles=num_quantiles, - name="QuantileAccumulator/{}".format(self._name)) + # Allocate both stats accumulator and quantile accumulator on the same + # device so that we can build splits with fewer RPCs. + with ops.colocate_with(self._stats_accumulator.resource()): + self._quantile_accumulator = quantile_ops.QuantileAccumulator( + init_stamp_token, + epsilon=epsilon, + num_quantiles=num_quantiles, + name="QuantileAccumulator/{}".format(self._name)) class DenseSplitHandler(InequalitySplitHandler): @@ -236,45 +243,74 @@ class DenseSplitHandler(InequalitySplitHandler): def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - # Get the bucket boundaries - are_splits_ready, buckets = ( - self._quantile_accumulator.get_buckets(stamp_token)) - # After we receive the boundaries from previous iteration we can flush - # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = self._quantile_accumulator.flush( - stamp_token=stamp_token, next_stamp_token=next_stamp_token) - - # Get the aggregated gradients and hessians per - # pair. - # In order to distribute the computation on all the PSs we use the PS that - # had the stats accumulator on. - with ops.device(None): - with ops.device(self._stats_accumulator.resource().device): - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - self._stats_accumulator.flush(stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - are_splits_ready = control_flow_ops.with_dependencies( - [flush_quantiles, partition_ids], are_splits_ready) - - partition_ids, gains, split_infos = ( - split_handler_ops.build_dense_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=self._feature_column_group_id, - l1_regularization=self._l1_regularization, - l2_regularization=self._l2_regularization, - tree_complexity_regularization=self. - _tree_complexity_regularization, - min_node_weight=self._min_node_weight, - multiclass_strategy=self._multiclass_strategy)) - return (are_splits_ready, partition_ids, gains, split_infos) + if (self._gradient_shape == tensor_shape.scalar() and + self._hessian_shape == tensor_shape.scalar()): + handler = make_dense_split_scalar + else: + handler = make_dense_split_tensor + + are_splits_ready, partition_ids, gains, split_infos = ( + handler(self._quantile_accumulator.resource(), + self._stats_accumulator.resource(), stamp_token, + next_stamp_token, self._multiclass_strategy, class_id, + self._feature_column_group_id, self._l1_regularization, + self._l2_regularization, self._tree_complexity_regularization, + self._min_node_weight)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional): + """Function that builds splits for a dense feature column.""" + # Get the bucket boundaries + are_splits_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) + # quantile_accumulator_get_buckets returns a list of results per handle that + # we pass to it. In this case we're getting results just for one resource. + are_splits_ready = are_splits_ready[0] + buckets = buckets[0] + + # After we receive the boundaries from previous iteration we can flush + # the quantile accumulator. + with ops.control_dependencies([buckets]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_dense_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos class SparseSplitHandler(InequalitySplitHandler): @@ -327,9 +363,6 @@ class SparseSplitHandler(InequalitySplitHandler): multiclass_strategy=multiclass_strategy, init_stamp_token=init_stamp_token, name=name) - # Register sparse_make_stats_update function as an Op to the graph. - g = ops.get_default_graph() - sparse_make_stats_update.add_to_graph(g) self._sparse_float_column = sparse_float_column def scheduled_reads(self): @@ -361,8 +394,8 @@ class SparseSplitHandler(InequalitySplitHandler): are_buckets_ready, buckets = scheduled_reads[0] with ops.name_scope(self._name, "SparseSplitHandler"): (quantile_indices, quantile_values, quantile_shapes, quantile_weights, - example_partition_ids, - feature_ids, gradients, hessians) = sparse_make_stats_update( + example_partition_ids, feature_ids, gradients, + hessians) = sparse_make_stats_update( is_active, are_buckets_ready, self._sparse_float_column.indices, self._sparse_float_column.values, self._sparse_float_column.dense_shape, buckets, @@ -379,42 +412,115 @@ class SparseSplitHandler(InequalitySplitHandler): def make_splits(self, stamp_token, next_stamp_token, class_id): """Create the best split using the accumulated stats and flush the state.""" - # Get the bucket boundaries - are_splits_ready, buckets = ( - self._quantile_accumulator.get_buckets(stamp_token)) - - # After we receive the boundaries from previous iteration we can flush - # the quantile accumulator. - with ops.control_dependencies([buckets]): - flush_quantiles = self._quantile_accumulator.flush( - stamp_token=stamp_token, next_stamp_token=next_stamp_token) - - with ops.device(None): - with ops.device(self._stats_accumulator.resource().device): - num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( - self._stats_accumulator.flush(stamp_token, next_stamp_token)) - - # Put quantile and stats accumulator flushing in the dependency path. - are_splits_ready = control_flow_ops.with_dependencies( - [flush_quantiles, partition_ids], are_splits_ready) - partition_ids, gains, split_infos = ( - split_handler_ops.build_sparse_inequality_splits( - num_minibatches=num_minibatches, - bucket_boundaries=buckets, - partition_ids=partition_ids, - bucket_ids=bucket_ids, - gradients=gradients, - hessians=hessians, - class_id=class_id, - feature_column_group_id=self._feature_column_group_id, - l1_regularization=self._l1_regularization, - l2_regularization=self._l2_regularization, - tree_complexity_regularization=self. - _tree_complexity_regularization, - min_node_weight=self._min_node_weight, - bias_feature_id=_BIAS_FEATURE_ID, - multiclass_strategy=self._multiclass_strategy)) - return (are_splits_ready, partition_ids, gains, split_infos) + if (self._gradient_shape == tensor_shape.scalar() and + self._hessian_shape == tensor_shape.scalar()): + handler = make_sparse_split_scalar + else: + handler = make_sparse_split_tensor + + are_splits_ready, partition_ids, gains, split_infos = ( + handler(self._quantile_accumulator.resource(), + self._stats_accumulator.resource(), stamp_token, + next_stamp_token, self._multiclass_strategy, class_id, + self._feature_column_group_id, self._l1_regularization, + self._l2_regularization, self._tree_complexity_regularization, + self._min_node_weight)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle, + stamp_token, next_stamp_token, multiclass_strategy, + class_id, feature_column_id, l1_regularization, + l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional): + """Function that builds splits for a sparse feature column.""" + # Get the bucket boundaries + are_splits_ready, buckets = ( + gen_quantile_ops.quantile_accumulator_get_buckets( + quantile_accumulator_handles=[quantile_accumulator_handle], + stamp_token=stamp_token)) + # quantile_accumulator_get_buckets returns a list of results per handle that + # we pass to it. In this case we're getting results just for one resource. + are_splits_ready = are_splits_ready[0] + buckets = buckets[0] + + # After we receive the boundaries from previous iteration we can flush + # the quantile accumulator. + with ops.control_dependencies([buckets]): + flush_quantiles = gen_quantile_ops.quantile_accumulator_flush( + quantile_accumulator_handle=quantile_accumulator_handle, + stamp_token=stamp_token, + next_stamp_token=next_stamp_token) + + if is_multi_dimentional: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_tensor_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + else: + num_minibatches, partition_ids, bucket_ids, gradients, hessians = ( + gen_stats_accumulator_ops.stats_accumulator_scalar_flush( + stats_accumulator_handle, stamp_token, next_stamp_token)) + + # Put quantile and stats accumulator flushing in the dependency path. + with ops.control_dependencies([flush_quantiles, partition_ids]): + are_splits_ready = array_ops.identity(are_splits_ready) + partition_ids, gains, split_infos = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=num_minibatches, + bucket_boundaries=buckets, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + class_id=class_id, + feature_column_group_id=feature_column_id, + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, + min_node_weight=min_node_weight, + bias_feature_id=_BIAS_FEATURE_ID, + multiclass_strategy=multiclass_strategy)) + return are_splits_ready, partition_ids, gains, split_infos + + +def _specialize_make_split(func, is_multi_dimentional): + """Builds a specialized version of the function.""" + + @function.Defun( + dtypes.resource, + dtypes.resource, + dtypes.int64, + dtypes.int64, + dtypes.int32, + dtypes.int32, + dtypes.int32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + dtypes.float32, + noinline=True) + def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight): + """Function that builds splits for a sparse feature column.""" + return func( + quantile_accumulator_handle, stats_accumulator_handle, stamp_token, + next_stamp_token, multiclass_strategy, class_id, feature_column_id, + l1_regularization, l2_regularization, tree_complexity_regularization, + min_node_weight, is_multi_dimentional) + + return f + +make_dense_split_scalar = _specialize_make_split(_make_dense_split, + is_multi_dimentional=False) +make_dense_split_tensor = _specialize_make_split(_make_dense_split, + is_multi_dimentional=True) + +make_sparse_split_scalar = _specialize_make_split(_make_sparse_split, + is_multi_dimentional=False) +make_sparse_split_tensor = _specialize_make_split(_make_sparse_split, + is_multi_dimentional=True) @function.Defun( @@ -540,8 +646,9 @@ def sparse_make_stats_update( empty_float = constant_op.constant([], dtype=dtypes.float32) handler_not_active = (constant_op.constant( - [], dtype=dtypes.int64, shape=[0, 2]), empty_float, constant_op.constant( - [0, 1], dtype=dtypes.int64), empty_float) + [], dtype=dtypes.int64, shape=[0, 2]), empty_float, + constant_op.constant([0, 1], dtype=dtypes.int64), + empty_float) handler_active = (sparse_column_indices, sparse_column_values, sparse_column_shape, weights) quantile_indices, quantile_values, quantile_shape, quantile_weights = ( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index 54d03018d9e266beabbbabd78ebbb80cfe689c04..2f2c2302113bf59d6a065d5005c934dc76c2148d 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.boosted_trees.lib.learner.batch import ordinal_split_handler from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import split_info_pb2 @@ -65,9 +67,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.scalar() split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -92,7 +94,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -105,7 +109,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -199,10 +203,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.TensorShape([2, 2]) split_handler = ordinal_split_handler.DenseSplitHandler( - l1_regularization=0, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0., + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=3, feature_column_group_id=0, @@ -227,7 +231,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] + with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -240,7 +246,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -285,10 +291,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): hessian_shape = tensor_shape.TensorShape([2]) split_handler = ordinal_split_handler.DenseSplitHandler( - l1_regularization=0, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0., + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=3, feature_column_group_id=0, @@ -313,7 +319,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -326,7 +333,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -369,9 +376,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, - tree_complexity_regularization=0, - min_node_weight=0, + l2_regularization=1., + tree_complexity_regularization=0., + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -396,7 +403,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, False])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -409,7 +417,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([False, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -443,9 +451,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, + l2_regularization=1., tree_complexity_regularization=0.5, - min_node_weight=0, + min_node_weight=0., epsilon=0.001, num_quantiles=10, feature_column_group_id=0, @@ -470,7 +478,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -483,7 +492,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -576,7 +585,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): split_handler = ordinal_split_handler.DenseSplitHandler( l1_regularization=0.1, - l2_regularization=1, + l2_regularization=1., tree_complexity_regularization=0.5, min_node_weight=1.5, epsilon=0.001, @@ -603,7 +612,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -616,7 +626,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -685,10 +695,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -713,8 +723,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] - + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( 1, @@ -727,7 +737,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -811,10 +821,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -839,7 +849,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -853,7 +864,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -905,10 +916,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -933,7 +944,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -947,7 +959,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -996,10 +1008,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1024,7 +1036,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, False])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1038,7 +1051,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([False, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -1065,10 +1078,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1096,7 +1109,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1110,7 +1124,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits @@ -1138,10 +1152,10 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): class_id = -1 split_handler = ordinal_split_handler.SparseSplitHandler( - l1_regularization=0, - l2_regularization=2, - tree_complexity_regularization=0, - min_node_weight=0, + l1_regularization=0.0, + l2_regularization=2.0, + tree_complexity_regularization=0.0, + min_node_weight=0.0, epsilon=0.01, num_quantiles=2, feature_column_group_id=0, @@ -1166,7 +1180,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): example_weights, is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_1]): - are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + are_splits_ready = split_handler.make_splits( + np.int64(0), np.int64(1), class_id)[0] with ops.control_dependencies([are_splits_ready]): update_2 = split_handler.update_stats_sync( @@ -1180,7 +1195,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): is_active=array_ops.constant([True, True])) with ops.control_dependencies([update_2]): are_splits_ready2, partitions, gains, splits = ( - split_handler.make_splits(1, 2, class_id)) + split_handler.make_splits(np.int64(1), np.int64(2), class_id)) are_splits_ready, are_splits_ready2, partitions, gains, splits = ( sess.run([ are_splits_ready, are_splits_ready2, partitions, gains, splits diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc index 43b00d4c6dc2e0066810012292874314215c41be..c9223afeab233497bce9f680bd44bd10ccfc6491 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc @@ -26,7 +26,8 @@ void MultipleAdditiveTrees::Predict( const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, tensorflow::thread::ThreadPool* const worker_threads, - tensorflow::TTypes::Matrix output_predictions) { + tensorflow::TTypes::Matrix output_predictions, + Tensor* const output_leaf_index) { // Zero out predictions as the model is additive. output_predictions.setZero(); @@ -38,8 +39,13 @@ void MultipleAdditiveTrees::Predict( // Lambda for doing a block of work. auto update_predictions = [&config, &features, &trees_to_include, - &output_predictions](int64 start, int64 end) { + &output_predictions, + &output_leaf_index](int64 start, int64 end) { auto examples_iterable = features.examples_iterable(start, end); + Tensor dummy_tensor(DT_INT32, TensorShape({1, 1})); + tensorflow::TTypes::Matrix output_leaf_index_mat = + output_leaf_index != nullptr ? output_leaf_index->matrix() + : dummy_tensor.matrix(); for (const auto& example : examples_iterable) { for (const int32 tree_idx : trees_to_include) { const boosted_trees::trees::DecisionTreeConfig& tree = @@ -47,6 +53,10 @@ void MultipleAdditiveTrees::Predict( const float tree_weight = config.tree_weights(tree_idx); const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example); QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString(); + // Checks if output leaf tree index is required. + if (output_leaf_index != nullptr) { + output_leaf_index_mat(example.example_idx, tree_idx) = leaf_idx; + } const auto& leaf_node = tree.nodes(leaf_idx); QCHECK(leaf_node.has_leaf()) << "Invalid leaf node: " << leaf_node.DebugString(); diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h index cc3dc226cdbc88fc7010ada1e7f0e6c0a3913c5f..940531c4ba4bcac19fa980deb091e55b48e0693b 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h @@ -33,12 +33,17 @@ class MultipleAdditiveTrees { public: // Predict runs tree ensemble on the given batch and updates // output predictions accordingly, for the given list of trees. + // output_leaf_indices is a pointer to a 2 dimensional tensor. If it is not + // nullptr, this method fills output_leaf_indices with a per-tree leaf id + // where each of the instances from 'features' ended up in. Its shape is num + // examples X num of trees. static void Predict( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, tensorflow::thread::ThreadPool* const worker_threads, - tensorflow::TTypes::Matrix output_predictions); + tensorflow::TTypes::Matrix output_predictions, + Tensor* const output_leaf_index); }; } // namespace models diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc index 4ca18bedb1054ef64c6d4b25bbad04842bab1a6a..462a9ac86fe51d07cfb958d9be49bef84811a52e 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc @@ -62,7 +62,8 @@ TEST_F(MultipleAdditiveTreesTest, Empty) { tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", kNumThreadsSingleThreaded); MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, + /*output_leaf_index=*/nullptr); EXPECT_EQ(0, output_matrix(0, 0)); EXPECT_EQ(0, output_matrix(1, 0)); } @@ -99,17 +100,38 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + /*output_leaf_index=*/nullptr); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). } + // Normal case with leaf node. + { + // Initialize output leaf index tensor, since leaf index is positive in this + // case, initialize with the value of -1. Since there are 2 examples and + // there are 2 trees, initialize leaf output index by 2 * 2. + Tensor output_leaf_index_tensor(DT_INT32, TensorShape({2, 2})); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, + batch_features_, &threads, output_matrix, + &output_leaf_index_tensor); + EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). + EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). + EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix()( + 0, 0)); // 1st leaf for the first example + EXPECT_FLOAT_EQ(0, output_leaf_index_tensor.matrix()( + 1, 0)); // 1st leaf for the second example + EXPECT_FLOAT_EQ(2, output_leaf_index_tensor.matrix()( + 0, 1)); // 2nd leaf for the first example + EXPECT_FLOAT_EQ(1, output_leaf_index_tensor.matrix()( + 1, 1)); // 2nd leaf for the second example + } // Weighted case { DecisionTreeEnsembleConfig weighted = tree_ensemble_config; weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, - output_matrix); + output_matrix, nullptr); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0)); // -0.4 (bias) + 0.9 (leaf 1). @@ -118,21 +140,21 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) { // Drop first tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1). } // Drop second tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias). EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias). } // Drop all trees. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0)); } @@ -172,7 +194,8 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1) @@ -184,7 +207,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, - output_matrix); + output_matrix, nullptr); // bias EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0)); // bias + leaf 2 @@ -197,7 +220,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Dropout first tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2) @@ -206,7 +229,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Dropout second tree. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias) EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias) @@ -215,7 +238,7 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { // Drop both trees. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, - &threads, output_matrix); + &threads, output_matrix, nullptr); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1)); EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0)); @@ -258,7 +281,8 @@ TEST_F(MultipleAdditiveTreesTest, DenseLeaves) { // Normal case. { MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, - batch_features_, &threads, output_matrix); + batch_features_, &threads, output_matrix, + nullptr); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2) EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2) diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index 8ad97fedc923ac50bcaad86e0ba2c2e46df6821b..c120dd8a6c156ec9eb7ba0b6c552f5138bd21a16 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -295,7 +295,7 @@ WeightedQuantilesStream::GetQuantileSpecs( if (eps <= std::numeric_limits::epsilon()) { // Exact quantile computation at the expense of RAM. max_level = 1; - block_size = std::max(max_elements, 2LL); + block_size = std::max(max_elements, int64{2}); } else { // The bottom-most level will become full at most // (max_elements / block_size) times, the level above will become full @@ -315,7 +315,7 @@ WeightedQuantilesStream::GetQuantileSpecs( block_size = static_cast(ceil(max_level / eps)) + 1; } } - return std::make_tuple(max_level, std::max(block_size, 2LL)); + return std::make_tuple(max_level, std::max(block_size, int64{2})); } } // namespace quantiles diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index 7576856dc3a6d0b6681ee9745c875cf46d1e2960..a7e7bfc13cadcea4d29d33e0dbd955bdad6ffcb9 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -195,7 +195,7 @@ class WeightedQuantilesSummary { // designed to be cache-friendly. void Compress(int64 size_hint, double min_eps = 0) { // No-op if we're already within the size requirement. - size_hint = std::max(size_hint, 2LL); + size_hint = std::max(size_hint, int64{2}); if (entries_.size() <= size_hint) { return; } @@ -267,7 +267,7 @@ class WeightedQuantilesSummary { if (entries_.empty()) { return output; } - num_quantiles = std::max(num_quantiles, 2LL); + num_quantiles = std::max(num_quantiles, int64{2}); output.reserve(num_quantiles + 1); // Make successive rank queries to get boundaries. diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc index d66f645f62aba84261337eb37d6e3204930f8f15..6491d58794332e9417951753532e018aafb652b1 100644 --- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc @@ -40,6 +40,24 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) { return Status::OK(); } +static Status ApplyGradientTreesPredictionVerboseShapeFn(InferenceContext* c) { + string learner_config_str; + c->GetAttr("learner_config", &learner_config_str).IgnoreError(); + LearnerConfig learner_config; + ParseProtoUnlimited(&learner_config, learner_config_str); + + bool reduce_dim; + c->GetAttr("reduce_dim", &reduce_dim).IgnoreError(); + // Sets the shape of the output as a matrix. + c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim, + reduce_dim ? learner_config.num_classes() - 1 + : learner_config.num_classes())}); + c->set_output(1, {c->UnknownShape()}); + c->set_output(2, {c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)}); + return Status::OK(); +} + REGISTER_OP("GradientTreesPrediction") .Attr("learner_config: string") .Attr("num_dense_float_features: int >= 0") @@ -90,6 +108,58 @@ drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices and original weights of those trees during prediction. )doc"); +REGISTER_OP("GradientTreesPredictionVerbose") + .Attr("learner_config: string") + .Attr("num_dense_float_features: int >= 0") + .Attr("num_sparse_float_features: int >= 0") + .Attr("num_sparse_int_features: int >= 0") + .Attr("use_locking: bool = false") + .Attr("apply_dropout: bool") + .Attr("apply_averaging: bool") + .Attr("center_bias: bool") + .Attr("reduce_dim: bool") + .Input("tree_ensemble_handle: resource") + .Input("seed: int64") + .Input("dense_float_features: num_dense_float_features * float") + .Input("sparse_float_feature_indices: num_sparse_float_features * int64") + .Input("sparse_float_feature_values: num_sparse_float_features * float") + .Input("sparse_float_feature_shapes: num_sparse_float_features * int64") + .Input("sparse_int_feature_indices: num_sparse_int_features * int64") + .Input("sparse_int_feature_values: num_sparse_int_features * int64") + .Input("sparse_int_feature_shapes: num_sparse_int_features * int64") + .Output("predictions: float") + .Output("drop_out_tree_indices_weights: float") + .Output("leaf_index: int32") + .SetShapeFn(ApplyGradientTreesPredictionVerboseShapeFn) + .Doc(R"doc( +Runs multiple additive regression forests predictors on input instances +and computes the final prediction for each class, and outputs a matrix of +leaf ids per each tree in an ensemble. + +learner_config: Config for the learner of type LearnerConfig proto. Prediction +ops for now uses only LearningRateDropoutDrivenConfig config from the learner. +num_dense_float_features: Number of dense float features. +num_sparse_float_features: Number of sparse float features. +num_sparse_int_features: Number of sparse int features. +use_locking: Whether to use locking. +seed: random seed to be used for dropout. +reduce_dim: whether to reduce the dimension (legacy impl) or not. +apply_dropout: whether to apply dropout during prediction. +apply_averaging: whether averaging of tree ensembles should take place. If set +to true, will be based on AveragingConfig from learner_config. +tree_ensemble_handle: The handle to the tree ensemble. +dense_float_features: Rank 2 Tensors containing dense float feature values. +sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices. +sparse_float_feature_values: Rank 1 Tensors containing sparse float values. +sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes. +sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices. +sparse_int_feature_values: Rank 1 Tensors containing sparse int values. +sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes. +predictions: Rank 2 Tensor containing predictions per example per class. +drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices +leaf_index: tensor of rank 2 containing leaf ids for each tree where an instance ended up. +)doc"); + REGISTER_OP("GradientTreesPartitionExamples") .Attr("num_dense_float_features: int >= 0") .Attr("num_sparse_float_features: int >= 0") diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 5d0ebbf73ce1272b51a475f67984db3a181b7130..ca5c7f3d8c78a543c63fbfa9f7eb7c3d348f11b8 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -23,12 +23,6 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; REGISTER_OP("BuildDenseInequalitySplits") - .Attr("feature_column_group_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("bucket_ids: int64") @@ -36,6 +30,12 @@ REGISTER_OP("BuildDenseInequalitySplits") .Input("hessians: float32") .Input("bucket_boundaries: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -73,6 +73,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -81,13 +92,6 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); REGISTER_OP("BuildSparseInequalitySplits") - .Attr("feature_column_group_id: int") - .Attr("bias_feature_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("bucket_ids: int64") @@ -95,6 +99,13 @@ REGISTER_OP("BuildSparseInequalitySplits") .Input("hessians: float32") .Input("bucket_boundaries: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("bias_feature_id: int64") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -133,6 +144,17 @@ bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -141,19 +163,19 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); REGISTER_OP("BuildCategoricalEqualitySplits") - .Attr("feature_column_group_id: int") - .Attr("bias_feature_id: int") - .Attr("l1_regularization: float") - .Attr("l2_regularization: float") - .Attr("tree_complexity_regularization: float") - .Attr("min_node_weight: float") - .Attr("multiclass_strategy: int") .Input("num_minibatches: int64") .Input("partition_ids: int32") .Input("feature_ids: int64") .Input("gradients: float32") .Input("hessians: float32") .Input("class_id: int32") + .Input("feature_column_group_id: int32") + .Input("bias_feature_id: int64") + .Input("l1_regularization: float") + .Input("l2_regularization: float") + .Input("tree_complexity_regularization: float") + .Input("min_node_weight: float") + .Input("multiclass_strategy: int32") .Output("output_partition_ids: int32") .Output("gains: float32") .Output("split_infos: string") @@ -188,6 +210,17 @@ partition_ids: A rank 1 tensor of partition IDs. feature_ids: A rank 2 tensor of feature IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. +class_id: A scalar, the class id for which we're building the splits. +feature_column_group_id: A scalar, the index of the feature we are spiltting on. +l1_regularization: A scalar, which specifies the l1 regularization term. +l2_regularization: A scalar, which specifies the l2 regularization term. +tree_complexity_regularization: A scalar, which specifies the tree complexity + regularization term. +min_node_weight: A scalar, minimum sum of example hessian needed in a child. + If a split results in a leaf node with a smaller value, the split will not + be considered. +multiclass_strategy: A scalar, specifying the multiclass handling strategy. + See LearnerConfig.MultiClassStrategy for valid values. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits for. gains: A rank 1 tensor, for the computed gain for the created splits. @@ -196,4 +229,3 @@ split_infos: A rank 1 tensor of serialized protos which contains the )doc"); } // namespace tensorflow - // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py index 7a5f329b7ab3216972180ccbb4c85f2537175422..843420968ac6a6716fdf6b4967146e131139f67c 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py +++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py @@ -20,6 +20,8 @@ from __future__ import print_function import abc import collections +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops @@ -60,6 +62,7 @@ def _move_tensors(tensors, device): """Moves a list of tensors to a device by concatenating/splitting them.""" # Reset the device setting to avoid weird interactions with device merging # logic. + zero = constant_op.constant(0, dtype=dtypes.int32) with ops.device(None): if all(tensor.shape == tensor_shape.scalar() for tensor in tensors): with ops.device(tensors[0].device): @@ -68,12 +71,11 @@ def _move_tensors(tensors, device): return array_ops.unstack(values) else: with ops.device(tensors[0].device): - sizes = array_ops.stack( - [array_ops.shape(tensor)[0] for tensor in tensors]) - values = array_ops.concat(tensors, axis=0) + sizes = array_ops.stack(array_ops.shape_n(tensors))[:, 0] + values = array_ops.concat(tensors, axis=zero) with ops.device(device): sizes = array_ops.unstack(sizes) - return list(array_ops.split(values, sizes, axis=0)) + return list(array_ops.split(values, sizes, axis=zero)) def _scheduled_stamp_resource_op_runner(batch, stamp): diff --git a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py index 58f0d36b0f78eeed6abcec1c4fa696f4ccffa615..7f6e55ae5888fc4ef50e34690d61c3ed303e971a 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py @@ -21,4 +21,5 @@ from __future__ import print_function from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_partition_examples from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction +from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction_verbose # pylint: enable=unused-import diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 50cc00afdcc77fedc9bf8c94a9a6fcf2a28ebde9..19b6b3296db394b07f57a25dbde187eb9195af38 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -201,3 +201,6 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): stamp_token=stamp_token, next_stamp_token=next_stamp_token) return result + + def resource(self): + return self._quantile_accumulator_handle diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index e53d86ec612f299c800753d67ceee79acb5db497..47698d45c81478f2b694aaadc603f742c44d5351 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -58,6 +58,7 @@ NUM_LAYERS_ATTEMPTED = "num_layers" NUM_TREES_ATTEMPTED = "num_trees" NUM_USED_HANDLERS = "num_used_handlers" USED_HANDLERS_MASK = "used_handlers_mask" +LEAF_INDEX = "leaf_index" _FEATURE_NAME_TEMPLATE = "%s_%d" @@ -71,18 +72,24 @@ def _get_column_by_index(tensor, indices): return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1]) -def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats, - used_handlers): +def _make_predictions_dict(stamp, + logits, + partition_ids, + ensemble_stats, + used_handlers, + leaf_index=None): """Returns predictions for the given logits and n_classes. Args: stamp: The ensemble stamp. - logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. - that contains predictions when no dropout was applied. + logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. that + contains predictions when no dropout was applied. partition_ids: A rank 1 `Tensor` with shape [batch_size]. ensemble_stats: A TreeEnsembleStatsOp result tuple. used_handlers: A TreeEnsembleUsedHandlerOp result tuple of an int and a - boolean mask.. + boolean mask. + leaf_index: A rank 2 `Tensor` with shape [batch_size, number of trees]. that + contains leaf id for each example prediction. Returns: A dict of predictions. @@ -95,6 +102,8 @@ def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats, result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees result[NUM_USED_HANDLERS] = used_handlers.num_used_handlers result[USED_HANDLERS_MASK] = used_handlers.used_handlers_mask + if leaf_index is not None: + result[LEAF_INDEX] = leaf_index return result @@ -180,8 +189,7 @@ def extract_features(features, feature_columns, use_core_columns): elif isinstance(fc, feature_column_lib._EmbeddingColumn): # pylint: enable=protected-access transformed_features[fc.name] = fc_core.input_layer( - features, [fc], - weight_collections=[scope]) + features, [fc], weight_collections=[scope]) else: result = feature_column_ops.transform_features(features, [fc]) if len(result) > 1: @@ -269,7 +277,8 @@ class GradientBoostedDecisionTreeModel(object): features, logits_dimension, feature_columns=None, - use_core_columns=False): + use_core_columns=False, + output_leaf_index=False): """Construct a new GradientBoostedDecisionTreeModel function. Args: @@ -277,13 +286,15 @@ class GradientBoostedDecisionTreeModel(object): num_ps_replicas: Number of parameter server replicas, can be 0. ensemble_handle: A handle to the ensemble variable. center_bias: Whether to center the bias before growing trees. - examples_per_layer: Number of examples to accumulate before growing - a tree layer. It can also be a function that computes the number of - examples based on the depth of the layer that's being built. + examples_per_layer: Number of examples to accumulate before growing a tree + layer. It can also be a function that computes the number of examples + based on the depth of the layer that's being built. learner_config: A learner config. features: `dict` of `Tensor` objects. logits_dimension: An int, the dimension of logits. feature_columns: A list of feature columns. + output_leaf_index: A boolean variable indicating whether to output leaf + index into predictions dictionary. Raises: ValueError: if inputs are not valid. @@ -334,10 +345,12 @@ class GradientBoostedDecisionTreeModel(object): self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._attempted_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + initial_value=array_ops.zeros([], dtypes.int64), + trainable=False, name="attempted_trees") self._finalized_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + initial_value=array_ops.zeros([], dtypes.int64), + trainable=False, name="finalized_trees") if not features: raise ValueError("Features dictionary must be specified.") @@ -354,9 +367,11 @@ class GradientBoostedDecisionTreeModel(object): self._sparse_int_indices = sparse_int_indices self._sparse_int_values = sparse_int_values self._sparse_int_shapes = sparse_int_shapes - self._reduce_dim = (self._learner_config.multi_class_strategy == - learner_pb2.LearnerConfig.TREE_PER_CLASS and - learner_config.num_classes == 2) + self._reduce_dim = ( + self._learner_config.multi_class_strategy == + learner_pb2.LearnerConfig.TREE_PER_CLASS and + learner_config.num_classes == 2) + self._output_leaf_index = output_leaf_index def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): """Runs prediction and returns a dictionary of the prediction results. @@ -374,8 +389,8 @@ class GradientBoostedDecisionTreeModel(object): ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, ensemble_stamp) num_handlers = ( - len(self._dense_floats) + len(self._sparse_float_shapes) + - len(self._sparse_int_shapes)) + len(self._dense_floats) + len(self._sparse_float_shapes) + len( + self._sparse_int_shapes)) # Used during feature selection. used_handlers = model_ops.tree_ensemble_used_handlers( ensemble_handle, ensemble_stamp, num_all_handlers=num_handlers) @@ -386,22 +401,44 @@ class GradientBoostedDecisionTreeModel(object): # Make sure ensemble stats run. This will check that the ensemble has # the right stamp. with ops.control_dependencies(ensemble_stats): - predictions, _ = prediction_ops.gradient_trees_prediction( - ensemble_handle, - seed, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - learner_config=self._learner_config_serialized, - apply_dropout=apply_dropout, - apply_averaging=mode != learn.ModeKeys.TRAIN, - use_locking=True, - center_bias=self._center_bias, - reduce_dim=self._reduce_dim) + leaf_index = None + # Only used in infer (predict), not used in train and eval. + if self._output_leaf_index and mode == learn.ModeKeys.INFER: + predictions, _, leaf_index = ( + prediction_ops).gradient_trees_prediction_verbose( + ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=mode != learn.ModeKeys.TRAIN, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim) + else: + leaf_index = None + predictions, _ = prediction_ops.gradient_trees_prediction( + ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=mode != learn.ModeKeys.TRAIN, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim) partition_ids = prediction_ops.gradient_trees_partition_examples( ensemble_handle, self._dense_floats, @@ -414,7 +451,7 @@ class GradientBoostedDecisionTreeModel(object): use_locking=True) return _make_predictions_dict(ensemble_stamp, predictions, partition_ids, - ensemble_stats, used_handlers) + ensemble_stats, used_handlers, leaf_index) def predict(self, mode): """Returns predictions given the features and mode. @@ -432,8 +469,9 @@ class GradientBoostedDecisionTreeModel(object): # Use the current ensemble to predict on the current batch of input. # For faster prediction we check if the inputs are on the same device # as the model. If not, we create a copy of the model on the worker. - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) + input_deps = ( + self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) if not input_deps: raise ValueError("No input tensors for prediction.") @@ -457,8 +495,8 @@ class GradientBoostedDecisionTreeModel(object): # Determine whether the local ensemble is stale and update it if needed. def _refresh_local_ensemble_fn(): - # Serialize the model from parameter server after reading all inputs. - with ops.control_dependencies(input_deps): + # Serialize the model from parameter server after reading the inputs. + with ops.control_dependencies([input_deps[0]]): (ensemble_stamp, serialized_model) = ( model_ops.tree_ensemble_serialize(self._ensemble_handle)) @@ -500,8 +538,9 @@ class GradientBoostedDecisionTreeModel(object): ValueError: if inputs are not valid. """ # Get the worker device from input dependencies. - input_deps = (self._dense_floats + self._sparse_float_indices + - self._sparse_int_indices) + input_deps = ( + self._dense_floats + self._sparse_float_indices + + self._sparse_int_indices) worker_device = input_deps[0].device # Get tensors relevant for training and form the loss. @@ -517,7 +556,7 @@ class GradientBoostedDecisionTreeModel(object): aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy - class_id = -1 + class_id = constant_op.constant(-1, dtype=dtypes.int32) # Handle different multiclass strategies. if strategy == learner_pb2.LearnerConfig.TREE_PER_CLASS: # We build one vs rest trees. @@ -571,31 +610,39 @@ class GradientBoostedDecisionTreeModel(object): # Get the weights for each example for quantiles calculation, weights = self._get_weights(hessian_shape, squeezed_hessians) - regularization_config = self._learner_config.regularization - min_node_weight = self._learner_config.constraints.min_node_weight # Create all handlers ensuring resources are evenly allocated across PS. fc_name_idx = 0 handlers = [] init_stamp_token = constant_op.constant(0, dtype=dtypes.int64) + l1_regularization = constant_op.constant( + self._learner_config.regularization.l1, dtypes.float32) + l2_regularization = constant_op.constant( + self._learner_config.regularization.l2, dtypes.float32) + tree_complexity_regularization = constant_op.constant( + self._learner_config.regularization.tree_complexity, dtypes.float32) + min_node_weight = constant_op.constant( + self._learner_config.constraints.min_node_weight, dtypes.float32) + epsilon = 0.01 + num_quantiles = 100 + strategy_tensor = constant_op.constant(strategy) with ops.device(self._get_replica_device_setter(worker_device)): # Create handlers for dense float columns for dense_float_column_idx in range(len(self._dense_floats)): fc_name = self._fc_names[fc_name_idx] handlers.append( ordinal_split_handler.DenseSplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=dense_float_column_idx, - epsilon=0.01, - num_quantiles=100, + epsilon=epsilon, + num_quantiles=num_quantiles, dense_float_column=self._dense_floats[dense_float_column_idx], name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -604,14 +651,13 @@ class GradientBoostedDecisionTreeModel(object): fc_name = self._fc_names[fc_name_idx] handlers.append( ordinal_split_handler.SparseSplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=sparse_float_column_idx, - epsilon=0.01, - num_quantiles=100, + epsilon=epsilon, + num_quantiles=num_quantiles, sparse_float_column=sparse_tensor.SparseTensor( self._sparse_float_indices[sparse_float_column_idx], self._sparse_float_values[sparse_float_column_idx], @@ -619,7 +665,7 @@ class GradientBoostedDecisionTreeModel(object): name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -628,10 +674,9 @@ class GradientBoostedDecisionTreeModel(object): fc_name = self._fc_names[fc_name_idx] handlers.append( categorical_split_handler.EqualitySplitHandler( - l1_regularization=regularization_config.l1, - l2_regularization=regularization_config.l2, - tree_complexity_regularization=( - regularization_config.tree_complexity), + l1_regularization=l1_regularization, + l2_regularization=l2_regularization, + tree_complexity_regularization=tree_complexity_regularization, min_node_weight=min_node_weight, feature_column_group_id=sparse_int_column_idx, sparse_int_column=sparse_tensor.SparseTensor( @@ -641,7 +686,7 @@ class GradientBoostedDecisionTreeModel(object): name=fc_name, gradient_shape=gradient_shape, hessian_shape=hessian_shape, - multiclass_strategy=strategy, + multiclass_strategy=strategy_tensor, init_stamp_token=init_stamp_token)) fc_name_idx += 1 @@ -694,11 +739,11 @@ class GradientBoostedDecisionTreeModel(object): name="continue_centering", trainable=False) stats_update_ops.append( - control_flow_ops.cond(continue_centering, - self._make_update_bias_stats_fn( - ensemble_stamp, predictions, gradients, - bias_stats_accumulator), - control_flow_ops.no_op)) + control_flow_ops.cond( + continue_centering, + self._make_update_bias_stats_fn(ensemble_stamp, predictions, + gradients, bias_stats_accumulator), + control_flow_ops.no_op)) # Update handler stats. handler_reads = collections.OrderedDict() @@ -720,8 +765,8 @@ class GradientBoostedDecisionTreeModel(object): shape=[len(handlers)], seed=[seed + 1, 1]) active_handlers = array_ops.stack( [active_handlers_current_layer, active_handlers_next_layer], axis=1) - active_handlers = (active_handlers < - self._learner_config.feature_fraction_per_level) + active_handlers = ( + active_handlers < self._learner_config.feature_fraction_per_level) elif subsampling_type == "feature_fraction_per_tree": seed = predictions_dict[NUM_TREES_ATTEMPTED] active_handlers_current_layer = stateless.stateless_random_uniform( @@ -729,9 +774,12 @@ class GradientBoostedDecisionTreeModel(object): active_handlers_current_layer = ( active_handlers_current_layer < self._learner_config.feature_fraction_per_tree) - active_handlers = array_ops.stack([ - active_handlers_current_layer, - array_ops.ones([len(handlers)], dtype=dtypes.bool)], axis=1) + active_handlers = array_ops.stack( + [ + active_handlers_current_layer, + array_ops.ones([len(handlers)], dtype=dtypes.bool) + ], + axis=1) else: active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool) @@ -760,6 +808,7 @@ class GradientBoostedDecisionTreeModel(object): empty_hessians = constant_op.constant( [], dtype=dtypes.float32, shape=empty_hess_shape) + active_handlers = array_ops.unstack(active_handlers, axis=0) for handler_idx in range(len(handlers)): handler = handlers[handler_idx] is_active = active_handlers[handler_idx] @@ -901,7 +950,6 @@ class GradientBoostedDecisionTreeModel(object): "DecisionTreeEnsembleResourceHandleOp", "StatsAccumulatorScalarResourceHandleOp", "StatsAccumulatorTensorResourceHandleOp", - "QuantileStreamResourceHandleOp", ] ps_strategy = _OpRoundRobinStrategy(ps_ops, ps_tasks) return device_setter.replica_device_setter( @@ -971,7 +1019,7 @@ class GradientBoostedDecisionTreeModel(object): # This is a workaround for the slowness of graph building in tf.cond. # See (b/36554864). split_sizes = array_ops.reshape( - array_ops.shape_n(partition_ids_list), [-1]) + array_ops.shape_n(partition_ids_list), [len(partition_ids_list)]) partition_ids = array_ops.concat(partition_ids_list, axis=0) gains = array_ops.concat(gains_list, axis=0) split_infos = array_ops.concat(split_info_list, axis=0) @@ -1036,8 +1084,11 @@ class GradientBoostedDecisionTreeModel(object): # Update ensemble. update_ops = [are_all_splits_ready] - update_model = control_flow_ops.cond(continue_centering, _center_bias_fn, - _grow_ensemble_fn) + if self._center_bias: + update_model = control_flow_ops.cond(continue_centering, + _center_bias_fn, _grow_ensemble_fn) + else: + update_model = _grow_ensemble_fn() update_ops.append(update_model) # Update ensemble stats. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index f9c22283b7f5136777bfa60a12c94974adfbd245..e3d4397fadcbaf148f7f6cfaca13e850639786cf 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -19,19 +19,15 @@ from __future__ import division from __future__ import print_function from google.protobuf import text_format - from tensorflow.contrib import layers from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch from tensorflow.contrib.boosted_trees.python.utils import losses - -from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.learn.python.learn.estimators import model_fn - - +from tensorflow.python.feature_column import feature_column_lib as core_feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util @@ -97,8 +93,8 @@ class GbdtTest(test_util.TensorFlowTestCase): array_ops.zeros([2], dtypes.int64)) features["sparse_int"] = sparse_tensor.SparseTensor( array_ops.zeros([2, 2], dtypes.int64), - array_ops.zeros([2], dtypes.int64), - array_ops.zeros([2], dtypes.int64)) + array_ops.zeros([2], dtypes.int64), array_ops.zeros([2], + dtypes.int64)) (fc_names, dense_floats, sparse_float_indices, sparse_float_values, sparse_float_shapes, sparse_int_indices, sparse_int_values, sparse_int_shapes) = ( @@ -139,8 +135,8 @@ class GbdtTest(test_util.TensorFlowTestCase): array_ops.zeros([2], dtypes.int64)) features["sparse_categorical"] = sparse_tensor.SparseTensor( array_ops.zeros([2, 2], dtypes.int64), - array_ops.zeros( - [2], dtypes.string), array_ops.zeros([2], dtypes.int64)) + array_ops.zeros([2], dtypes.string), array_ops.zeros([2], + dtypes.int64)) feature_columns = set() feature_columns.add(layers.real_valued_column("dense_float")) feature_columns.add( @@ -235,7 +231,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -316,6 +313,113 @@ class GbdtTest(test_util.TensorFlowTestCase): }""" self.assertProtoEquals(expected_tree, output.trees[0]) + def testTrainFnChiefSparseAndDense(self): + """Tests the train function with sparse and dense features.""" + with self.test_session() as sess: + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, tree_ensemble_config="", name="tree_ensemble") + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.ones([4, 1], dtypes.float32) + features["sparse_float"] = sparse_tensor.SparseTensor( + array_ops.zeros([2, 2], dtypes.int64), + array_ops.zeros([2], dtypes.float32), + array_ops.constant([4, 1], dtypes.int64)) + + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=True, + num_ps_replicas=0, + center_bias=False, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features) + + predictions = array_ops.constant( + [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) + partition_ids = array_ops.zeros([4], dtypes.int32) + ensemble_stamp = variables.Variable( + initial_value=0, + name="ensemble_stamp", + trainable=False, + dtype=dtypes.int64) + + predictions_dict = { + "predictions": predictions, + "predictions_no_dropout": predictions, + "partition_ids": partition_ids, + "ensemble_stamp": ensemble_stamp, + "num_trees": 12, + } + + labels = array_ops.ones([4, 1], dtypes.float32) + weights = array_ops.ones([4, 1], dtypes.float32) + # Create train op. + train_op = gbdt_model.train( + loss=math_ops.reduce_mean( + _squared_loss(labels, weights, predictions)), + predictions_dict=predictions_dict, + labels=labels) + variables.global_variables_initializer().run() + resources.initialize_resources(resources.shared_resources()).run() + + # On first run, expect no splits to be chosen because the quantile + # buckets will not be ready. + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 0) + self.assertEquals(len(output.tree_weights), 0) + self.assertEquals(stamp_token.eval(), 1) + + # Update the stamp to be able to run a second time. + sess.run([ensemble_stamp.assign_add(1)]) + + train_op.run() + stamp_token, serialized = model_ops.tree_ensemble_serialize( + ensemble_handle) + output = tree_config_pb2.DecisionTreeEnsembleConfig() + output.ParseFromString(serialized.eval()) + self.assertEquals(len(output.trees), 1) + self.assertAllClose(output.tree_weights, [0.1]) + self.assertEquals(stamp_token.eval(), 2) + expected_tree = """ + nodes { + sparse_float_binary_split_default_right { + split{ + left_id: 1 + right_id: 2 + } + } + node_metadata { + gain: 1.125 + } + } + nodes { + leaf { + vector { + value: 1.0 + } + } + } + nodes { + leaf { + vector { + value: -0.5 + } + } + }""" + self.assertProtoEquals(expected_tree, output.trees[0]) + def testTrainFnChiefScalingNumberOfExamples(self): """Tests the train function running on chief without bias centering.""" with self.test_session() as sess: @@ -339,7 +443,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=num_examples_fn, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -442,7 +547,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -513,7 +619,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -576,7 +683,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) @@ -622,7 +730,8 @@ class GbdtTest(test_util.TensorFlowTestCase): with self.test_session() as sess: # Create ensemble with one bias node. ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - text_format.Merge(""" + text_format.Merge( + """ trees { nodes { leaf { @@ -659,16 +768,129 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=1, features=features) + logits_dimension=1, + features=features) # Create predict op. mode = model_fn.ModeKeys.EVAL predictions_dict = sess.run(gbdt_model.predict(mode)) self.assertEquals(predictions_dict["ensemble_stamp"], 3) - self.assertAllClose(predictions_dict["predictions"], [[0.25], [0.25], - [0.25], [0.25]]) + self.assertAllClose(predictions_dict["predictions"], + [[0.25], [0.25], [0.25], [0.25]]) self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0]) + def testPredictFnWithLeafIndexAdvancedLeft(self): + """Tests the predict function with output leaf ids.""" + with self.test_session() as sess: + # Create ensemble with one bias node. + ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge( + """ + trees { + nodes { + dense_float_binary_split { + threshold: 1.0 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 0 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.15 + } + } + } + } + trees { + nodes { + dense_float_binary_split { + threshold: 0.99 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 00 + } + } + nodes { + leaf { + vector { + value: 0.25 + } + } + } + nodes { + leaf { + vector { + value: 0.23 + } + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + } + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + }""", ensemble_config) + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=3, + tree_ensemble_config=ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 + learner_config.num_classes = 2 + learner_config.regularization.l1 = 0 + learner_config.regularization.l2 = 0 + learner_config.constraints.max_tree_depth = 1 + learner_config.constraints.min_node_weight = 0 + features = {} + features["dense_float"] = array_ops.constant( + [[0.0], [1.0], [1.1], [2.0]], dtype=dtypes.float32) + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=False, + num_ps_replicas=0, + center_bias=True, + ensemble_handle=ensemble_handle, + examples_per_layer=1, + learner_config=learner_config, + logits_dimension=1, + features=features, + output_leaf_index=True) + + # Create predict op. + mode = model_fn.ModeKeys.INFER + predictions_dict = sess.run(gbdt_model.predict(mode)) + self.assertEquals(predictions_dict["ensemble_stamp"], 3) + # here are how the numbers in expected results are calculated, + # 0.5 = 0.25 + 0.25 + # 0.48 = 0.25 + 0.23 + # 0.38 = 0.15 + 0.23 + # 0.38 = 0.15 + 0.23 + self.assertAllClose(predictions_dict["predictions"], + [[0.5], [0.48], [0.38], [0.38]]) + self.assertAllClose(predictions_dict["partition_ids"], [0, 0, 0, 0]) + self.assertAllClose(predictions_dict["leaf_index"], + [[1, 1], [1, 2], [2, 2], [2, 2]]) + def testTrainFnMulticlassFullHessian(self): """Tests the GBDT train for multiclass full hessian.""" with self.test_session() as sess: @@ -698,7 +920,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], @@ -801,7 +1024,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) predictions = array_ops.constant( [[0.0, -1.0, 0.5, 1.2, 3.1], [1.0, 0.0, 0.8, 0.3, 1.0], @@ -893,8 +1117,8 @@ class GbdtTest(test_util.TensorFlowTestCase): learner_config.constraints.max_tree_depth = 1 learner_config.constraints.min_node_weight = 0 features = { - "dense_float": array_ops.constant( - [[1.0], [1.5], [2.0]], dtypes.float32), + "dense_float": + array_ops.constant([[1.0], [1.5], [2.0]], dtypes.float32), } gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( @@ -904,7 +1128,8 @@ class GbdtTest(test_util.TensorFlowTestCase): ensemble_handle=ensemble_handle, examples_per_layer=1, learner_config=learner_config, - logits_dimension=5, features=features) + logits_dimension=5, + features=features) batch_size = 3 predictions = array_ops.constant( @@ -986,7 +1211,8 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertAllClose( 0.893284678459, output.trees[0].nodes[2].leaf.sparse_vector.value[0], - atol=1e-4, rtol=1e-4) + atol=1e-4, + rtol=1e-4) def testTrainFnChiefFeatureSelectionReachedLimitNoGoodSplit(self): """Tests the train function running on chief with feature selection.""" @@ -1230,9 +1456,9 @@ class GbdtTest(test_util.TensorFlowTestCase): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree = tree_ensemble_config.trees.add() - _set_float_split(tree.nodes.add() - .sparse_float_binary_split_default_right.split, 2, 4.0, - 1, 2) + _set_float_split( + tree.nodes.add().sparse_float_binary_split_default_right.split, 2, + 4.0, 1, 2) _append_to_leaf(tree.nodes.add().leaf, 0, 0.5) _append_to_leaf(tree.nodes.add().leaf, 1, 1.2) tree_ensemble_config.tree_weights.append(1.0) @@ -1241,7 +1467,8 @@ class GbdtTest(test_util.TensorFlowTestCase): metadata.num_layers_grown = 1 tree_ensemble_config = tree_ensemble_config.SerializeToString() ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, tree_ensemble_config=tree_ensemble_config, + stamp_token=0, + tree_ensemble_config=tree_ensemble_config, name="tree_ensemble") learner_config = learner_pb2.LearnerConfig() learner_config.learning_rate_tuner.fixed.learning_rate = 0.1 @@ -1333,5 +1560,6 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertEquals(output.growing_metadata.num_layers_attempted, 2) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py index af8df72618b7255e182e98e6e4b96a0333b3dce6..8ae493ba998bd882b5ef946f927ec1882d91f61d 100644 --- a/tensorflow/contrib/checkpoint/__init__.py +++ b/tensorflow/contrib/checkpoint/__init__.py @@ -18,11 +18,15 @@ Visualization and inspection: @@dot_graph_from_checkpoint @@object_metadata -Creating and managing dependencies: +Managing dependencies: @@Checkpointable @@CheckpointableObjectGraph @@NoDependency @@split_dependency + +Checkpointable data structures: +@@List +@@Mapping @@UniqueNameTracker """ @@ -36,8 +40,11 @@ from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkp from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph from tensorflow.python.training.checkpointable.base import Checkpointable from tensorflow.python.training.checkpointable.base import NoDependency +from tensorflow.python.training.checkpointable.data_structures import List +from tensorflow.python.training.checkpointable.data_structures import Mapping from tensorflow.python.training.checkpointable.util import object_metadata from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(module_name=__name__) + diff --git a/tensorflow/contrib/checkpoint/python/BUILD b/tensorflow/contrib/checkpoint/python/BUILD index 53f4e97f9932104933b3ecf80142e5af82cd487a..7b200a29bf60087d6da1010b0be05c04faec80cd 100644 --- a/tensorflow/contrib/checkpoint/python/BUILD +++ b/tensorflow/contrib/checkpoint/python/BUILD @@ -11,6 +11,7 @@ py_library( ":containers", ":split_dependency", ":visualize", + "//tensorflow/python/training/checkpointable:data_structures", ], ) @@ -19,7 +20,10 @@ py_library( srcs = ["containers.py"], srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], - deps = ["//tensorflow/python/training/checkpointable:base"], + deps = [ + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:data_structures", + ], ) py_test( @@ -30,8 +34,8 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", "@six_archive//:six", ], ) @@ -44,6 +48,7 @@ py_library( deps = [ "//tensorflow/python:control_flow_ops", "//tensorflow/python:training", + "//tensorflow/python/training/checkpointable:base", ], ) @@ -55,8 +60,9 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:training", "//tensorflow/python/eager:test", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", ], ) @@ -67,6 +73,8 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python/training/checkpointable:base", + "//tensorflow/python/training/checkpointable:util", ], ) @@ -75,10 +83,13 @@ py_test( srcs = ["visualize_test.py"], deps = [ ":visualize", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:constant_op", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", + "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", + "//tensorflow/python/keras:engine", + "//tensorflow/python/keras:layers", + "//tensorflow/python/training/checkpointable:util", ], ) diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 9807abae1f5106bb84f858c3725f096aaa4eaca9..4d3d5312993740636709cb732c0b8e3e2626262d 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -18,9 +18,10 @@ from __future__ import division from __future__ import print_function from tensorflow.python.training.checkpointable import base as checkpointable_lib +from tensorflow.python.training.checkpointable import data_structures -class UniqueNameTracker(checkpointable_lib.CheckpointableBase): +class UniqueNameTracker(data_structures.CheckpointableDataStructure): """Adds dependencies on checkpointable objects with name hints. Useful for creating dependencies with locally unique names. @@ -41,6 +42,7 @@ class UniqueNameTracker(checkpointable_lib.CheckpointableBase): """ def __init__(self): + super(UniqueNameTracker, self).__init__() self._maybe_initialize_checkpointable() self._name_counts = {} @@ -74,4 +76,5 @@ class UniqueNameTracker(checkpointable_lib.CheckpointableBase): count += 1 candidate = _format_name(base_name, count) self._name_counts[base_name] = count + 1 - return self._track_checkpointable(checkpointable, name=candidate) + self._track_value(checkpointable, name=candidate) + return checkpointable diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py index 851a80058852bd917aec075b4bf63264318603a7..3717d7f583ffdc205a279d45df60cddbc5cbf08e 100644 --- a/tensorflow/contrib/checkpoint/python/containers_test.py +++ b/tensorflow/contrib/checkpoint/python/containers_test.py @@ -22,6 +22,8 @@ import six from tensorflow.contrib.checkpoint.python import containers from tensorflow.python.framework import test_util +from tensorflow.python.keras import layers +from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test from tensorflow.python.training.checkpointable import base as checkpointable @@ -95,5 +97,12 @@ class UniqueNameTrackerTests(test.TestCase): dependency_names, ["x", "x_1", "y", "slot_manager", "slotdeps", "save_counter"]) + @test_util.run_in_graph_and_eager_modes() + def testLayers(self): + tracker = containers.UniqueNameTracker() + tracker.track(layers.Dense(3), "dense") + tracker.layers[0](array_ops.zeros([1, 1])) + self.assertEqual(2, len(tracker.trainable_weights)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index f3a75e8688ece19a6e6fd53ee9faf7f4144d76cf..42ba368531468b789a87429f88ca84937f9b909d 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -15,7 +15,10 @@ load( ) tf_gen_op_libs( - op_lib_names = ["bigquery_reader_ops"], + op_lib_names = [ + "bigquery_reader_ops", + "gcs_config_ops", + ], deps = [ "//tensorflow/core:lib", ], @@ -28,15 +31,25 @@ tf_gen_op_wrapper_py( deps = [":bigquery_reader_ops_op_lib"], ) +tf_gen_op_wrapper_py( + name = "gen_gcs_config_ops", + out = "python/ops/gen_gcs_config_ops.py", + require_shape_functions = True, + visibility = ["//tensorflow:internal"], + deps = [":gcs_config_ops_op_lib"], +) + py_library( name = "cloud_py", srcs = [ "__init__.py", "python/ops/bigquery_reader_ops.py", + "python/ops/gcs_config_ops.py", ], srcs_version = "PY2AND3", deps = [ ":gen_bigquery_reader_ops", + ":gen_gcs_config_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:io_ops", "//tensorflow/python:util", diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index 8870264b95dfd9f8c4b1655c475fe23e0639924f..a6e13ea3ae938444b9ead0772e52fb8797a847da 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -20,9 +20,15 @@ from __future__ import print_function # pylint: disable=line-too-long,wildcard-import from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * +from tensorflow.contrib.cloud.python.ops.gcs_config_ops import * # pylint: enable=line-too-long,wildcard-import from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['BigQueryReader'] +_allowed_symbols = [ + 'BigQueryReader', + 'ConfigureColabSession', + 'ConfigureGcs', + 'ConfigureGcsHook', +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index ff46f0daa80a70badedf73e15bfaf4dca85fdd89..40160706f70e8fa8323005dd183770ed51c8c415 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -73,3 +73,17 @@ tf_proto_library( srcs = ["bigquery_table_partition.proto"], cc_api_version = 2, ) + +tf_kernel_library( + name = "gcs_config_ops", + srcs = ["gcs_config_ops.cc"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform/cloud:curl_http_request", + "//tensorflow/core/platform/cloud:gcs_file_system", + "//tensorflow/core/platform/cloud:oauth_client", + "@jsoncpp_git//:jsoncpp", + ], +) diff --git a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..648a219fb87a6ebc64767a7da780013ef6b95443 --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc @@ -0,0 +1,205 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "include/json/json.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/cloud/gcs_file_system.h" +#include "tensorflow/core/platform/cloud/oauth_client.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { + +// The default initial delay between retries with exponential backoff. +constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec + +// The minimum time delta between now and the token expiration time +// for the token to be re-used. +constexpr int kExpirationTimeMarginSec = 60; + +// The URL to retrieve the auth bearer token via OAuth with a refresh token. +constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token"; + +// The URL to retrieve the auth bearer token via OAuth with a private key. +constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token"; + +// The authentication token scope to request. +constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform"; + +Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) { + DCHECK(fs != nullptr); + *fs = nullptr; + + FileSystem* filesystem = nullptr; + TF_RETURN_IF_ERROR( + ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem)); + if (filesystem == nullptr) { + return errors::FailedPrecondition("The GCS file system is not registered."); + } + + *fs = dynamic_cast(filesystem); + if (*fs == nullptr) { + return errors::Internal( + "The filesystem registered under the 'gs://' scheme was not a " + "tensorflow::RetryingGcsFileSystem*."); + } + return Status::OK(); +} + +template +Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, + T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar()(); + return Status::OK(); +} + +// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem. +class GcsCredentialsOpKernel : public OpKernel { + public: + explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + string json_string; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "json", &json_string)); + + Json::Value json; + Json::Reader reader; + std::stringstream json_stream(json_string); + OP_REQUIRES(ctx, reader.parse(json_stream, json), + errors::InvalidArgument("Could not parse json: ", json_string)); + + OP_REQUIRES( + ctx, json.isMember("refresh_token") || json.isMember("private_key"), + errors::InvalidArgument("JSON format incompatible; did not find fields " + "`refresh_token` or `private_key`.")); + + auto provider = + tensorflow::MakeUnique(json, ctx->env()); + + // Test getting a token + string dummy_token; + OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token)); + OP_REQUIRES(ctx, !dummy_token.empty(), + errors::InvalidArgument( + "Could not retrieve a token with the given credentials.")); + + // Set the provider. + gcs->underlying()->SetAuthProvider(std::move(provider)); + } + + private: + class ConstantAuthProvider : public AuthProvider { + public: + ConstantAuthProvider(const Json::Value& json, + std::unique_ptr oauth_client, Env* env, + int64 initial_retry_delay_usec) + : json_(json), + oauth_client_(std::move(oauth_client)), + env_(env), + initial_retry_delay_usec_(initial_retry_delay_usec) {} + + ConstantAuthProvider(const Json::Value& json, Env* env) + : ConstantAuthProvider(json, tensorflow::MakeUnique(), env, + kInitialRetryDelayUsec) {} + + ~ConstantAuthProvider() override {} + + Status GetToken(string* token) override { + mutex_lock l(mu_); + const uint64 now_sec = env_->NowSeconds(); + + if (!current_token_.empty() && + now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) { + *token = current_token_; + return Status::OK(); + } + if (json_.isMember("refresh_token")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson( + json_, kOAuthV3Url, ¤t_token_, &expiration_timestamp_sec_)); + } else if (json_.isMember("private_key")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson( + json_, kOAuthV4Url, kOAuthScope, ¤t_token_, + &expiration_timestamp_sec_)); + } else { + return errors::FailedPrecondition( + "Unexpected content of the JSON credentials file."); + } + + *token = current_token_; + return Status::OK(); + } + + private: + Json::Value json_; + std::unique_ptr oauth_client_; + Env* env_; + + mutex mu_; + string current_token_ GUARDED_BY(mu_); + uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0; + + // The initial delay for exponential backoffs when retrying failed calls. + const int64 initial_retry_delay_usec_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider); + }; +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureCredentials").Device(DEVICE_CPU), + GcsCredentialsOpKernel); + +class GcsBlockCacheOpKernel : public OpKernel { + public: + explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + size_t max_cache_size, block_size, max_staleness; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "max_cache_size", + &max_cache_size)); + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_size", &block_size)); + OP_REQUIRES_OK( + ctx, ParseScalarArgument(ctx, "max_staleness", &max_staleness)); + + if (gcs->underlying()->block_size() == block_size && + gcs->underlying()->max_bytes() == max_cache_size && + gcs->underlying()->max_staleness() == max_staleness) { + LOG(INFO) << "Skipping resetting the GCS block cache."; + return; + } + gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size, + max_staleness); + } +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureBlockCache").Device(DEVICE_CPU), + GcsBlockCacheOpKernel); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/ops/gcs_config_ops.cc b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..9cf85f5f1811d873075b6d2e1931d8badfd6e32c --- /dev/null +++ b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("GcsConfigureCredentials") + .Input("json: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Configures the credentials used by the GCS client of the local TF runtime. + +The json input can be of the format: + +1. Refresh Token: +{ + "client_id": "", + "client_secret": "", + "refresh_token: "", + "type": "authorized_user", +} + +2. Service Account: +{ + "type": "service_account", + "project_id": "", + "private_key_id": "", + "private_key": "------BEGIN PRIVATE KEY-----\n\n-----END PRIVATE KEY------\n", + "client_email": "@.iam.gserviceaccount.com", + "client_id": "", + # Some additional fields elided +} + +Note the credentials established through this method are shared across all +sessions run on this runtime. + +Note be sure to feed the inputs to this op to ensure the credentials are not +stored in a constant op within the graph that might accidentally be checkpointed +or in other ways be persisted or exfiltrated. +)doc"); + +REGISTER_OP("GcsConfigureBlockCache") + .Input("max_cache_size: uint64") + .Input("block_size: uint64") + .Input("max_staleness: uint64") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Re-configures the GCS block cache with the new configuration values. + +If the values are the same as already configured values, this op is a no-op. If +they are different, the current contents of the block cache is dropped, and a +new block cache is created fresh. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8c8c5acb31af69b4f738a13c6548cdd31947d71a --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -0,0 +1,188 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GCS file system configuration for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from tensorflow.contrib.cloud.python.ops import gen_gcs_config_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.training import training + + +# @tf_export('contrib.cloud.BlockCacheParams') +class BlockCacheParams(object): + """BlockCacheParams is a struct used for configuring the GCS Block Cache.""" + + def __init__(self, block_size=None, max_bytes=None, max_staleness=None): + self._block_size = block_size or 128 * 1024 * 1024 + self._max_bytes = max_bytes or 2 * self._block_size + self._max_staleness = max_staleness or 0 + + @property + def block_size(self): + return self._block_size + + @property + def max_bytes(self): + return self._max_bytes + + @property + def max_staleness(self): + return self._max_staleness + + +# @tf_export('contrib.cloud.ConfigureGcsHook') +class ConfigureGcsHook(training.SessionRunHook): + """ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Example: + + ``` + sess = tf.Session() + refresh_token = raw_input("Refresh token: ") + client_secret = raw_input("Client secret: ") + client_id = "" + creds = { + "client_id": client_id, + "refresh_token": refresh_token, + "client_secret": client_secret, + "type": "authorized_user", + } + tf.contrib.cloud.configure_gcs(sess, credentials=creds) + ``` + + """ + + def _verify_dictionary(self, creds_dict): + if 'refresh_token' in creds_dict or 'private_key' in creds_dict: + return True + return False + + def __init__(self, credentials=None, block_cache=None): + """Constructs a ConfigureGcsHook. + + Args: + credentials: A json-formatted string. + block_cache: A `BlockCacheParams` + + Raises: + ValueError: If credentials is improperly formatted or block_cache is not a + BlockCacheParams. + """ + if credentials is not None: + if isinstance(credentials, str): + try: + data = json.loads(credentials) + except ValueError as e: + raise ValueError('credentials was not a well formed JSON string.', e) + if not self._verify_dictionary(data): + raise ValueError( + 'credentials has neither a "refresh_token" nor a "private_key" ' + 'field.') + elif isinstance(credentials, dict): + if not self._verify_dictionary(credentials): + raise ValueError('credentials has neither a "refresh_token" nor a ' + '"private_key" field.') + credentials = json.dumps(credentials) + else: + raise ValueError('credentials is of an unknown type') + + self._credentials = credentials + + if block_cache and not isinstance(block_cache, BlockCacheParams): + raise ValueError('block_cache must be an instance of BlockCacheParams.') + self._block_cache = block_cache + + def begin(self): + if self._credentials: + self._credentials_placeholder = array_ops.placeholder(dtypes.string) + self._credentials_ops = gen_gcs_config_ops.gcs_configure_credentials( + self._credentials_placeholder) + if self._block_cache: + self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=self._block_cache.max_bytes, + block_size=self._block_cache.block_size, + max_staleness=self._block_cache.max_staleness) + + def after_create_session(self, session, coord): + del coord + if self._credentials_op: + session.run( + self._credentials_op, + feed_dict={self._credentials_placeholder: self._credentials}) + if self._block_cache_op: + session.run(self._block_cache_op) + + +def configure_gcs(session, credentials=None, block_cache=None, device=None): + """Configures the GCS file system for a given a session. + + Warning: GCS `credentials` may be transmitted over the network unencrypted. + Please ensure that the network is trusted before using this function. For + users running code entirely within Google Cloud, your data is protected by + encryption in between data centers. For more information, please take a look + at https://cloud.google.com/security/encryption-in-transit/. + + Args: + session: A `tf.Session` session that should be used to configure the GCS + file system. + credentials: [Optional.] A JSON string + block_cache: [Optional.] A BlockCacheParams to configure the block cache . + device: [Optional.] The device to place the configure ops. + """ + + def configure(credentials, block_cache): + """Helper function to actually configure GCS.""" + if credentials: + if isinstance(credentials, dict): + credentials = json.dumps(credentials) + placeholder = array_ops.placeholder(dtypes.string) + op = gen_gcs_config_ops.gcs_configure_credentials(placeholder) + session.run(op, feed_dict={placeholder: credentials}) + if block_cache: + op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=block_cache.max_bytes, + block_size=block_cache.block_size, + max_staleness=block_cache.max_staleness) + session.run(op) + + if device: + with ops.device(device): + return configure(credentials, block_cache) + return configure(credentials, block_cache) + + +def configure_colab_session(session): + """ConfigureColabSession configures the GCS file system in Colab. + + Args: + session: A `tf.Session` session. + """ + # Read from the application default credentials (adc). + with open('/content/datalab/adc.json') as f: + data = json.load(f) + configure_gcs(session, credentials=data) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 880fca4ea65608472838baee234e468bef37afb3..8f521ffee4d31e090c13bac98290656d6e1d330e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -36,6 +36,7 @@ except ImportError: _GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' +_ENDPOINTS_SEPARATOR = ',' _DEFAULT_ENV_VARIABLE = 'TPU_NAME' _DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' @@ -69,8 +70,8 @@ class TPUClusterResolver(ClusterResolver): return _GKE_ENV_VARIABLE in os.environ @staticmethod - def _gkeMaster(): - return os.environ[_GKE_ENV_VARIABLE].split(',')[0] + def _gkeEndpoints(): + return os.environ[_GKE_ENV_VARIABLE] @staticmethod def _envVarFallback(): @@ -143,7 +144,7 @@ class TPUClusterResolver(ClusterResolver): # When using GKE with Cloud TPUs, the env variable will be set. if tpu is None: if in_gke: - tpu = self._gkeMaster() + tpu = self._gkeEndpoints() else: tpu = self._envVarFallback() @@ -170,10 +171,11 @@ class TPUClusterResolver(ClusterResolver): if service is None and should_resolve: if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient must be installed before using the ' - 'TPU cluster resolver. Execute: `pip install ' - '--upgrade google-api-python-client` to install with ' - 'pip.') + raise ImportError('googleapiclient and oauth2client must be installed ' + 'before using the TPU cluster resolver. Execute: ' + '`pip install --upgrade google-api-python-client` ' + 'and `pip install --upgrade oauth2client` to ' + 'install with pip.') final_discovery_url = self._discoveryUrl() or discovery_url if final_discovery_url: @@ -213,7 +215,7 @@ class TPUClusterResolver(ClusterResolver): ValueError: If none of the TPUs specified exists. """ if not self._shouldResolve(): - return self._tpu + return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0] job_tasks = self.cluster_spec().job_tasks(self._job_name) if not job_tasks: @@ -255,6 +257,10 @@ class TPUClusterResolver(ClusterResolver): request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() + if 'state' in response and response['state'] != 'READY': + raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % + (self._tpu, response['state'])) + if 'health' in response and response['health'] != 'HEALTHY': raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, response['health'])) @@ -275,8 +281,12 @@ class TPUClusterResolver(ClusterResolver): # Case 3. return None # Case 2. - cluster_spec = {self._job_name: [self._tpu[len( - compat.as_bytes('grpc://')):]]} + cluster_spec = { + self._job_name: [ + x[len(compat.as_bytes('grpc://')):] + for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR)) + ] + } if self._coordinator_address: # {1, 2}.a diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 5fac55fd027fa2d100621e08a09e05cdb3a1b941..ad4f6432630be44a7de6e778f55f1fb7fd66f307 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -158,6 +158,50 @@ class TPUClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testUnhealthyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'health': 'UNHEALTHY' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testNotReadyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'state': 'CREATING' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { @@ -358,13 +402,61 @@ class TPUClusterResolverTest(test.TestCase): compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master()) self.assertEqual(None, tpu_cluster_resolver.cluster_spec()) - def testGkeEnvironment(self): + def testGkeEnvironmentForDonut(self): os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470' - self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ) + + self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) + self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(TPUClusterResolver._gkeEndpoints())) + + tpu_cluster_resolver = TPUClusterResolver() + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470'), + compat.as_bytes(tpu_cluster_resolver.master())) + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.120.27.5:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + + def testGkeEnvironmentForPod(self): + os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,' + 'grpc://10.120.27.6:8470,' + 'grpc://10.120.27.7:8470,' + 'grpc://10.120.27.8:8470') + + self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) self.assertTrue(TPUClusterResolver._inGke()) + self.assertEqual( + compat.as_bytes('grpc://10.120.27.5:8470,' + 'grpc://10.120.27.6:8470,' + 'grpc://10.120.27.7:8470,' + 'grpc://10.120.27.8:8470'), + compat.as_bytes(TPUClusterResolver._gkeEndpoints())) + + tpu_cluster_resolver = TPUClusterResolver() self.assertEqual( compat.as_bytes('grpc://10.120.27.5:8470'), - compat.as_bytes(TPUClusterResolver._gkeMaster())) + compat.as_bytes(tpu_cluster_resolver.master())) + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.120.27.5:8470' } + tasks { key: 1 value: '10.120.27.6:8470' } + tasks { key: 2 value: '10.120.27.7:8470' } + tasks { key: 3 value: '10.120.27.8:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] def testDiscoveryUrl(self): diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 0708d6b7b9f0ba549aea091a265f42890e50d223..e524e9e7437b19e0d117fe7b85042e8154773a02 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -18,7 +18,16 @@ cmake_policy(SET CMP0022 NEW) # Options option(tensorflow_VERBOSE "Enable for verbose output" OFF) + +if(WIN32) +# BoringSSL is disabled for windows as it currently doesn't build with +# MSBuild. (Ninja is required.) option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF) +else() +# BoringSSL is enabled for gRPC. +option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" ON) +endif() + option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF) diff --git a/tensorflow/contrib/cmake/external/double_conversion.cmake b/tensorflow/contrib/cmake/external/double_conversion.cmake index 527ccdc8d887cb4c2e7d2412c99a8bc682568472..5c5adaf5798289fba1c5d0b3f9e0489dc242043e 100644 --- a/tensorflow/contrib/cmake/external/double_conversion.cmake +++ b/tensorflow/contrib/cmake/external/double_conversion.cmake @@ -16,15 +16,15 @@ include (ExternalProject) set(double_conversion_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/double_conversion/src/double_conversion) set(double_conversion_URL https://github.com/google/double-conversion.git) -set(double_conversion_TAG 5664746) +set(double_conversion_TAG 3992066a95b823efc8ccc1baf82a1cfc73f6e9b8) set(double_conversion_BUILD ${double_conversion_INCLUDE_DIR}) set(double_conversion_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.so) set(double_conversion_INCLUDES ${double_conversion_BUILD}) if(WIN32) - set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/$(Configuration)/double-conversion.lib) + set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/$(Configuration)/double-conversion.lib) else() - set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.a) + set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/libdouble-conversion.a) endif() set(double_conversion_HEADERS diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 693dc7cd673233b889b35a3f3170b57581da9a9f..b1e64aa55c80ad59cfdc0f4767c0282b4f73367f 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -20,6 +20,10 @@ set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f) if(WIN32) + # We use unsecure gRPC because boringssl does not build on windows + set(grpc_TARGET grpc++_unsecure) + set(grpc_DEPENDS protobuf zlib) + set(grpc_SSL_PROVIDER NONE) if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib @@ -32,9 +36,12 @@ if(WIN32) ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/gpr.lib) endif() else() + set(grpc_TARGET grpc++) + set(grpc_DEPENDS boringssl protobuf zlib) + set(grpc_SSL_PROVIDER module) set(grpc_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++.a + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) @@ -44,13 +51,13 @@ add_definitions(-DGRPC_ARES=0) ExternalProject_Add(grpc PREFIX grpc - DEPENDS protobuf zlib + DEPENDS ${grpc_DEPENDS} GIT_REPOSITORY ${GRPC_URL} GIT_TAG ${GRPC_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 BUILD_BYPRODUCTS ${grpc_STATIC_LIBRARIES} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target ${grpc_TARGET} COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin INSTALL_COMMAND "" CMAKE_CACHE_ARGS @@ -59,7 +66,7 @@ ExternalProject_Add(grpc -DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS} -DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES} -DZLIB_ROOT:STRING=${ZLIB_INSTALL} - -DgRPC_SSL_PROVIDER:STRING=NONE + -DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER} ) # grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h. diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index fece56c4127de4deebc1404f0eff9747f99ba89f..015cb73bbd93bb77f6748a364b263d99eb305c27 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -115,6 +115,8 @@ tensorflow/contrib/coder/python/ops tensorflow/contrib/compiler tensorflow/contrib/constrained_optimization tensorflow/contrib/constrained_optimization/python +tensorflow/contrib/control_flow +tensorflow/contrib/control_flow/python tensorflow/contrib/copy_graph tensorflow/contrib/copy_graph/python tensorflow/contrib/copy_graph/python/util diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index a06bdf78fb011b288d5d7af6488ec6802ff34c35..2e0a2fcef4cbdc50f0521296c4a25a864dbd8b77 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -21,6 +21,7 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" + "${tensorflow_source_dir}/tensorflow/c/eager/c_api_debug.cc" "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.h" diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index b47c32f1c48b3d42fe5b4ba115cc2a511b7ee5f4..dac84ccb0dbf4848329e35a6e9bcf6213d8c0e55 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -213,10 +213,6 @@ else() list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude}) endif() -file(GLOB tf_core_platform_exclude_srcs - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc") -list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_exclude_srcs}) - list(APPEND tf_core_lib_srcs ${tf_core_platform_srcs}) if(UNIX) @@ -286,8 +282,6 @@ set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.c file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/framework/*.h" "${tensorflow_source_dir}/tensorflow/core/framework/*.cc" - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.h" - "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc" "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.h" "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index e558691de4b74988031f7b2204aad92e8c7af68b..bc753333dba4f67eee0114c4022743dd59a05982 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -113,6 +113,7 @@ GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_stats "${tensorflow_source_dir}/tensor GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}") GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(gcs_config "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/gcs_config_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(reduce_slice_ops "${tensorflow_source_dir}/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc") ######################################################## diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 894b1ead7688bd951c5c78f613fdb7aae226fe65..92446044892127284ecb8753a250b77cb2a5743a 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -420,6 +420,8 @@ GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py) GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_gcs_config_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_gcs_config_ops.py) GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" @@ -723,7 +725,7 @@ endif() ######################################################## # Parse tensorflow/tools/api/generator/BUILD to get list of generated files. -FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/BUILD api_generator_BUILD_text) +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) STRING(REGEX MATCH "# BEGIN GENERATED FILES.*# END GENERATED FILES" api_init_files_text ${api_generator_BUILD_text}) string(REPLACE "# BEGIN GENERATED FILES" "" api_init_files_text ${api_init_files_text}) string(REPLACE "# END GENERATED FILES" "" api_init_files_text ${api_init_files_text}) @@ -734,7 +736,7 @@ foreach(api_init_file ${api_init_files_list}) string(STRIP "${api_init_file}" api_init_file) if(api_init_file) string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes - list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/${api_init_file}") + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/${api_init_file}") endif() endforeach(api_init_file) set(api_init_list_file "${tensorflow_source_dir}/api_init_files_list.txt") @@ -747,18 +749,16 @@ add_custom_command( # tensorflow/__init__.py depends on files generated in this step. So, remove it while # this step is running since the files aren't there yet. - COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py - COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py # Run create_python_api.py to generate API init files. COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" "${api_init_list_file}" - - # Re-add tensorflow/__init__.py back. - COMMAND ${CMAKE_COMMAND} -E remove -f ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py - COMMAND ${CMAKE_COMMAND} -E rename ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/final.__init__.py - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/__init__.py + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--root_init_template=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/api_template.__init__.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow" + "--package=tensorflow.python" + "--apiname=tensorflow" + "${api_init_list_file}" COMMENT "Generating __init__.py files for Python API." WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" @@ -767,7 +767,49 @@ add_custom_command( add_custom_target(tf_python_api SOURCES ${api_init_files}) add_dependencies(tf_python_api tf_python_ops) +# TODO(mikecase): This can be removed once tf.estimator is moved +# out of TensorFlow. +######################################################## +# Generate API __init__.py files for tf.estimator. +######################################################## + +# Parse tensorflow/tools/api/generator/BUILD to get list of generated files. +FILE(READ ${tensorflow_source_dir}/tensorflow/tools/api/generator/api_gen.bzl api_generator_BUILD_text) +STRING(REGEX MATCH "# BEGIN GENERATED ESTIMATOR FILES.*# END GENERATED ESTIMATOR FILES" api_init_files_text ${api_generator_BUILD_text}) +string(REPLACE "# BEGIN GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "# END GENERATED ESTIMATOR FILES" "" api_init_files_text ${api_init_files_text}) +string(REPLACE "," ";" api_init_files_list ${api_init_files_text}) + +set(api_init_files "") +foreach(api_init_file ${api_init_files_list}) + string(STRIP "${api_init_file}" api_init_file) + if(api_init_file) + string(REPLACE "\"" "" api_init_file "${api_init_file}") # Remove quotes + list(APPEND api_init_files "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api/${api_init_file}") + endif() +endforeach(api_init_file) +set(estimator_api_init_list_file "${tensorflow_source_dir}/estimator_api_init_files_list.txt") +file(WRITE "${estimator_api_init_list_file}" "${api_init_files}") + +# Run create_python_api.py to generate __init__.py files. +add_custom_command( + OUTPUT ${api_init_files} + DEPENDS tf_python_ops tf_python_copy_scripts_to_destination pywrap_tensorflow_internal tf_python_touchup_modules tf_extension_ops + + # Run create_python_api.py to generate API init files. + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}/tf_python ${PYTHON_EXECUTABLE} + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tools/api/generator/create_python_api.py" + "--apidir=${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/estimator/api" + "--package=tensorflow.python.estimator" + "--apiname=estimator" + "${estimator_api_init_list_file}" + + COMMENT "Generating __init__.py files for Python API." + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/tf_python" +) +add_custom_target(estimator_python_api SOURCES ${api_init_files}) +add_dependencies(estimator_python_api tf_python_ops) ############################################################ # Build a PIP package containing the TensorFlow runtime. ############################################################ @@ -778,6 +820,7 @@ add_dependencies(tf_python_build_pip_package tf_python_touchup_modules tf_python_ops tf_python_api + estimator_python_api tf_extension_ops) # Fix-up Python files that were not included by the add_python_module() macros. diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 5942ff3363a96de70df7e13d0857e4ad82e35fee..eb9482dc25f2be8ce46cc38bf3dd28889b09a9d4 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -212,6 +212,10 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py" # Disable following manual tag in BUILD. "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py" + # These tests depend on a .so file + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/duplicate_op_test.py + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/invalid_op_test.py + ${tensorflow_source_dir}/tensorflow/python/kernel_tests/ackermann_test.py ) if (WIN32) diff --git a/tensorflow/contrib/control_flow/BUILD b/tensorflow/contrib/control_flow/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..e8036d63aeeac224b226899c036416a06b4ffe65 --- /dev/null +++ b/tensorflow/contrib/control_flow/BUILD @@ -0,0 +1,53 @@ +# New implementations of control flow ops + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +py_library( + name = "control_flow", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":cond_v2", + ], +) + +py_library( + name = "cond_v2", + srcs = ["python/cond_v2.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:c_api_util", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:function_def_to_graph", + "//tensorflow/python:functional_ops_gen", + "//tensorflow/python:gradients", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:util", + ], +) + +tf_py_test( + name = "cond_v2_test", + size = "small", + srcs = ["python/cond_v2_test.py"], + additional_deps = [ + ":cond_v2", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:training", + ], + grpc_enabled = True, +) diff --git a/tensorflow/contrib/control_flow/__init__.py b/tensorflow/contrib/control_flow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..582af2cf10a3d92dd8611b0f2826625e3acfb099 --- /dev/null +++ b/tensorflow/contrib/control_flow/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""New implementations of TF control flow ops. + +@@cond_v2 +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.contrib.control_flow.python.cond_v2 import cond_v2 +# pylint: enable=unused-import + +from tensorflow.python.util.all_util import remove_undocumented + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/control_flow/python/cond_v2.py b/tensorflow/contrib/control_flow/python/cond_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..9ffad9caa92d2d3be8f598758a443b0eceb8d4d8 --- /dev/null +++ b/tensorflow/contrib/control_flow/python/cond_v2.py @@ -0,0 +1,429 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""cond_v2 and gradient. + +This is a version of cond that emits a single If op, as well as the gradient +function for If ops produced by cond_v2. This will eventually replace the +current tf.cond implementation once it reaches feature and performance parity. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python import pywrap_tensorflow as c_api +from tensorflow.python.framework import c_api_util +from tensorflow.python.framework import function +from tensorflow.python.framework import function_def_to_graph +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_functional_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.util import compat + + +# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify +# that they aren't part of the official public API. These protected members +# often need to be used by implementation code however. Rather than litter the +# code with pylint comments, we ignore protected access violations for +# readability. +# pylint: disable=protected-access + + +def cond_v2(pred, true_fn, false_fn, name="cond"): + """Like tf.cond, except emits a single If op.""" + with ops.name_scope(name) as scope: + true_graph = function.func_graph_from_py_func(true_fn, [], [], + name="%s_true" % scope) + false_graph = function.func_graph_from_py_func(false_fn, [], [], + name="%s_false" % scope) + _check_same_outputs(true_graph, false_graph) + + # Add inputs to true_graph and false_graph to make them match. Note that + # this modifies true_graph and false_graph. + cond_inputs = _make_inputs_match(true_graph, false_graph, + true_graph.extra_inputs, + false_graph.extra_inputs) + + # Add all intermediate tensors as function outputs so they're available for + # the gradient computation. + + true_intermediates = _get_intermediates(true_graph) + false_intermediates = _get_intermediates(false_graph) + + # Save the original number of outputs to return to the caller. + num_cond_outputs = len(true_graph.outputs) + + # Make the number/type of new intermediate outputs match. + extra_true_outputs, extra_false_outputs = _pad_params( + true_graph, false_graph, true_intermediates, false_intermediates) + + true_graph.outputs.extend(extra_true_outputs) + false_graph.outputs.extend(extra_false_outputs) + + # Create the If op. + tensors = gen_functional_ops._if( + pred, cond_inputs, [t.dtype for t in true_graph.outputs], + _create_new_tf_function(true_graph), + _create_new_tf_function(false_graph), + name=scope) + + return tensors[:num_cond_outputs] + + +@ops.RegisterGradient("If") +def _IfGrad(op, *grads): # pylint: disable=invalid-name + """The gradient of an If op produced by cond_v2.""" + true_graph, false_graph = _get_func_graphs(op) + + # Create grad functions that compute the gradient of the true/false forward + # graphs. These functions will capture tensors from the forward pass + # functions. + true_grad_graph = _create_grad_func( + true_graph, grads, _get_grad_fn_name(true_graph)) + false_grad_graph = _create_grad_func( + false_graph, grads, _get_grad_fn_name(false_graph)) + + assert ([t.dtype for t in true_grad_graph.outputs] == + [t.dtype for t in false_grad_graph.outputs]) + + # Match up the captured grad function inputs with outputs of 'op' and other + # external tensors. + true_grad_inputs = _get_grad_inputs(op, true_graph, true_grad_graph) + false_grad_inputs = _get_grad_inputs(op, false_graph, false_grad_graph) + + # Make the inputs to true_grad_graph and false_grad_graph match. Note that + # this modifies true_grad_graph and false_grad_graph. + grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, + true_grad_inputs, false_grad_inputs) + + # Add all intermediate tensors as function outputs so they're available for + # higher-order gradient computations. + + true_grad_intermediates = _get_intermediates(true_grad_graph) + false_grad_intermediates = _get_intermediates(false_grad_graph) + + # Save the original number of gradient outputs to return. + num_grad_outputs = len(true_grad_graph.outputs) + + # Make the number/type of new intermediate outputs match. + extra_true_grad_outputs, extra_false_grad_outputs = _pad_params( + true_grad_graph, false_grad_graph, + true_grad_intermediates, false_grad_intermediates) + + true_grad_graph.outputs.extend(extra_true_grad_outputs) + false_grad_graph.outputs.extend(extra_false_grad_outputs) + + # Create the gradient If op. + tensors = gen_functional_ops._if( + op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs], + _create_new_tf_function(true_grad_graph), + _create_new_tf_function(false_grad_graph)) + + # The predicate has no gradient. + return [None] + tensors[:num_grad_outputs] + + +def _get_func_graphs(if_op): + """Returns `_FuncGraph`s for the input op branches. + + Args: + if_op: The _If Operation. + + Returns: + A 2-tuple of the `_FuncGraph`s of the then_branch and else_branch. + """ + def _get_func_graph_for_branch(branch_name): + extra_inputs = if_op.inputs[1:] # First input is pred. + input_shapes = [t.shape for t in extra_inputs] + func_name = if_op.get_attr(branch_name).name + fdef = if_op.graph._get_function(func_name).definition + func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) + func_graph.extra_inputs = extra_inputs + func_graph.extra_args = func_graph.inputs + func_graph._captured = dict(zip(extra_inputs, func_graph.inputs)) + return func_graph + + return (_get_func_graph_for_branch("then_branch"), + _get_func_graph_for_branch("else_branch")) + + +def _grad_fn(func_graph, grads): + """The gradient function for each conditional branch. + + This function builds the gradient graph of the corresponding forward-pass + conditional branch in `func_graph`. This is done by differentiating + func_graph's outputs w.r.t. its inputs. + + Args: + func_graph: function._FuncGraph. The corresponding forward-pass function. + grads: The list of input gradient Tensors. + + Returns: + The output gradient Tensors. + """ + # Filter out untrainable function outputs. + # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes + # cause _GradientsHelper to raise an exception (e.g. the implementation + # doesn't expect 'ys' to contain boolean tensors). + assert len(func_graph.outputs) == len(grads) + ys = [] + grad_ys = [] + for y, grad_y in zip(func_graph.outputs, grads): + if not gradients_impl._IsTrainable(y): + continue + ys.append(y) + grad_ys.append(grad_y) + + # Build the gradient graph. Note that this builds the gradient computation of + # func_graph in the current graph, which requires capturing tensors from + # func_graph. The captured func_graph tensors are resolved to external tensors + # in _get_grad_inputs. + result = gradients_impl._GradientsHelper( + ys, func_graph.inputs, grad_ys=grad_ys, + src_graph=func_graph) + + # Functions can't return None; replace Nones with zero tensors. + # TODO(b/80444525): don't return anything here and make _IfGrad return None if + # both branches have zero gradient. + for i in range(len(result)): + if result[i] is None: + result[i] = array_ops.zeros_like(func_graph.inputs[i]) + + return result + + +def _create_grad_func(func_graph, grads, name): + """Returns the _FuncGraph representation of _grad_fn.""" + return function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads), + [], [], name) + + +def _get_grad_inputs(if_op, cond_graph, grad_graph): + """Returns the tensors we should pass to grad_graph. + + This method handles tensors captured from cond_graph in grad_graph. It + converts these to suitable input tensors from the outer graph. + + Args: + if_op: Operation. The forward-pass If op that uses cond_graph. + cond_graph: function._FuncGraph. The forward-pass function. + grad_graph: function._FuncGraph. The gradients function. + + Returns: + A list of inputs tensors to be passed to grad_graph. + """ + inputs = [] + + # Maps placeholders in cond_graph -> input tensor in outer graph. + forward_input_map = {v: k for k, v in cond_graph._captured.items()} + + for t in grad_graph.extra_inputs: + if t.graph == ops.get_default_graph(): + # t is in the outer graph (e.g. one of the input gradients). + inputs.append(t) + elif t in forward_input_map: + # t is an input placeholder in cond_graph. Get the corresponding input + # tensor in the outer graph. + assert t.graph == cond_graph + assert forward_input_map[t].graph == ops.get_default_graph() + inputs.append(forward_input_map[t]) + else: + # t is an intermediate value in cond_graph. Get the corresponding output + # of 'if_op' (note that all intermediate values are outputs). + assert t.graph == cond_graph + output_idx = cond_graph.outputs.index(t) + inputs.append(if_op.outputs[output_idx]) + + return inputs + + +def _create_new_tf_function(func_graph): + """Converts func_graph to a TF_Function and adds it to the current graph. + + Args: + func_graph: function._FuncGraph + + Returns: + The name of the new TF_Function. + """ + c_func = c_api.TF_GraphToFunction_wrapper( + func_graph._c_graph, + compat.as_str(func_graph.name), + False, # append_hash_to_fn_name + None, # opers + [t._as_tf_output() for t in func_graph.inputs], + [t._as_tf_output() for t in func_graph.outputs], + [], + None, # opts + None) # description + _ = c_api_util.ScopedTFFunction(c_func) + + # TODO(b/109833212): this sucks, we're serializing the TF_Function*, + # deserializing it into a Python FunctionDef, then reserializing it to create + # a new TF_Function that we add to the graph. + fdef = function.function_def_from_tf_function(c_func) + defined_func = function._from_definition(fdef) + defined_func.add_to_graph(ops.get_default_graph()) + + return func_graph.name + + +def _get_intermediates(func_graph): + """Returns all tensors in `func_graph` that aren't inputs or outputs.""" + intermediates = [] + for op in func_graph.get_operations(): + for t in op.outputs: + if t in func_graph.inputs: continue + if t in func_graph.outputs: continue + intermediates.append(t) + return intermediates + + +def _separate_unique_inputs(true_inputs, false_inputs): + """Separates tensors appearing only in true_inputs or false_inputs, or both. + + Args: + true_inputs: list of Tensors + false_inputs: list of Tensors + + Returns: + Three lists of Tensors: + 1. The tensors that appear in both true_inputs and false_inputs + 2. The tensors that only appear in true_inputs + 3. The tensors that only appear in false_inputs + """ + true_inputs = set(true_inputs) + false_inputs = set(false_inputs) + + shared_inputs = true_inputs.intersection(false_inputs) + true_only_inputs = true_inputs - false_inputs + false_only_inputs = false_inputs - true_inputs + + return list(shared_inputs), list(true_only_inputs), list(false_only_inputs) + + +def _pad_params(true_graph, false_graph, true_params, false_params): + """Returns new param lists that have matching signatures. + + This is done by mirroring each param list in the other using dummy params. + There is no merging of params. + + Args: + true_graph: function._FuncGraph + false_graph: function._FuncGraph + true_params: a list of Tensors from true_graph + false_params: a list of Tensors from false_graph + + Returns: + A new list of Tensors in true_graph and a new list of Tensors in + false_graph. The two lists have the same number of Tensors, with matching + types and shapes across the lists. + """ + new_true_params = (true_params + + _create_dummy_params(true_graph, false_params)) + new_false_inputs = (_create_dummy_params(false_graph, true_params) + + false_params) + return new_true_params, new_false_inputs + + +def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): + """Modifies true_graph and false_graph so they have the same input signature. + + This method reorders and/or adds parameters to true_graph and false_graph so + they have the same input signature, and updates the 'inputs', 'extra_inputs', + and '_captured' fields of both graphs accordingly. It uses the input tensors + from the outer graph to avoid duplicating shared arguments. + + Args: + true_graph: function._FuncGraph + false_graph: function._FuncGraph + true_inputs: a list of Tensors in the outer graph. The inputs for + true_graph. + false_inputs: a list of Tensors in the outer graph. The inputs for + false_graph. + + Returns: + A new list of Tensors from the outer graph that are the new inputs for both + true_graph and false_graph. This is a deduped version of true_inputs + + false_inputs. + """ + shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs( + true_inputs, false_inputs) + + new_inputs = shared_inputs + true_only_inputs + false_only_inputs + + true_input_to_param = dict(zip(true_inputs, true_graph.inputs)) + false_input_to_param = dict(zip(false_inputs, false_graph.inputs)) + + true_graph.inputs = ( + [true_input_to_param[t] for t in shared_inputs] + + [true_input_to_param[t] for t in true_only_inputs] + + _create_dummy_params(true_graph, false_only_inputs)) + + false_graph.inputs = ( + [false_input_to_param[t] for t in shared_inputs] + + _create_dummy_params(false_graph, true_only_inputs) + + [false_input_to_param[t] for t in false_only_inputs]) + + # Rewrite the _FuncGraphs' state to reflect the new inputs. + true_graph.extra_inputs = new_inputs + false_graph.extra_inputs = new_inputs + + true_graph._captured = dict(zip(new_inputs, true_graph.inputs)) + false_graph._captured = dict(zip(new_inputs, false_graph.inputs)) + + return new_inputs + + +def _create_dummy_params(func_graph, template_tensors): + """Creates tensors in func_graph to represent template_tensors. + + Args: + func_graph: function._FuncGraph. + template_tensors: a list of tensors in the outer graph. + + Returns: + A list of tensors in func_graph. + """ + with func_graph.as_default(): + return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape) + for t in template_tensors] + + +def _get_grad_fn_name(func_graph): + """Returns a unique name to use for the grad function of `func_graph`.""" + name = "%s_grad" % func_graph.name + + base_name = name + counter = 1 + if ops.get_default_graph()._is_function(name): + name = "%s_%s" % (base_name, counter) + counter += 1 + + return name + + +def _check_same_outputs(true_graph, false_graph): + """Raises an error if true_graph and false_graph have different outputs.""" + true_output_types = [t.dtype for t in true_graph.outputs] + false_output_types = [t.dtype for t in false_graph.outputs] + if (len(true_graph.outputs) != len(false_graph.outputs) or + true_output_types != false_output_types): + raise ValueError( + "true_fn() and false_fn() must return the same number and type of " + "arguments, got:\n" + " true_fn: %s\n" + " false_fn: %s" % (true_output_types, false_output_types)) diff --git a/tensorflow/contrib/control_flow/python/cond_v2_test.py b/tensorflow/contrib/control_flow/python/cond_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..338601aa2c5ee8ffc97aa2e07ff05a2d17531936 --- /dev/null +++ b/tensorflow/contrib/control_flow/python/cond_v2_test.py @@ -0,0 +1,171 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for cond_v2.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.control_flow.python import cond_v2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test +from tensorflow.python.training import saver + + +class NewCondTest(test.TestCase): + + def _testCond(self, true_fn, false_fn, train_vals): + pred = array_ops.placeholder(dtypes.bool, name="pred") + + expected = control_flow_ops.cond(pred, true_fn, false_fn, name="expected") + actual = cond_v2.cond_v2(pred, true_fn, false_fn, name="actual") + + expected_grad = gradients_impl.gradients(expected, train_vals) + actual_grad = gradients_impl.gradients(actual, train_vals) + + with self.test_session() as sess: + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), {pred: True}) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + expected_val, actual_val, expected_grad_val, actual_grad_val = sess.run( + (expected, actual, expected_grad, actual_grad), {pred: False}) + self.assertEqual(expected_val, actual_val) + self.assertEqual(expected_grad_val, actual_grad_val) + + def testBasic(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * 2.0 + + def false_fn(): + return y * 3.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testBasic2(self): + x = constant_op.constant(1.0, name="x") + y = constant_op.constant(2.0, name="y") + + def true_fn(): + return x * y * 2.0 + + def false_fn(): + return 2.0 + + self._testCond(true_fn, false_fn, [x]) + self._testCond(true_fn, false_fn, [x, y]) + self._testCond(true_fn, false_fn, [y]) + + def testNoInputs(self): + pred = array_ops.placeholder(dtypes.bool, name="pred") + + def true_fn(): + return constant_op.constant(1.0) + + def false_fn(): + return constant_op.constant(2.0) + + out = cond_v2.cond_v2(pred, true_fn, false_fn) + + with self.test_session() as sess: + self.assertEqual(sess.run(out, {pred: True}), [1.0]) + self.assertEqual(sess.run(out, {pred: False}), [2.0]) + + def testSecondDerivative(self): + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") + + def true_fn(): + return math_ops.pow(x, 3) + + def false_fn(): + return x + + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + cond_grad = gradients_impl.gradients(cond, [x]) + cond_grad_grad = gradients_impl.gradients(cond_grad, [x]) + + with self.test_session() as sess: + # d[x^3]/dx = 3x^2 + true_val = sess.run(cond_grad, {pred: True}) + self.assertEqual(true_val, [27.0]) + # d[x]/dx = 1 + false_val = sess.run(cond_grad, {pred: False}) + self.assertEqual(false_val, [1.0]) + + true_val = sess.run(cond_grad_grad, {pred: True}) + # d2[x^3]/dx2 = 6x + self.assertEqual(true_val, [18.0]) + false_val = sess.run(cond_grad_grad, {pred: False}) + # d2[x]/dx2 = 0 + self.assertEqual(false_val, [0.0]) + + def testGradientOfDeserializedCond(self): + with ops.Graph().as_default(): + pred = array_ops.placeholder(dtypes.bool, name="pred") + x = constant_op.constant(3.0, name="x") + ops.add_to_collection("x", x) + + def true_fn(): + return math_ops.pow(x, 3) + + def false_fn(): + return x + + ops.add_to_collection("pred", pred) + cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond") + for c in cond: + ops.add_to_collection("cond", c) + meta_graph = saver.export_meta_graph() + + with ops.Graph().as_default() as g: + saver.import_meta_graph(meta_graph) + x = ops.get_collection("x")[0] + pred = ops.get_collection("pred")[0] + cond = ops.get_collection("cond") + cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad") + cond_grad_grad = gradients_impl.gradients( + cond_grad, [x], name="cond_grad_grad") + with self.test_session(graph=g) as sess: + # d[x^3]/dx = 3x^2 + true_val = sess.run(cond_grad, {pred: True}) + self.assertEqual(true_val, [27.0]) + # d[x]/dx = 1 + false_val = sess.run(cond_grad, {pred: False}) + self.assertEqual(false_val, [1.0]) + + true_val = sess.run(cond_grad_grad, {pred: True}) + # d2[x^3]/dx2 = 6x + self.assertEqual(true_val, [18.0]) + false_val = sess.run(cond_grad_grad, {pred: False}) + # d2[x]/dx2 = 0 + self.assertEqual(false_val, [0.0]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index 102bc460fdadb0ad5dc9a2960b8655c55357108e..a0dd3881a86c19e47ccb65f84a2477a55626b81c 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -218,7 +218,6 @@ def copy_op_to_graph(org_instance, to_graph, variables, scope=''): new_control_inputs, input_types, new_original_op, op_def) #Use Graph's hidden methods to add the op - to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index ed0a26bbd87eeb5bd005de8f9d054d315e378529..8822a7523f6b168f569e29970c9c29f2eb3614fc 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -20,7 +20,6 @@ from __future__ import print_function import os from tensorflow.contrib.checkpoint.python import split_dependency from tensorflow.contrib.rnn.python.ops import lstm_ops -from tensorflow.python.framework import common_shapes from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -1647,10 +1646,3 @@ class CudnnRNNRelu(_CudnnRNNNoInputC): # 1 set of weight and bias parameters for the recurrent input, and 1 for the # previous layer input. _NUM_PARAMS_PER_LAYER = CUDNN_RNN_RELU_PARAMS_PER_LAYER - - -ops.RegisterShape("CudnnRNNParamsSize")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNParamsToCanonical")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNCanonicalToParams")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNN")(common_shapes.call_cpp_shape_fn) -ops.RegisterShape("CudnnRNNBackprop")(common_shapes.call_cpp_shape_fn) diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index a25aa85251083c24ca6685c4ffef267955f66f63..1af1ed08b53ee04367eb316d5c9caa0216f2e88d 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -30,6 +30,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@assert_element_shape @@batch_and_drop_remainder @@bucket_by_sequence_length +@@choose_from_datasets @@dense_to_sparse_batch @@enumerate_dataset @@group_by_window diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc index 76e54a284e07ec1bab9b0f364a44997a39bce78a..4657807785d58727d34f37172bd30c56a5b7cde6 100644 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/random_inputstream.h" namespace tensorflow { @@ -103,12 +102,11 @@ class CSVDatasetOp : public DatasetOpKernel { OP_REQUIRES( ctx, select_cols.empty() || select_cols.front() >= 0, errors::InvalidArgument("select_cols should be non-negative indices")); - bool select_all_cols = select_cols.empty(); - *output = new Dataset( - ctx, std::move(filenames), header, buffer_size, output_types_, - output_shapes_, std::move(record_defaults), std::move(select_cols), - select_all_cols, use_quote_delim, delim[0], std::move(na_value)); + *output = new Dataset(ctx, std::move(filenames), header, buffer_size, + output_types_, output_shapes_, + std::move(record_defaults), std::move(select_cols), + use_quote_delim, delim[0], std::move(na_value)); } private: @@ -118,8 +116,7 @@ class CSVDatasetOp : public DatasetOpKernel { int64 buffer_size, const DataTypeVector& output_types, const std::vector& output_shapes, std::vector record_defaults, std::vector select_cols, - bool select_all_cols, bool use_quote_delim, char delim, - string na_value) + bool use_quote_delim, char delim, string na_value) : GraphDatasetBase(ctx), filenames_(std::move(filenames)), header_(header), @@ -128,12 +125,11 @@ class CSVDatasetOp : public DatasetOpKernel { output_shapes_(output_shapes), record_defaults_(std::move(record_defaults)), select_cols_(std::move(select_cols)), - select_all_cols_(select_all_cols), use_quote_delim_(use_quote_delim), delim_(delim), na_value_(std::move(na_value)) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::CSV")})); @@ -145,7 +141,7 @@ class CSVDatasetOp : public DatasetOpKernel { return output_shapes_; } - string DebugString() override { return "CSVDatasetOp::Dataset"; } + string DebugString() const override { return "CSVDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, @@ -166,11 +162,24 @@ class CSVDatasetOp : public DatasetOpKernel { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); + bool select_all = dataset()->select_cols_.empty(); do { // We are currently processing a file, so try to read the next record - if (buffered_input_stream_) { - Status s = ReadRecord(ctx, out_tensors); - if (s.ok() || !errors::IsOutOfRange(s)) { + if (input_stream_) { + Status s = ReadRecord(ctx, out_tensors, select_all, + dataset()->select_cols_); + if (s.ok()) { + // Validate output + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument( + "Expect ", dataset()->out_type_.size(), " fields but have ", + out_tensors->size(), " in record"); + } + + *end_of_sequence = false; + return s; + } + if (!errors::IsOutOfRange(s)) { // Not at the end of file, return OK or non-EOF errors to caller. *end_of_sequence = false; return s; @@ -203,145 +212,317 @@ class CSVDatasetOp : public DatasetOpKernel { } private: - // Reads a record by parsing the input buffer, and converting extracted + // Reads an entire CSV row from the input stream, either from the + // existing buffer or by filling the buffer as needed. Converts extracted // fields to output tensors as we go. - Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors) + // + // When this function is called, pos_ should be the index of the first + // character of the record in buffer_, or past the end of the buffer. + // Note: ctx and out_tensors are only used in this function + // when fields are included in the record. + Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors, + bool select_all, const std::vector& selected) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - // Extracts fields from line(s) from the buffered input stream. - out_tensors->reserve(dataset()->record_defaults_.size()); - - string input; - TF_RETURN_IF_ERROR(buffered_input_stream_->ReadLine(&input)); - - size_t current_idx = 0; - size_t num_fields_parsed = 0; - size_t selector_idx = 0; // Keep track of index into select_cols - - while (current_idx < input.size()) { - // In each iteration, parse one field - if (input[current_idx] == '\n' || input[current_idx] == '\r') { - // This should never happen, because buffered input reader splits - // input on newlines. - return errors::InvalidArgument("Parsing error."); - } + if (pos_ >= buffer_.size()) { + // At the end of the file, this will return errors::OutOfRange + TF_RETURN_IF_ERROR(FillBuffer(&buffer_)); + pos_ = 0; + } + + // The first character may be \n if this is the continuation of a + // \r\n linebreak between this and the previous record. If so, skip it. + + bool end_of_record = false; // Keep track of when we find \n, \r or EOF + size_t num_parsed = 0; + size_t num_selected_parsed = 0; + + Status result; - bool quoted = false; + while (!end_of_record) { // Read till we reach \n, \r or EOF bool include = - (dataset()->select_all_cols_ || - dataset()->select_cols_[selector_idx] == num_fields_parsed); + select_all || (num_selected_parsed < selected.size() && + selected[num_selected_parsed] == num_parsed); - if (dataset()->use_quote_delim_ && input[current_idx] == '"') { - quoted = true; - current_idx++; + // Don't fail fast, so that the next call to GetNext may still return + // a valid record + result.Update( + ParseOneField(ctx, out_tensors, &end_of_record, include)); + + num_parsed++; + if (include) num_selected_parsed++; + } + + return result; + } + + // Parses one field from position pos_ in the buffer. Fields are + // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of + // the next field. + Status ParseOneField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // If we get here, this means the previous field's end coincided + // with the end of the buffer. We can fill the buffer without abandon. + Status s = FillBuffer(&buffer_); + + if (errors::IsOutOfRange(s)) { + // Reached EOF, and last field is empty + *end_of_record = true; + if (include) { + return FieldToOutput(ctx, StringPiece(), out_tensors); + } else { + return Status::OK(); + } + } else if (!s.ok()) { + return s; // Surface other errors back to caller } - // Parse the body of the field - string field; - if (!quoted) { - while (current_idx < input.size() && - input[current_idx] != dataset()->delim_) { - if ((dataset()->use_quote_delim_ && input[current_idx] == '"') || - input[current_idx] == '\n' || input[current_idx] == '\r') { - return errors::InvalidArgument( - "Unquoted fields cannot have quotes/CRLFs inside"); + pos_ = 0; + } + + if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') { + return ParseQuotedField(ctx, out_tensors, end_of_record, include); + } + + return ParseUnquotedField(ctx, out_tensors, end_of_record, include); + } + + // For keeping track of relevant parts of a field from a previous buffer + struct Piece { + size_t start; + size_t len; + string buffer; + + Piece(string buffer, size_t start, size_t len) + : start(start), len(len), buffer(std::move(buffer)) {} + }; + + // Given that pos_ exceeds the buffer, saves the relevant part of the + // current buffer (if necessary), fills the buffer, and resets indices to + // 0. + Status SaveAndFillBuffer(std::vector* earlier_pieces, + size_t* start, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + string temp_buffer; + + buffer_.swap(temp_buffer); + if (include && pos_ > *start) { + earlier_pieces->push_back( + Piece(std::move(temp_buffer), *start, pos_ - *start)); + } + pos_ = 0; + *start = 0; + return FillBuffer(&buffer_); + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseQuotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + pos_++; // Starting quotation mark + + Status parse_result; + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument( + "Reached end of file without closing quoted field in " + "record"); + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + if (ch == '"') { + // When we encounter a quote, we look ahead to the next character to + // decide what to do + pos_++; + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + // This was the last field. We are done + *end_of_record = true; + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(), out_tensors, earlier_pieces, include)); + return parse_result; + } else if (!s.ok()) { + return s; } - if (include) field += input[current_idx]; - current_idx++; - } // Exit condition: end of input, or current index at delim + } + + char next = buffer_[pos_]; + pos_++; + if (next == dataset()->delim_) { + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include)); + return parse_result; + + } else if (next == '\n' || next == '\r') { + *end_of_record = true; + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include)); + if (next == '\r') SkipNewLineIfNecessary(); + return parse_result; + } else if (next != '"') { + // Take note of the error, but keep going to end of field. + include = false; // So we don't get funky errors when trying to + // unescape the quotes. + parse_result.Update(errors::InvalidArgument( + "Quote inside a string has to be escaped by another quote")); + } - // Go to next field or the end - current_idx++; } else { - // Quoted field needs to be ended with '"' and delim or end - while (true) { - if (current_idx >= input.size() - 1 || input.empty()) { - if (current_idx == input.size() - 1 && - input[current_idx] == '"') { - // We're at the end of the input, and the quote terminates the - // record. Go to end. - current_idx++; - break; - } - // If there's no terminating quote, it means our buffered record - // line reader split a record up. This can happen if there is a - // newline encased in quotes. The next line is also part of the - // record, so we read it and reset the index. - if (include && current_idx == input.size() - 1) { - // TODO(rachelim): Instead of building up a string, keep track - // of terminal indices (or starting char* and length) - // Also look into using /lib/strings/Scanner - field += input[current_idx]; - } - if (include) { - field += '\n'; - } - current_idx = 0; - Status s = buffered_input_stream_->ReadLine(&input); - if (!s.ok()) { - return errors::InvalidArgument( - "Quoted field has to end with quote followed by delim, " - "CRLF, or EOF"); - } - } else if (input[current_idx] == '"' && - input[current_idx + 1] == dataset()->delim_) { - // End of field, go to next field or end - current_idx += 2; - break; - } else if (input[current_idx] == '"') { - // Current char is a quote. Since we're not at end of field, - // the next character must also be a quote. - if (input[current_idx + 1] != '"') { - return errors::InvalidArgument( - "Quote inside a string has to be escaped by another " - "quote"); - } - if (include) field += '"'; - current_idx += 2; - } else { - if (include) field += input[current_idx]; - current_idx++; - } + pos_++; + } + } + } + + // Converts quoted field to an output tensor, removing the starting + // and ending quotes from it and unescaping double quotations if + // necessary. + Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + if (field.find('\"', 1) == field.size() - 1) { + // `field` contains no escaped quotation marks. + // Exclude framing quotation marks + field.remove_prefix(1); + field.remove_suffix(1); + return FieldToOutput(ctx, field, out_tensors); + } + } + string field_complete; + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + field_complete.reserve(str_len); + + // This bool flips every time we see a quote, so that we skip the second + // quote of every pair of adjacent quotes in the field. We need to track + // this across iterations of the for loop because adjacent double quotes + // may be in different buffers. Initialize to true because we also skip + // the opening quotation mark of the quoted field. + bool skip_next_quote = true; + for (const Piece& p : earlier_pieces) { + AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len), + &field_complete, &skip_next_quote); + } + AppendUnescapedPiece(field, &field_complete, &skip_next_quote); + StringPiece result = StringPiece(field_complete); + result.remove_suffix(1); // Skip final quote + + return FieldToOutput(ctx, result, out_tensors); + } + + void AppendUnescapedPiece(StringPiece piece, string* field_complete, + bool* skip_next_quote) { + size_t from = 0; + size_t found = piece.find('\"', from); + while (found != string::npos) { + if (!*skip_next_quote) { + // This is the first quote in a pair of adjacent double quotes + field_complete->append(piece.data() + from, found + 1 - from); + } + *skip_next_quote = !*skip_next_quote; + from = found + 1; + found = piece.find('\"', from); + } + // Include the chunk after the last quotation mark in the string + if (from < piece.size()) { + field_complete->append(piece.data() + from, piece.size() - from); + } + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseUnquotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector earlier_pieces; + size_t start = pos_; + Status parse_result; + + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + // Handle errors + if (errors::IsOutOfRange(s)) { + // Whatever we have is the last field of the last record + *end_of_record = true; + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + return parse_result; + } else if (!s.ok()) { + return s; // Surface all other errors to caller } } - num_fields_parsed++; + char ch = buffer_[pos_]; - if (include) { - // Add the tensor to the result - TF_RETURN_IF_ERROR(FieldToOutput(ctx, std::move(field), - selector_idx, out_tensors)); - selector_idx++; - // Terminate early if we have all the fields we want - if (selector_idx == dataset()->select_cols_.size()) - return Status::OK(); + if (ch == dataset()->delim_) { + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + pos_++; + return parse_result; } - } // Exit condition: current_idx has reached the end of record - - // Check if the last field is empty, and include it if necessary - bool include = - (dataset()->select_all_cols_ || - dataset()->select_cols_[selector_idx] == num_fields_parsed); - if (include && !input.empty() && - input[input.size() - 1] == dataset()->delim_) { - TF_RETURN_IF_ERROR( - FieldToOutput(ctx, string(), selector_idx, out_tensors)); + if (ch == '\n' || ch == '\r') { + // need special case to skip over first \n of record if the line + // breaks are \r\n + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + *end_of_record = true; + pos_++; + if (ch == '\r') SkipNewLineIfNecessary(); + return parse_result; + } + if (dataset()->use_quote_delim_ && ch == '"') { + // Take note of the error, but keep going to end of field. + parse_result.Update(errors::InvalidArgument( + "Unquoted fields cannot have quotes inside")); + } + // Otherwise, go to next character + pos_++; } + } - // Check that number of fields matches - if (out_tensors->size() != dataset()->out_type_.size()) { - return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), - " fields but have ", - out_tensors->size(), " in record"); + Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + result->clear(); + Status s = input_stream_->ReadNBytes(dataset()->buffer_size_, result); + + if (errors::IsOutOfRange(s) && !result->empty()) { + // Ignore OutOfRange error when ReadNBytes read < N bytes. + return Status::OK(); } - return Status::OK(); + return s; } - // Given a string field, and its index in the output, - // converts it to a Tensor of the right type and adds it to the - // out_tensors vector. - Status FieldToOutput(IteratorContext* ctx, string field, - size_t output_idx, + // Given a field, converts it to the right output tensor type + Status FieldToOutput(IteratorContext* ctx, StringPiece field, std::vector* out_tensors) { + size_t output_idx = out_tensors->size(); if (output_idx >= dataset()->out_type_.size()) { // We can get here if we're selecting all columns, but the number of // fields exceeds the number of defaults provided @@ -397,7 +578,7 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->record_defaults_[output_idx].flat()(0); } else { float value; - if (!strings::safe_strtof(field.c_str(), &value)) { + if (!strings::safe_strtof(field, &value)) { return errors::InvalidArgument( "Field ", output_idx, " in record is not a valid float: ", field); @@ -412,7 +593,7 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->record_defaults_[output_idx].flat()(0); } else { double value; - if (!strings::safe_strtod(field.c_str(), &value)) { + if (!strings::safe_strtod(field, &value)) { return errors::InvalidArgument( "Field ", output_idx, " in record is not a valid double: ", field); @@ -426,7 +607,7 @@ class CSVDatasetOp : public DatasetOpKernel { component.scalar()() = dataset()->record_defaults_[output_idx].flat()(0); } else { - component.scalar()() = std::move(field); + component.scalar()() = field.ToString(); } break; } @@ -439,6 +620,50 @@ class CSVDatasetOp : public DatasetOpKernel { return Status::OK(); } + // Records can be delimited by "\r\n" line breaks. When we encounter a + // '\r', we have to check the next character to see if it is part of the + // linebreak, and ignore it if so. + void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + Status s = FillBuffer(&buffer_); + pos_ = 0; + // If we failed to fill buffer, it doesn't matter because we're done + // with the record + if (!s.ok()) return; + } + if (buffer_[pos_] == '\n') { + pos_++; + } + } + + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + return FieldToOutput(ctx, field, out_tensors); + } + + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + string field_complete; + field_complete.reserve(str_len); + + for (const Piece& p : earlier_pieces) { + field_complete.append(p.buffer, p.start, p.len); + } + + field_complete.append(field.data(), field.size()); + return FieldToOutput(ctx, field_complete, out_tensors); + } + // Sets up reader streams to read from the file at `current_file_index_`. Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (current_file_index_ >= dataset()->filenames_.size()) { @@ -452,16 +677,18 @@ class CSVDatasetOp : public DatasetOpKernel { dataset()->filenames_[current_file_index_], &file_)); input_stream_.reset( new io::RandomAccessInputStream(file_.get(), false)); - // TODO(rachelim): Maintain our own buffer so we don't read every record - // twice - buffered_input_stream_.reset(new io::BufferedInputStream( - input_stream_.get(), dataset()->buffer_size_, false)); + buffer_.clear(); + pos_ = 0; if (dataset()->header_) { - // Ignore header line - string str; - Status s = buffered_input_stream_->ReadLine(&str); - if (errors::IsOutOfRange(s)) { - return errors::InvalidArgument("Can't read header of empty file"); + // Read one line, but don't include it. Pass nullptrs as dummy + // pointers to objects that shouldn't be invoked anyway + // We need to process this as a record here instead of just finding + // the first newline because it might contain quoted fields with + // newlines in the header as well + std::vector empty; + Status s = ReadRecord(nullptr, nullptr, false, empty); + if (!s.ok()) { + return errors::InvalidArgument("Can't read header of file"); } } return Status::OK(); @@ -470,15 +697,15 @@ class CSVDatasetOp : public DatasetOpKernel { // Resets all reader streams. void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { input_stream_.reset(); - buffered_input_stream_.reset(); file_.reset(); } mutex mu_; + string buffer_ GUARDED_BY(mu_); // Maintain our own buffer + size_t pos_ GUARDED_BY( + mu_); // Index into the buffer must be maintained between iters std::unique_ptr input_stream_ GUARDED_BY(mu_); - std::unique_ptr buffered_input_stream_ - GUARDED_BY(mu_); size_t current_file_index_ GUARDED_BY(mu_) = 0; std::unique_ptr file_ GUARDED_BY(mu_); // must outlive input_stream_ @@ -491,7 +718,6 @@ class CSVDatasetOp : public DatasetOpKernel { const std::vector output_shapes_; const std::vector record_defaults_; const std::vector select_cols_; - const bool select_all_cols_; const bool use_quote_delim_; const char delim_; const string na_value_; diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc index 48d3734162525ffc6ace076e4f0523c1d0cae511..6a12ca06f4d6cc2096aaf8191a01a899881b43db 100644 --- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc @@ -91,7 +91,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { } } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr(new Iterator( {this, strings::StrCat(prefix, "::DirectedInterleave")})); @@ -105,7 +105,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { return output_shapes_; } - string DebugString() override { + string DebugString() const override { return strings::StrCat("DirectedInterleaveDatasetOp::Dataset"); } @@ -130,15 +130,21 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator(params), - selector_input_impl_(params.dataset->selector_input_->MakeIterator( - params.prefix + ".selector")), - num_active_inputs_(params.dataset->data_inputs_.size()) { - data_input_impls_.reserve(params.dataset->data_inputs_.size()); - for (size_t i = 0; i < params.dataset->data_inputs_.size(); ++i) { - const DatasetBase* data_input = params.dataset->data_inputs_[i]; - data_input_impls_.push_back(data_input->MakeIterator( - strings::StrCat(params.prefix, "[", i, "]"))); + num_active_inputs_(params.dataset->data_inputs_.size()) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator( + ctx, strings::StrCat(prefix(), ".selector"), + &selector_input_impl_)); + data_input_impls_.resize(dataset()->data_inputs_.size()); + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const DatasetBase* data_input = dataset()->data_inputs_[i]; + TF_RETURN_IF_ERROR(data_input->MakeIterator( + ctx, strings::StrCat(prefix(), "[", i, "]"), + &data_input_impls_[i])); } + return Status::OK(); } Status GetNextInternal(IteratorContext* ctx, diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index bb29df60e8f114aaa50f578c43e73874f72ab0a3..bbec50681c6f5decec5a3b5fbf09cc3011a21199 100644 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -44,7 +44,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")})); @@ -57,7 +57,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "IgnoreErrorsDatasetOp::Dataset"; } + string DebugString() const override { + return "IgnoreErrorsDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, @@ -72,8 +74,11 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc index 63e19ae3f837c9d3cfb1221df64360ee74117f13..3dfc3741c2b040dd5be3223c24f0715ba3be4248 100644 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -127,7 +127,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { threadpool_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::ThreadPool")})); @@ -140,7 +140,9 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { return "ThreadPoolDatasetOp::Dataset"; } + string DebugString() const override { + return "ThreadPoolDatasetOp::Dataset"; + } protected: Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, @@ -154,8 +156,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc index 69fbb0fcdcce87951d2c9b84210fda378081b103..67c237799c10a2724f18bb0df99e4bf8f5cd2b8a 100644 --- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -56,7 +56,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { ~Dataset() override { input_->Unref(); } - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Unique")})); @@ -70,7 +70,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { return input_->output_shapes(); } - string DebugString() override { + string DebugString() const override { return strings::StrCat("UniqueDatasetOp::Dataset"); } @@ -87,8 +87,11 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator { public: explicit Iterator(const typename Iterator::Params& params) - : DatasetIterator(params), - input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + : DatasetIterator(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index f5082228e885d065e659abf208ca7b94bb4999a5..0dfd249ec27d96d6f1a4ae65d623df456db9991f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -54,6 +54,19 @@ py_test( ], ) +py_test( + name = "cache_dataset_op_test", + size = "small", + srcs = ["cache_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "concatenate_dataset_op_test", size = "small", @@ -128,6 +141,7 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:error_ops", "//tensorflow/contrib/data/python/ops:readers", "//third_party/py/numpy", ], @@ -208,6 +222,23 @@ py_test( ], ) +py_test( + name = "directed_interleave_dataset_test", + size = "medium", + srcs = ["directed_interleave_dataset_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:interleave_ops", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "get_single_element_test", size = "small", @@ -262,6 +293,19 @@ py_test( ], ) +py_test( + name = "optimize_dataset_op_test", + size = "small", + srcs = ["optimize_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/python:platform", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "prefetch_dataset_op_test", size = "small", @@ -299,6 +343,26 @@ py_test( ], ) +py_library( + name = "reader_dataset_ops_test_base", + testonly = 1, + srcs = [ + "reader_dataset_ops_test_base.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:lib", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:readers", + ], +) + py_test( name = "reader_dataset_ops_test", size = "medium", @@ -308,8 +372,8 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", + ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -321,6 +385,7 @@ py_test( "//tensorflow/python:string_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", "//third_party/py/numpy", ], ) @@ -410,6 +475,7 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -417,6 +483,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", @@ -447,10 +514,15 @@ py_test( tags = ["no_pip"], deps = [ ":dataset_serialization_test", + ":reader_dataset_ops_test_base", "//tensorflow/contrib/data/python/ops:stats_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 2568b899d7ea1be685036ad8af93f584f861c951..b5fbc45ad3d8d262c1c79b5723ffeb38ff6a34c2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -552,6 +552,44 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testMapAndBatchParallelGetNext(self): + iterator = (dataset_ops.Dataset.range(50000) + .apply(batching.map_and_batch(lambda x: x, batch_size=100)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.test_session() as sess: + for i in range(5): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) + + def testMapAndBatchParallelGetNextDropRemainder(self): + iterator = ( + dataset_ops.Dataset.range(49999).apply( + batching.map_and_batch( + lambda x: x, batch_size=100, drop_remainder=True)) + .make_one_shot_iterator()) + elements = [] + for _ in range(100): + elements.append(iterator.get_next()) + with self.test_session() as sess: + for i in range(4): + got = sess.run(elements) + got.sort(key=lambda x: x[0]) + expected = [] + for j in range(100): + expected.append(range(i*10000+j*100, i*10000+(j+1)*100)) + self.assertAllEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(elements) + def testMapAndBatchSparse(self): def _sparse(i): diff --git a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f08216a303e2d7dee155ccadcdb9f42f1b24ea0f --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental features of CacheDataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class CacheToFileDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self.range_size = 10 + self.num_repeats = 3 + self.num_outputs = self.range_size * self.num_repeats + self.cache_file_prefix = 'test' + + def ds_fn(self): + return dataset_ops.Dataset.range(self.range_size).cache( + os.path.join(self.get_temp_dir(), + self.cache_file_prefix)).repeat(self.num_repeats) + + def expected_outputs(self): + return list(range(self.range_size)) * self.num_repeats + + def testCheckpointBeforeOneEpoch(self): + # Generate 5 entries from iterator and save checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint and produce the rest of the elements from the + # iterator. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + + def testCheckpointBeforeOneEpochThenRunFewSteps(self): + # Generate 8 entries from iterator but save checkpoint after producing + # 5. + outputs = self.gen_outputs( + self.ds_fn, [5], + 8, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, range(8)) + + # Restoring from checkpoint and running GetNext should return a + # `AlreadExistsError` now because the lockfile already exists. + with self.assertRaises(errors.AlreadyExistsError): + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + + def testCheckpointAfterOneEpoch(self): + # Generate 15 entries from iterator and save checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) + + # Restore from checkpoint and produce the rest of the elements from the + # iterator. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 15, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, self.expected_outputs()) + + def testCheckpointAfterOneEpochThenRunFewSteps(self): + # Generate 18 entries from iterator but save checkpoint after producing + # 15. + outputs = self.gen_outputs( + self.ds_fn, [15], + 18, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(8))) + + outputs = list(range(10)) + list(range(5)) + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 15, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointBeforeOneEpochButRunCompleteEpoch(self): + # Generate 13 entries from iterator but save checkpoint after producing + # 5. + outputs = self.gen_outputs( + self.ds_fn, [5], + 13, + verify_exhausted=False, + save_checkpoint_at_end=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(3))) + + # Since we ran for more than one epoch, the cache was completely written. + # The ckpt was saved when the iterator was in cache-write mode. Test that + # the iterator falls back to read mode after restoring if the cache has + # been completely written. + + outputs = list(range(5)) + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointUnusedWriterIterator(self): + # Checkpoint before get_next is called even once. + outputs = self.gen_outputs(self.ds_fn, [], 0, verify_exhausted=False) + self.assertSequenceEqual(outputs, []) + + outputs = self.gen_outputs( + self.ds_fn, [], + self.num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testCheckpointUnusedMidwayWriterIterator(self): + # Produce 5 elements and checkpoint. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint, then produce no elements and checkpoint. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False)) + self.assertSequenceEqual(outputs, range(5)) + + # Restore from checkpoint and produce rest of the elements. + outputs.extend( + self.gen_outputs( + self.ds_fn, [], + self.num_outputs - 5, + ckpt_saved=True, + verify_exhausted=False)) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + def testUnusedCheckpointError(self): + # Produce 5 elements and save ckpt. + outputs = self.gen_outputs(self.ds_fn, [], 5, verify_exhausted=False) + self.assertSequenceEqual(outputs, range(5)) + + # Since the complete cache has not been written, a new iterator which does + # not restore the checkpoint will throw an error since there is a partial + # cache shard. + with self.assertRaises(errors.AlreadyExistsError): + outputs = self.gen_outputs( + self.ds_fn, [], self.num_outputs, verify_exhausted=False) + + def testIgnoreCheckpointIfCacheWritten(self): + # Produce 15 elements and save ckpt. This will write the complete cache. + outputs = self.gen_outputs(self.ds_fn, [], 15, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) + list(range(5))) + + # Build the iterator again but do not restore from ckpt. Since the cache + # has already been written we should be able to use it. + outputs = self.gen_outputs( + self.ds_fn, [], self.num_outputs, verify_exhausted=False) + self.assertSequenceEqual(outputs, list(range(10)) * 3) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index 641a389c033687ebe081963182390b00230e4cb5..97b5e9416521dcad9ee5047a8275f8fd0142e338 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -19,11 +19,13 @@ from __future__ import division from __future__ import print_function import os +import string import tempfile import time import numpy as np +from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.python.client import session from tensorflow.python.data.ops import readers as core_readers @@ -60,12 +62,12 @@ class CsvDatasetOpTest(test.TestCase): op2 = sess.run(next2) self.assertAllEqual(op1, op2) - def setup_files(self, inputs): + def setup_files(self, inputs, linebreak='\n'): filenames = [] for i, ip in enumerate(inputs): - fn = os.path.join(self.get_temp_dir(), 'temp_%d.txt' % i) - with open(fn, 'w') as f: - f.write('\n'.join(ip)) + fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i) + with open(fn, 'wb') as f: + f.write(linebreak.join(ip).encode('utf-8')) filenames.append(fn) return filenames @@ -85,38 +87,47 @@ class CsvDatasetOpTest(test.TestCase): inputs, **kwargs) self._assert_datasets_equal(g, dataset_actual, dataset_expected) + def _verify_output_or_err(self, + sess, + dataset, + expected_output=None, + expected_err_re=None): + nxt = dataset.make_one_shot_iterator().get_next() + if expected_err_re is None: + # Verify that output is expected, without errors + expected_output = [[ + v.encode('utf-8') if isinstance(v, str) else v for v in op + ] for op in expected_output] + for value in expected_output: + op = sess.run(nxt) + self.assertAllEqual(op, value) + with self.assertRaises(errors.OutOfRangeError): + sess.run(nxt) + else: + # Verify that OpError is produced as expected + with self.assertRaisesOpError(expected_err_re): + while True: + try: + sess.run(nxt) + except errors.OutOfRangeError: + break + def _test_dataset(self, inputs, expected_output=None, expected_err_re=None, + linebreak='\n', **kwargs): """Checks that elements produced by CsvDataset match expected output.""" # Convert str type because py3 tf strings are bytestrings - filenames = self.setup_files(inputs) + filenames = self.setup_files(inputs, linebreak) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: dataset = readers.CsvDataset(filenames, **kwargs) - nxt = dataset.make_one_shot_iterator().get_next() - if expected_err_re is None: - # Verify that output is expected, without errors - expected_output = [[ - v.encode('utf-8') if isinstance(v, str) else v for v in op - ] for op in expected_output] - for value in expected_output: - op = sess.run(nxt) - self.assertAllEqual(op, value) - with self.assertRaises(errors.OutOfRangeError): - sess.run(nxt) - else: - # Verify that OpError is produced as expected - with self.assertRaisesOpError(expected_err_re): - while True: - try: - sess.run(nxt) - except errors.OutOfRangeError: - break - - def testCsvDataset_floatRequired(self): + self._verify_output_or_err(sess, dataset, expected_output, + expected_err_re) + + def testCsvDataset_requiredFields(self): record_defaults = [[]] * 4 inputs = [['1,2,3,4']] self._test_by_comparison(inputs, record_defaults=record_defaults) @@ -136,10 +147,55 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']] self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withQuoted(self): - record_defaults = [['']] * 4 - inputs = [['1.0,2.1,"hello, it is me",4.3', '5.4,6.5,goodbye,8.7']] - self._test_by_comparison(inputs, record_defaults=record_defaults) + def testCsvDataset_withEmptyFields(self): + record_defaults = [[0]] * 4 + inputs = [[',,,', '1,1,1,', ',2,2,2']] + self._test_dataset( + inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]], + record_defaults=record_defaults) + + def testCsvDataset_errWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_dataset( + inputs, + expected_err_re='Unquoted fields cannot have quotes inside', + record_defaults=record_defaults) + + def testCsvDataset_errWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['"a"b","c","d"']] + self._test_dataset( + inputs, + expected_err_re= + 'Quote inside a string has to be escaped by another quote', + record_defaults=record_defaults) + + def testCsvDataset_ignoreErrWithUnescapedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + + def testCsvDataset_ignoreErrWithUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']] + filenames = self.setup_files(inputs) + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + dataset = readers.CsvDataset(filenames, record_defaults=record_defaults) + dataset = dataset.apply(error_ops.ignore_errors()) + self._verify_output_or_err(sess, dataset, [['e', 'f', 'g']]) + + def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self): + record_defaults = [['']] * 3 + inputs = [['1,2"3,4']] + self._test_by_comparison( + inputs, record_defaults=record_defaults, use_quote_delim=False) def testCsvDataset_mixedTypes(self): record_defaults = [ @@ -163,11 +219,6 @@ class CsvDatasetOpTest(test.TestCase): self._test_by_comparison( inputs, record_defaults=record_defaults, field_delim=':') - def testCsvDataset_withEmptyValues(self): - record_defaults = [[0]] * 4 - inputs = [['1,,3,4', ',6,7,8']] - self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withNaValue(self): record_defaults = [[0]] * 4 inputs = [['1,NA,3,4', 'NA,6,7,8']] @@ -175,8 +226,8 @@ class CsvDatasetOpTest(test.TestCase): inputs, record_defaults=record_defaults, na_value='NA') def testCsvDataset_withSelectCols(self): - record_defaults = [[0]] * 2 - inputs = [['1,2,3,4', '5,6,7,8']] + record_defaults = [['']] * 2 + inputs = [['1,2,3,4', '"5","6","7","8"']] self._test_by_comparison( inputs, record_defaults=record_defaults, select_cols=[1, 2]) @@ -189,27 +240,17 @@ class CsvDatasetOpTest(test.TestCase): record_defaults=record_defaults, select_cols=[3, 4]) + def testCsvDataset_withOneCol(self): + record_defaults = [['NA']] + inputs = [['0', '', '2']] + self._test_dataset( + inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults) + def testCsvDataset_withMultipleFiles(self): record_defaults = [[0]] * 4 inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']] self._test_by_comparison(inputs, record_defaults=record_defaults) - def testCsvDataset_withNewLine(self): - # In this case, we expect it to behave differently from - # TextLineDataset->map(decode_csv) since that flow has bugs - record_defaults = [['']] * 4 - inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] - expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] - self._test_dataset(inputs, expected, record_defaults=record_defaults) - - def testCsvDataset_withMultipleNewLines(self): - # In this case, we expect it to behave differently from - # TextLineDataset->map(decode_csv) since that flow has bugs - record_defaults = [['']] * 4 - inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] - expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] - self._test_dataset(inputs, expected, record_defaults=record_defaults) - def testCsvDataset_withLeadingAndTrailingSpaces(self): record_defaults = [[0.0]] * 4 inputs = [['0, 1, 2, 3']] @@ -265,9 +306,10 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_errorWithHeaderEmptyFile(self): record_defaults = [[0]] * 2 inputs = [[]] + expected_err_re = "Can't read header of file" self._test_dataset( inputs, - expected_err_re="Can't read header of empty file", + expected_err_re=expected_err_re, record_defaults=record_defaults, header=True, ) @@ -283,7 +325,7 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['', '1,2']] # First record is empty self._test_dataset( inputs, - expected_err_re='Expect 2 fields but have 0 in record', + expected_err_re='Expect 2 fields but have 1 in record', record_defaults=record_defaults) def testCsvDataset_withChainedOps(self): @@ -300,7 +342,7 @@ class CsvDatasetOpTest(test.TestCase): def testCsvDataset_withTypeDefaults(self): # Testing using dtypes as record_defaults for required fields - record_defaults = [dtypes.float32, dtypes.float32] + record_defaults = [dtypes.float32, [0.0]] inputs = [['1.0,2.0', '3.0,4.0']] self._test_dataset( inputs, @@ -308,71 +350,270 @@ class CsvDatasetOpTest(test.TestCase): record_defaults=record_defaults, ) + def testMakeCsvDataset_fieldOrder(self): + data = [[ + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19', + '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19' + ]] + file_path = self.setup_files(data) + + with ops.Graph().as_default() as g: + ds = readers.make_csv_dataset( + file_path, batch_size=1, shuffle=False, num_epochs=1) + next_batch = ds.make_one_shot_iterator().get_next() + + with self.test_session(graph=g) as sess: + result = list(sess.run(next_batch).values()) + + self.assertEqual(result, sorted(result)) + +## The following tests exercise parsing logic for quoted fields + + def testCsvDataset_withQuoted(self): + record_defaults = [['']] * 4 + inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + def testCsvDataset_withOneColAndQuotes(self): + record_defaults = [['']] + inputs = [['"0"', '"1"', '"2"']] + self._test_dataset( + inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults) + + def testCsvDataset_withNewLine(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_withNewLineInUnselectedCol(self): + record_defaults = [['']] + inputs = [['1,"2\n3",4', '5,6,7']] + self._test_dataset( + inputs, + expected_output=[['1'], ['5']], + record_defaults=record_defaults, + select_cols=[0]) + + def testCsvDataset_withMultipleNewLines(self): + # In this case, we expect it to behave differently from + # TextLineDataset->map(decode_csv) since that flow has bugs + record_defaults = [['']] * 4 + inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']] + expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']] + self._test_dataset(inputs, expected, record_defaults=record_defaults) + + def testCsvDataset_errorWithTerminateMidRecord(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,"a']] + self._test_dataset( + inputs, + expected_err_re= + 'Reached end of file without closing quoted field in record', + record_defaults=record_defaults) + + def testCsvDataset_withEscapedQuotes(self): + record_defaults = [['']] * 4 + inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']] + self._test_by_comparison(inputs, record_defaults=record_defaults) + + +## Testing that parsing works with all buffer sizes, quoted/unquoted fields, +## and different types of line breaks + + def testCsvDataset_withInvalidBufferSize(self): + record_defaults = [['']] * 4 + inputs = [['a,b,c,d']] + self._test_dataset( + inputs, + expected_err_re='buffer_size should be positive', + record_defaults=record_defaults, + buffer_size=0) + + def testCsvDataset_withBufferSize(self): + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, expected, record_defaults=record_defaults, buffer_size=i + 1) + + def testCsvDataset_withCR(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withCRLF(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['abc,def,ghi', '0,1,2', ',,']] + expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + + def testCsvDataset_withBufferSizeAndQuoted(self): + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\n', record_defaults=record_defaults) + + def testCsvDataset_withCRAndQuoted(self): + # Test that when the line separator is '\r', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r', record_defaults=record_defaults) + + def testCsvDataset_withCRLFAndQuoted(self): + # Test that when the line separator is '\r\n', parsing works with all buffer + # sizes + record_defaults = [['NA']] * 3 + inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']] + expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'], + ['NA', 'NA', 'NA']] + for i in range(20): + # Test a range of buffer sizes that should all work + self._test_dataset( + inputs, + expected, + linebreak='\r\n', + record_defaults=record_defaults, + buffer_size=i + 1) + self._test_dataset( + inputs, expected, linebreak='\r\n', record_defaults=record_defaults) + class CsvDatasetBenchmark(test.Benchmark): """Benchmarks for the various ways of creating a dataset from CSV files. """ + FLOAT_VAL = '1.23456E12' + STR_VAL = string.ascii_letters * 10 - def _setUp(self): + def _setUp(self, str_val): # Since this isn't test.TestCase, have to manually create a test dir gfile.MakeDirs(googletest.GetTempDir()) self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self._num_cols = [4, 64, 256] - self._batch_size = 500 + self._num_per_iter = 5000 self._filenames = [] for n in self._num_cols: fn = os.path.join(self._temp_dir, 'file%d.csv' % n) - with open(fn, 'w') as f: - # Just write 10 rows and use `repeat`... - row = ','.join(['1.23456E12' for _ in range(n)]) - f.write('\n'.join([row for _ in range(10)])) + with open(fn, 'wb') as f: + # Just write 100 rows and use `repeat`... Assumes the cost + # of creating an iterator is not significant + row = ','.join([str_val for _ in range(n)]) + f.write('\n'.join([row for _ in range(100)])) self._filenames.append(fn) def _tearDown(self): gfile.DeleteRecursively(self._temp_dir) def _runBenchmark(self, dataset, num_cols, prefix): - next_element = dataset.make_one_shot_iterator().get_next() - with session.Session() as sess: - for _ in range(5): - sess.run(next_element) - deltas = [] - for _ in range(10): + dataset = dataset.skip(self._num_per_iter - 1) + deltas = [] + for _ in range(10): + next_element = dataset.make_one_shot_iterator().get_next() + with session.Session() as sess: start = time.time() + # NOTE: This depends on the underlying implementation of skip, to have + # the net effect of calling `GetNext` num_per_iter times on the + # input dataset. We do it this way (instead of a python for loop, or + # batching N inputs in one iter) so that the overhead from session.run + # or batch doesn't dominate. If we eventually optimize skip, this has + # to change. sess.run(next_element) end = time.time() - deltas.append(end - start) - median_wall_time = np.median(deltas) / 100 + deltas.append(end - start) + # Median wall time per CSV record read and decoded + median_wall_time = np.median(deltas) / self._num_per_iter print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols, median_wall_time)) self.report_benchmark( - iters=self._batch_size, + iters=self._num_per_iter, wall_time=median_wall_time, name='%s_with_cols_%d' % (prefix, num_cols)) - def benchmarkBatchThenMap(self): - self._setUp() + def benchmarkMapWithFloats(self): + self._setUp(self.FLOAT_VAL) for i in range(len(self._filenames)): num_cols = self._num_cols[i] kwargs = {'record_defaults': [[0.0]] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop - dataset = dataset.batch(self._batch_size) - self._runBenchmark(dataset, num_cols, 'csv_map_then_batch') + self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv') + self._tearDown() + + def benchmarkMapWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = dataset.map(lambda l: gen_parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv') self._tearDown() - def benchmarkCsvDataset(self): - self._setUp() + def benchmarkCsvDatasetWithFloats(self): + self._setUp(self.FLOAT_VAL) for i in range(len(self._filenames)): num_cols = self._num_cols[i] kwargs = {'record_defaults': [[0.0]] * num_cols} dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop - dataset = dataset.batch(self._batch_size) - self._runBenchmark(dataset, num_cols, 'csv_fused_dataset') + self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset') self._tearDown() + def benchmarkCsvDatasetWithStrings(self): + self._setUp(self.STR_VAL) + for i in range(len(self._filenames)): + num_cols = self._num_cols[i] + kwargs = {'record_defaults': [['']] * num_cols} + dataset = core_readers.TextLineDataset(self._filenames[i]).repeat() + dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop + self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset') + self._tearDown() if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index 78ecce8f7daaf84002ae78d8d77820755b967d89..393f08850b1865180a8b94e9209b2445b54c8b69 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -467,7 +467,8 @@ class DatasetSerializationTestBase(test.TestCase): ckpt_saved=False, init_before_restore=False, sparse_tensors=False, - verify_exhausted=True): + verify_exhausted=True, + save_checkpoint_at_end=True): """Generates elements from input dataset while stopping at break points. Produces `num_outputs` outputs and saves the state of the iterator in the @@ -490,6 +491,10 @@ class DatasetSerializationTestBase(test.TestCase): sparse_tensors: Whether dataset is built from SparseTensor(s). verify_exhausted: Whether to verify that the iterator has been exhausted after producing `num_outputs` elements. + save_checkpoint_at_end: Whether to save a checkpoint after producing all + outputs. If False, checkpoints are saved each break point but not at the + end. Note that checkpoints overwrite each other so there is always only + a single checkpoint available. Defaults to True. Returns: A list of `num_outputs` items. @@ -526,8 +531,9 @@ class DatasetSerializationTestBase(test.TestCase): if i == len(break_points) and verify_exhausted: with self.assertRaises(errors.OutOfRangeError): sess.run(get_next_op) - self._save(sess, saver) - ckpt_saved = True + if save_checkpoint_at_end or i < len(break_points): + self._save(sess, saver) + ckpt_saved = True return outputs diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py new file mode 100644 index 0000000000000000000000000000000000000000..34b6a080c0aae7dfc228746139acc52cea4e6f28 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -0,0 +1,167 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import random_seed +from tensorflow.python.platform import test + + +class DirectedInterleaveDatasetTest(test.TestCase): + + def testBasic(self): + selector_dataset = dataset_ops.Dataset.range(10).repeat(100) + input_datasets = [ + dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) + ] + dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, + input_datasets) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(100): + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def _normalize(self, vec): + return vec / vec.sum() + + def _chi2(self, expected, actual): + actual = np.asarray(actual) + expected = np.asarray(expected) + diff = actual - expected + chi2 = np.sum(diff * diff / expected, axis=0) + return chi2 + + def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): + # Create a dataset that samples each integer in `[0, num_datasets)` + # with probability given by `weights[i]`. + dataset = interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(num_datasets) + ], weights) + dataset = dataset.take(num_samples) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + freqs = np.zeros([num_datasets]) + for _ in range(num_samples): + freqs[sess.run(next_element)] += 1 + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + return freqs + + def testSampleFromDatasets(self): + random_seed.set_random_seed(1619) + num_samples = 5000 + rand_probs = self._normalize(np.random.random_sample((15,))) + + # Use chi-squared test to assert that the observed distribution matches the + # expected distribution. Based on the implementation in + # "tensorflow/python/kernel_tests/multinomial_op_test.py". + for probs in [[.85, .05, .1], rand_probs]: + probs = np.asarray(probs) + classes = len(probs) + freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + + # Also check that `weights` as a dataset samples correctly. + probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() + freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) + self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2) + + def testSelectFromDatasets(self): + words = [b"foo", b"bar", b"baz"] + datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words] + choice_array = np.random.randint(3, size=(15,), dtype=np.int64) + choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array) + dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in choice_array: + self.assertEqual(words[i], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testErrors(self): + with self.assertRaisesRegexp(ValueError, + r"vector of length `len\(datasets\)`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[0.25, 0.25, 0.25, 0.25]) + + with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): + interleave_ops.sample_from_datasets( + [dataset_ops.Dataset.range(10), + dataset_ops.Dataset.range(20)], + weights=[1, 1]) + + with self.assertRaisesRegexp(TypeError, "must have the same type"): + interleave_ops.sample_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(0.0) + ]) + + with self.assertRaisesRegexp(TypeError, "tf.int64"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0)) + + with self.assertRaisesRegexp(TypeError, "scalar"): + interleave_ops.choose_from_datasets([ + dataset_ops.Dataset.from_tensors(0), + dataset_ops.Dataset.from_tensors(1) + ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0])) + + +class SampleFromDatasetsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, probs, num_samples): + dataset = interleave_ops.sample_from_datasets( + [ + dataset_ops.Dataset.from_tensors(i).repeat(None) + for i in range(len(probs)) + ], + probs, + seed=1813) + return dataset.take(num_samples) + + def testSerializationCore(self): + self.run_core_tests( + lambda: self._build_dataset([0.5, 0.5], 100), + lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 43aa4b1bd02791ff304a990c0bbe8e45534c0c77..bee561e3e23a2ab6f314894caa21785347e6ca8b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -30,7 +30,6 @@ from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import random_seed from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -907,114 +906,5 @@ class ParallelInterleaveDatasetTest(test.TestCase): sess.run(self.next_element) -class DirectedInterleaveDatasetTest(test.TestCase): - - def testBasic(self): - selector_dataset = dataset_ops.Dataset.range(10).repeat(100) - input_datasets = [ - dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10) - ] - dataset = interleave_ops.DirectedInterleaveDataset(selector_dataset, - input_datasets) - iterator = dataset.make_initializable_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - sess.run(iterator.initializer) - for _ in range(100): - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def _normalize(self, vec): - return vec / vec.sum() - - def _chi2(self, expected, actual): - actual = np.asarray(actual) - expected = np.asarray(expected) - diff = actual - expected - chi2 = np.sum(diff * diff / expected, axis=0) - return chi2 - - def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples): - # Create a dataset that samples each integer in `[0, num_datasets)` - # with probability given by `weights[i]`. - dataset = interleave_ops.sample_from_datasets([ - dataset_ops.Dataset.from_tensors(i).repeat(None) - for i in range(num_datasets) - ], weights) - dataset = dataset.take(num_samples) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - freqs = np.zeros([num_datasets]) - for _ in range(num_samples): - freqs[sess.run(next_element)] += 1 - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - return freqs - - def testSampleFromDatasets(self): - random_seed.set_random_seed(1619) - num_samples = 10000 - rand_probs = self._normalize(np.random.random_sample((15,))) - - # Use chi-squared test to assert that the observed distribution matches the - # expected distribution. Based on the implementation in - # "tensorflow/python/kernel_tests/multinomial_op_test.py". - for probs in [[.85, .05, .1], rand_probs]: - probs = np.asarray(probs) - classes = len(probs) - freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) - self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) - - # Also check that `weights` as a dataset samples correctly. - probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat() - freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples) - self.assertLess(self._chi2(probs, freqs / num_samples), 1e-3) - - def testErrors(self): - with self.assertRaisesRegexp(ValueError, - r"vector of length `len\(datasets\)`"): - interleave_ops.sample_from_datasets( - [dataset_ops.Dataset.range(10), - dataset_ops.Dataset.range(20)], - weights=[0.25, 0.25, 0.25, 0.25]) - - with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"): - interleave_ops.sample_from_datasets( - [dataset_ops.Dataset.range(10), - dataset_ops.Dataset.range(20)], - weights=[1, 1]) - - with self.assertRaisesRegexp(TypeError, "must have the same type"): - interleave_ops.sample_from_datasets([ - dataset_ops.Dataset.from_tensors(0), - dataset_ops.Dataset.from_tensors(0.0) - ]) - - -class SampleFromDatasetsSerializationTest( - dataset_serialization_test_base.DatasetSerializationTestBase): - - def _build_dataset(self, probs, num_samples): - dataset = interleave_ops.sample_from_datasets( - [ - dataset_ops.Dataset.from_tensors(i).repeat(None) - for i in range(len(probs)) - ], - probs, - seed=1813) - return dataset.take(num_samples) - - def testSerializationCore(self): - self.run_core_tests( - lambda: self._build_dataset([0.5, 0.5], 100), - lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100) - - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..30f1847dcddbfaf379ef2b09185f7a8db4aaeae2 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import optimization +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class OptimizeDatasetTest(test.TestCase): + + def testDefaultOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize()) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testEmptyOptimizations(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize([])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + all([node.op != "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testOptimization(self): + dataset = dataset_ops.Dataset.range(10).map(lambda x: x * x).batch( + 10).apply(optimization.optimize(["map_and_batch_fusion"])) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + graph = graph_pb2.GraphDef().FromString( + sess.run(dataset._as_serialized_graph())) + self.assertTrue( + any([node.op == "MapAndBatchDatasetV2" for node in graph.node])) + self.assertAllEqual([x * x for x in range(10)], sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +class OptimizeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + + def build_dataset(num_elements, batch_size): + return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch( + batch_size).apply(optimization.optimize(["map_and_batch_fusion"])) + + self.run_core_tests(lambda: build_dataset(200, 10), None, 20) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index e0237198b7d47eb98eeffe88d28bf9681b2722c6..3b07ef290bc38daa37472ef8919f3350851fe370 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -24,9 +24,8 @@ import zlib import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import readers -from tensorflow.core.example import example_pb2 -from tensorflow.core.example import feature_pb2 from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import constant_op @@ -280,163 +279,8 @@ def _interleave(iterators, cycle_length): num_open -= 1 -class ReadBatchFeaturesTest(test.TestCase): - - def setUp(self): - super(ReadBatchFeaturesTest, self).setUp() - self._num_files = 2 - self._num_records = 7 - self.test_filenames = self._createFiles() - - def _read_batch_features(self, - filenames, - num_epochs, - batch_size, - reader_num_threads=1, - parser_num_threads=1, - shuffle=False, - shuffle_seed=None, - drop_final_batch=False): - self.filenames = filenames - self.num_epochs = num_epochs - self.batch_size = batch_size - - return readers.make_batched_features_dataset( - file_pattern=self.filenames, - batch_size=self.batch_size, - features={ - "file": parsing_ops.FixedLenFeature([], dtypes.int64), - "record": parsing_ops.FixedLenFeature([], dtypes.int64), - "keywords": parsing_ops.VarLenFeature(dtypes.string) - }, - reader=core_readers.TFRecordDataset, - num_epochs=self.num_epochs, - shuffle=shuffle, - shuffle_seed=shuffle_seed, - reader_num_threads=reader_num_threads, - parser_num_threads=parser_num_threads, - drop_final_batch=drop_final_batch).make_one_shot_iterator( - ).get_next() - - def _record(self, f, r): - example = example_pb2.Example( - features=feature_pb2.Features( - feature={ - "file": - feature_pb2.Feature( - int64_list=feature_pb2.Int64List(value=[f])), - "record": - feature_pb2.Feature( - int64_list=feature_pb2.Int64List(value=[r])), - "keywords": - feature_pb2.Feature( - bytes_list=feature_pb2.BytesList( - value=self._get_keywords(f, r))) - })) - return example.SerializeToString() - - def _get_keywords(self, f, r): - num_keywords = 1 + (f + r) % 2 - keywords = [] - for index in range(num_keywords): - keywords.append(compat.as_bytes("keyword%d" % index)) - return keywords - - def _createFiles(self): - filenames = [] - for i in range(self._num_files): - fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) - filenames.append(fn) - writer = python_io.TFRecordWriter(fn) - for j in range(self._num_records): - writer.write(self._record(i, j)) - writer.close() - return filenames - - def _run_actual_batch(self, outputs, sess): - file_op = outputs["file"] - keywords_indices_op = outputs["keywords"].indices - keywords_values_op = outputs["keywords"].values - keywords_dense_shape_op = outputs["keywords"].dense_shape - record_op = outputs["record"] - return sess.run([ - file_op, keywords_indices_op, keywords_values_op, - keywords_dense_shape_op, record_op - ]) - - def _next_actual_batch(self, sess): - return self._run_actual_batch(self.outputs, sess) - - def _next_expected_batch(self, - file_indices, - batch_size, - num_epochs, - cycle_length=1): - - def _next_record(file_indices): - for j in file_indices: - for i in range(self._num_records): - yield j, i - - def _next_record_interleaved(file_indices, cycle_length): - return _interleave([_next_record([i]) for i in file_indices], - cycle_length) - - file_batch = [] - keywords_batch_indices = [] - keywords_batch_values = [] - keywords_batch_max_len = 0 - record_batch = [] - batch_index = 0 - for _ in range(num_epochs): - if cycle_length == 1: - next_records = _next_record(file_indices) - else: - next_records = _next_record_interleaved(file_indices, cycle_length) - for record in next_records: - f = record[0] - r = record[1] - file_batch.append(f) - record_batch.append(r) - keywords = self._get_keywords(f, r) - keywords_batch_values.extend(keywords) - keywords_batch_indices.extend( - [[batch_index, i] for i in range(len(keywords))]) - batch_index += 1 - keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) - if len(file_batch) == batch_size: - yield [ - file_batch, keywords_batch_indices, keywords_batch_values, - [batch_size, keywords_batch_max_len], record_batch - ] - file_batch = [] - keywords_batch_indices = [] - keywords_batch_values = [] - keywords_batch_max_len = 0 - record_batch = [] - batch_index = 0 - if file_batch: - yield [ - file_batch, keywords_batch_indices, keywords_batch_values, - [len(file_batch), keywords_batch_max_len], record_batch - ] - - def _verify_records(self, - sess, - batch_size, - file_index=None, - num_epochs=1, - interleave_cycle_length=1): - if file_index is not None: - file_indices = [file_index] - else: - file_indices = range(self._num_files) - - for expected_batch in self._next_expected_batch( - file_indices, batch_size, num_epochs, interleave_cycle_length): - actual_batch = self._next_actual_batch(sess) - for i in range(len(expected_batch)): - self.assertAllEqual(expected_batch[i], actual_batch[i]) +class ReadBatchFeaturesTest( + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): def testRead(self): for batch_size in [1, 2]: @@ -444,33 +288,33 @@ class ReadBatchFeaturesTest(test.TestCase): with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from file 0. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, 0, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, 0, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from file 1. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[1], num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, 1, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, 1, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: # Basic test: read from both files. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames, num_epochs=num_epochs, - batch_size=batch_size) - self._verify_records(sess, batch_size, num_epochs=num_epochs) + batch_size=batch_size).make_one_shot_iterator().get_next() + self.verify_records(sess, batch_size, num_epochs=num_epochs) with self.assertRaises(errors.OutOfRangeError): self._next_actual_batch(sess) @@ -504,18 +348,18 @@ class ReadBatchFeaturesTest(test.TestCase): # Test that shuffling with same seed produces the same result. with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - outputs1 = self._read_batch_features( + outputs1 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) - outputs2 = self._read_batch_features( + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) + shuffle_seed=5).make_one_shot_iterator().get_next() for _ in range(total_records // batch_size): batch1 = self._run_actual_batch(outputs1, sess) batch2 = self._run_actual_batch(outputs2, sess) @@ -525,18 +369,18 @@ class ReadBatchFeaturesTest(test.TestCase): # Test that shuffling with different seeds produces a different order. with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - outputs1 = self._read_batch_features( + outputs1 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=5) - outputs2 = self._read_batch_features( + shuffle_seed=5).make_one_shot_iterator().get_next() + outputs2 = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, shuffle=True, - shuffle_seed=15) + shuffle_seed=15).make_one_shot_iterator().get_next() all_equal = True for _ in range(total_records // batch_size): batch1 = self._run_actual_batch(outputs1, sess) @@ -552,13 +396,14 @@ class ReadBatchFeaturesTest(test.TestCase): for parser_num_threads in [2, 4]: with ops.Graph().as_default() as g: with self.test_session(graph=g) as sess: - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames, num_epochs=num_epochs, batch_size=batch_size, reader_num_threads=reader_num_threads, - parser_num_threads=parser_num_threads) - self._verify_records( + parser_num_threads=parser_num_threads).make_one_shot_iterator( + ).get_next() + self.verify_records( sess, batch_size, num_epochs=num_epochs, @@ -571,11 +416,11 @@ class ReadBatchFeaturesTest(test.TestCase): for num_epochs in [1, 10]: with ops.Graph().as_default(): # Basic test: read from file 0. - self.outputs = self._read_batch_features( + self.outputs = self.make_batch_feature( filenames=self.test_filenames[0], num_epochs=num_epochs, batch_size=batch_size, - drop_final_batch=True) + drop_final_batch=True).make_one_shot_iterator().get_next() for _, tensor in self.outputs.items(): if isinstance(tensor, ops.Tensor): # Guard against SparseTensor. self.assertEqual(tensor.shape[0], batch_size) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..805a7c7b7384d53cc166a48ba243502ef8643280 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py @@ -0,0 +1,218 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.ops import readers +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import readers as core_readers +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import python_io +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class ReadBatchFeaturesTestBase(test.TestCase): + """Base class for setting up and testing `make_batched_feature_dataset`.""" + + def setUp(self): + super(ReadBatchFeaturesTestBase, self).setUp() + self._num_files = 2 + self._num_records = 7 + self.test_filenames = self._createFiles() + + def make_batch_feature(self, + filenames, + num_epochs, + batch_size, + reader_num_threads=1, + parser_num_threads=1, + shuffle=False, + shuffle_seed=None, + drop_final_batch=False): + self.filenames = filenames + self.num_epochs = num_epochs + self.batch_size = batch_size + + return readers.make_batched_features_dataset( + file_pattern=self.filenames, + batch_size=self.batch_size, + features={ + "file": parsing_ops.FixedLenFeature([], dtypes.int64), + "record": parsing_ops.FixedLenFeature([], dtypes.int64), + "keywords": parsing_ops.VarLenFeature(dtypes.string) + }, + reader=core_readers.TFRecordDataset, + num_epochs=self.num_epochs, + shuffle=shuffle, + shuffle_seed=shuffle_seed, + reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads, + drop_final_batch=drop_final_batch) + + def _record(self, f, r): + example = example_pb2.Example( + features=feature_pb2.Features( + feature={ + "file": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[f])), + "record": + feature_pb2.Feature( + int64_list=feature_pb2.Int64List(value=[r])), + "keywords": + feature_pb2.Feature( + bytes_list=feature_pb2.BytesList( + value=self._get_keywords(f, r))) + })) + return example.SerializeToString() + + def _get_keywords(self, f, r): + num_keywords = 1 + (f + r) % 2 + keywords = [] + for index in range(num_keywords): + keywords.append(compat.as_bytes("keyword%d" % index)) + return keywords + + def _sum_keywords(self, num_files): + sum_keywords = 0 + for i in range(num_files): + for j in range(self._num_records): + sum_keywords += 1 + (i + j) % 2 + return sum_keywords + + def _createFiles(self): + filenames = [] + for i in range(self._num_files): + fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) + filenames.append(fn) + writer = python_io.TFRecordWriter(fn) + for j in range(self._num_records): + writer.write(self._record(i, j)) + writer.close() + return filenames + + def _run_actual_batch(self, outputs, sess): + file_op = outputs["file"] + keywords_indices_op = outputs["keywords"].indices + keywords_values_op = outputs["keywords"].values + keywords_dense_shape_op = outputs["keywords"].dense_shape + record_op = outputs["record"] + return sess.run([ + file_op, keywords_indices_op, keywords_values_op, + keywords_dense_shape_op, record_op + ]) + + def _next_actual_batch(self, sess): + return self._run_actual_batch(self.outputs, sess) + + def _interleave(self, iterators, cycle_length): + pending_iterators = iterators + open_iterators = [] + num_open = 0 + for i in range(cycle_length): + if pending_iterators: + open_iterators.append(pending_iterators.pop(0)) + num_open += 1 + + while num_open: + for i in range(min(cycle_length, len(open_iterators))): + if open_iterators[i] is None: + continue + try: + yield next(open_iterators[i]) + except StopIteration: + if pending_iterators: + open_iterators[i] = pending_iterators.pop(0) + else: + open_iterators[i] = None + num_open -= 1 + + def _next_expected_batch(self, + file_indices, + batch_size, + num_epochs, + cycle_length=1): + + def _next_record(file_indices): + for j in file_indices: + for i in range(self._num_records): + yield j, i + + def _next_record_interleaved(file_indices, cycle_length): + return self._interleave([_next_record([i]) for i in file_indices], + cycle_length) + + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + for _ in range(num_epochs): + if cycle_length == 1: + next_records = _next_record(file_indices) + else: + next_records = _next_record_interleaved(file_indices, cycle_length) + for record in next_records: + f = record[0] + r = record[1] + file_batch.append(f) + record_batch.append(r) + keywords = self._get_keywords(f, r) + keywords_batch_values.extend(keywords) + keywords_batch_indices.extend( + [[batch_index, i] for i in range(len(keywords))]) + batch_index += 1 + keywords_batch_max_len = max(keywords_batch_max_len, len(keywords)) + if len(file_batch) == batch_size: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [batch_size, keywords_batch_max_len], record_batch + ] + file_batch = [] + keywords_batch_indices = [] + keywords_batch_values = [] + keywords_batch_max_len = 0 + record_batch = [] + batch_index = 0 + if file_batch: + yield [ + file_batch, keywords_batch_indices, keywords_batch_values, + [len(file_batch), keywords_batch_max_len], record_batch + ] + + def verify_records(self, + sess, + batch_size, + file_index=None, + num_epochs=1, + interleave_cycle_length=1): + if file_index is not None: + file_indices = [file_index] + else: + file_indices = range(self._num_files) + + for expected_batch in self._next_expected_batch( + file_indices, batch_size, num_epochs, interleave_cycle_length): + actual_batch = self._next_actual_batch(sess) + for i in range(len(expected_batch)): + self.assertAllEqual(expected_batch[i], actual_batch[i]) diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index bcc644c0971854d948025009dc7add2fea214048..1b67a33f04b0c2ac80402e163005123a4b3e4400 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -20,11 +20,13 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib class ShuffleDatasetSerializationTest( @@ -50,26 +52,100 @@ class ShuffleDatasetSerializationTest( num_repeats = 5 num_outputs = range_limit * num_repeats buffer_sizes = [1, 3, 8, 10, 25, 50] - reshuffle_each_iteration = False # pylint: disable=cell-var-from-loop # pylint: disable=g-long-lambda - for buffer_size in buffer_sizes: - self.run_core_tests( - lambda: self._build_shuffle_dataset( + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + self.run_core_tests( + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration), + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=10, + reshuffle_each_iteration=reshuffle_each_iteration), + num_outputs) + # pylint: enable=cell-var-from-loop + # pylint: enable=g-long-lambda + + def testNonDeterministicSeeding(self): + + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 8, 10, 25, 50] + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration), - lambda: self._build_shuffle_dataset( + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + # We checkpoint the initial state of the Dataset so that we can restore + # the seeds in the next run. Since the seeding is non-deterministic + # the dataset gets initialized with different seeds each time. + expected = self.gen_outputs( + ds_fn, + break_points=[0], + num_outputs=num_outputs, + ckpt_saved=False, + verify_exhausted=False, + save_checkpoint_at_end=False) + actual = self.gen_outputs( + ds_fn, + break_points=self.gen_break_points(num_outputs), + num_outputs=num_outputs, + ckpt_saved=True, + verify_exhausted=False) + self.match(expected, actual) + + def testMultipleIterators(self): + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 8, 10, 25, 50] + + for reshuffle_each_iteration in [True, False]: + for buffer_size in buffer_sizes: + + def ds_fn(): + # pylint: disable=cell-var-from-loop + return self._build_shuffle_dataset( range_limit=range_limit, num_repeats=num_repeats, buffer_size=buffer_size, - seed=10, - reshuffle_each_iteration=reshuffle_each_iteration), - num_outputs) - # pylint: enable=cell-var-from-loop - # pylint: enable=g-long-lambda + seed=None, # Iterator seeds are generated non-deterministically. + reshuffle_each_iteration=reshuffle_each_iteration) + # pylint: enable=cell-var-from-loop + + with ops.Graph().as_default() as g: + ds = ds_fn() + iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()] + get_next_ops = [it.get_next() for it in iterators] + saveables = [ + contrib_iterator_ops.make_saveable_from_iterator(it) + for it in iterators + ] + for saveable in saveables: + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + saver = saver_lib.Saver(allow_empty=True) + with self.test_session(graph=g) as sess: + self._save(sess, saver) + expected = [sess.run(get_next_ops) for _ in range(num_outputs)] + self._restore(saver, sess) + actual = [sess.run(get_next_ops) for _ in range(num_outputs)] + self.match(expected, actual) class ShuffleAndRepeatTest( diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 5c74ed6ae7210e8e22efb6e8fdb773397459ce1e..17b6644759e53f84b23e070a71267aa15dcffe49 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops @@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class StatsDatasetTest(test.TestCase): +class StatsDatasetTestBase(test.TestCase): def _assertSummaryHasCount(self, summary_str, tag, expected_value): summary_proto = summary_pb2.Summary() @@ -49,6 +50,9 @@ class StatsDatasetTest(test.TestCase): return self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + +class StatsDatasetTest(StatsDatasetTestBase): + def testBytesProduced(self): stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).map( @@ -193,6 +197,45 @@ class StatsDatasetTest(test.TestCase): self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) +class FeatureStatsDatasetTest( + StatsDatasetTestBase, + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): + + def testFeaturesStats(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + batch_size = 2 + stats_aggregator = stats_ops.StatsAggregator() + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5, + drop_final_batch=True).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(total_records // batch_size): + sess.run(next_element) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats:features", total_records) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats:feature-values", total_records) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats:features", total_records * 3) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats:feature-values", + self._sum_keywords(1) * num_epochs + 2 * total_records) + + class StatsDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index eceecfd1744d0ae28953a4504450653efa473569..33b7a75046cf2acfa3d787833b907aa2b28dbdca 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -96,8 +96,10 @@ py_library( srcs_version = "PY2AND3", deps = [ ":batching", + ":gen_dataset_ops", ":interleave_ops", ":shuffle_ops", + ":stats_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", @@ -106,12 +108,12 @@ py_library( "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", - "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], @@ -142,6 +144,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", ], @@ -208,6 +211,20 @@ py_library( ], ) +py_library( + name = "optimization", + srcs = ["optimization.py"], + srcs_version = "PY2AND3", + deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + py_library( name = "resampling", srcs = ["resampling.py"], @@ -368,6 +385,7 @@ py_library( ":get_single_element", ":grouping", ":interleave_ops", + ":optimization", ":prefetching_ops", ":readers", ":resampling", diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index b9393de4e90ae2597045b29070934b94e18cfcbd..052618e08c8f204613db5a20d42e078f17f12840 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.contrib.framework import with_shape from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -29,6 +30,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops +from tensorflow.python.util import deprecation def dense_to_sparse_batch(batch_size, row_shape): @@ -101,10 +103,7 @@ class UnbatchDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.unbatch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): @@ -218,6 +217,8 @@ def filter_irregular_batches(batch_size): return _apply_fn +@deprecation.deprecated( + None, "Use `tf.data.Dataset.batch(..., drop_remainder=True)`.") def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). @@ -250,12 +251,16 @@ def batch_and_drop_remainder(batch_size): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" + # TODO(jsimsa): Switch to using `batch(..., drop_remainder=True)` any time + # after 6/30/2018. batched = dataset.batch(batch_size) return filter_irregular_batches(batch_size)(batched) return _apply_fn +@deprecation.deprecated( + None, "Use `tf.data.Dataset.padded_batch(..., drop_remainder=True)`.") def padded_batch_and_drop_remainder(batch_size, padded_shapes, padding_values=None): @@ -284,6 +289,8 @@ def padded_batch_and_drop_remainder(batch_size, def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" + # TODO(jsimsa): Switch to using `padded_batch(..., drop_remainder=True)` + # any time after 6/30/2018. batched = dataset.padded_batch( batch_size, padded_shapes=padded_shapes, padding_values=padding_values) return filter_irregular_batches(batch_size)(batched) @@ -309,11 +316,8 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): return gen_dataset_ops.dense_to_sparse_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._batch_size, - row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + row_shape=convert.partial_shape_to_tensor(self._row_shape), + **dataset_ops.flat_structure(self)) @property def output_classes(self): @@ -490,10 +494,7 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): batch_size=self._batch_size_t, num_parallel_calls=self._num_parallel_calls_t, drop_remainder=self._drop_remainder_t, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 6c21e489f7c35484ebacd465e3b46d6920df5933..5f5513849cb29a18b86ba8bcee1ab6c9c60674cb 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -20,8 +20,6 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse def ignore_errors(): @@ -64,10 +62,7 @@ class IgnoreErrorsDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.ignore_errors_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py index 3a07df572748e464284f580d67e3a664e71acdfe..0f4cd8e20c5727a5bcfa1dce4dadbfa8f90bd551 100644 --- a/tensorflow/contrib/data/python/ops/get_single_element.py +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -64,10 +64,7 @@ def get_single_element(dataset): nested_ret = nest.pack_sequence_as( dataset.output_types, gen_dataset_ops.dataset_to_single_element( dataset._as_variant_tensor(), # pylint: disable=protected-access - output_types=nest.flatten(sparse.as_dense_types( - dataset.output_types, dataset.output_classes)), - output_shapes=nest.flatten(sparse.as_dense_shapes( - dataset.output_shapes, dataset.output_classes)))) + **dataset_ops.flat_structure(dataset))) return sparse.deserialize_sparse_tensors( nested_ret, dataset.output_types, dataset.output_shapes, dataset.output_classes) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index ea229b5b27b117984e508fa4edc6f1cf713008b4..f9f25e6a0687fe7167525847c64743f52a551fb0 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -300,6 +300,7 @@ class GroupByReducerDataset(dataset_ops.Dataset): raise ValueError( "`key_func` must return a single tf.int64 tensor. " "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) + dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access return ret self._key_func = tf_key_func @@ -327,6 +328,8 @@ class GroupByReducerDataset(dataset_ops.Dataset): self._state_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) + dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access + # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) @@ -398,6 +401,8 @@ class GroupByReducerDataset(dataset_ops.Dataset): nest.pack_sequence_as(self._state_types, [t.dtype for t in flat_new_state]))) + dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access + # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, @@ -464,6 +469,8 @@ class GroupByReducerDataset(dataset_ops.Dataset): self._output_types = nest.pack_sequence_as( ret, [t.dtype for t in nest.flatten(ret)]) + dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()") # pylint: disable=protected-access + # Serialize any sparse tensors. ret = nest.pack_sequence_as( ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) @@ -495,10 +502,7 @@ class GroupByReducerDataset(dataset_ops.Dataset): init_func=self._init_func, reduce_func=self._reduce_func, finalize_func=self._finalize_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) class GroupByWindowDataset(dataset_ops.Dataset): @@ -525,6 +529,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): if window_size.dtype != dtypes.int64: raise ValueError( "`window_size_func` must return a single tf.int64 tensor.") + dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access return window_size self._window_size_func = tf_window_size_func @@ -557,6 +562,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) if ret.dtype != dtypes.int64: raise ValueError("`key_func` must return a single tf.int64 tensor.") + dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access return ret self._key_func = tf_key_func @@ -580,6 +586,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): self._output_classes = output_dataset.output_classes self._output_types = output_dataset.output_types self._output_shapes = output_dataset.output_shapes + dataset_ops._warn_if_collections("tf.contrib.data.group_by_window()") # pylint: disable=protected-access return output_dataset._as_variant_tensor() # pylint: disable=protected-access self._reduce_func = tf_reduce_func @@ -606,10 +613,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): key_func=self._key_func, reduce_func=self._reduce_func, window_size_func=self._window_size_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) class Reducer(object): diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 812a50ecbf105393f7e422edbbdf5c87311d72c1..70153ac575758f16beff373941dfefb32bd342cf 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -24,9 +24,9 @@ from tensorflow.contrib.data.python.ops import random_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation @@ -170,10 +170,7 @@ class DirectedInterleaveDataset(dataset_ops.Dataset): return gen_dataset_ops.directed_interleave_dataset( self._selector_input._as_variant_tensor(), [data_input._as_variant_tensor() for data_input in self._data_inputs], - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property @@ -240,3 +237,47 @@ def sample_from_datasets(datasets, weights=None, seed=None): (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset) return DirectedInterleaveDataset(selector_input, datasets) + + +def choose_from_datasets(datasets, choice_dataset): + """Creates a dataset that deterministically chooses elements from `datasets`. + + For example, given the following datasets: + + ```python + datasets = [tf.data.Dataset.from_tensors("foo").repeat(), + tf.data.Dataset.from_tensors("bar").repeat(), + tf.data.Dataset.from_tensors("baz").repeat()] + + # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`. + choice_dataset = tf.data.Dataset.range(3).repeat(3) + + result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset) + ``` + + The elements of `result` will be: + + ``` + "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz" + ``` + + Args: + datasets: A list of @{tf.data.Dataset} objects with compatible structure. + choice_dataset: A @{tf.data.Dataset} of scalar `tf.int64` tensors between + `0` and `len(datasets) - 1`. + + Returns: + A dataset that interleaves elements from `datasets` according to the values + of `choice_dataset`. + + Raises: + TypeError: If the `datasets` or `choice_dataset` arguments have the wrong + type. + """ + if not (choice_dataset.output_types == dtypes.int64 + and choice_dataset.output_shapes.is_compatible_with( + tensor_shape.scalar()) + and choice_dataset.output_classes == ops.Tensor): + raise TypeError("`choice_dataset` must be a dataset of scalar " + "`tf.int64` tensors.") + return DirectedInterleaveDataset(choice_dataset, datasets) diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..9612ac5ae910f8ee08d4b3ed9097a5c80266fcfd --- /dev/null +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -0,0 +1,74 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for optimizing `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +def optimize(optimizations=None): + """A transformation that applies optimizations. + + Args: + optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying + optimizations to use. If not specified, the default set of optimizations + is applied. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return OptimizeDataset(dataset, optimizations) + + return _apply_fn + + +class OptimizeDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and applies optimizations.""" + + def __init__(self, input_dataset, optimizations): + """See `optimize()` for details.""" + super(OptimizeDataset, self).__init__() + self._input_dataset = input_dataset + if optimizations is None: + optimizations = [] + self._optimizations = ops.convert_to_tensor( + optimizations, dtype=dtypes.string, name="optimizations") + + def _as_variant_tensor(self): + return gen_dataset_ops.optimize_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._optimizations, + **dataset_ops.flat_structure(self)) + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py index 28ef5e50f39dd7d1b6f124e58e068fc968ddd6dc..e670c4c8354f4067eb21c9b1fce708147c162967 100644 --- a/tensorflow/contrib/data/python/ops/random_ops.py +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -18,9 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -39,10 +37,7 @@ class RandomDataset(dataset_ops.Dataset): return gen_dataset_ops.random_dataset( seed=self._seed, seed2=self._seed2, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 75c31a944a09462f534f6ae3e3204c812ecf28d9..83095c7ba1c6465d18490e5197f71bf7f1fe2497 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import csv import numpy as np @@ -25,6 +26,7 @@ from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import convert @@ -467,11 +469,11 @@ def make_csv_dataset( Args: *columns: list of `Tensor`s corresponding to one csv record. Returns: - A dictionary of feature names to values for that particular record. If + An OrderedDict of feature names to values for that particular record. If label_name is provided, extracts the label feature to be returned as the second element of the tuple. """ - features = dict(zip(column_names, columns)) + features = collections.OrderedDict(zip(column_names, columns)) if label_name is not None: label = features.pop(label_name) return features, label @@ -753,6 +755,8 @@ def make_batched_features_dataset(file_pattern, dataset = _maybe_shuffle_and_repeat( dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed) + dataset = dataset.apply(stats_ops.feature_stats("record_stats")) + if drop_final_batch: dataset = dataset.apply(batching.batch_and_drop_remainder(batch_size)) else: diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py index bad6edd5147d832228c412919f1e6e782aafc40f..182a5c6ff36fcda8c9e2c522cce07bed0c2daec9 100644 --- a/tensorflow/contrib/data/python/ops/resampling.py +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -291,4 +291,4 @@ def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs): # TODO(joelshor): Simplify fraction, if possible. a_i = (ratio_l - m) / (max_ratio - m) - return a_i, m \ No newline at end of file + return a_i, m diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index e911ad0fa0541f2d8b991d66182dd002c2ecaab0..67eede981cb8f685ba4840d1f3c12bfea54c7646 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -148,6 +148,8 @@ class _ScanDataset(dataset_ops.Dataset): self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in nest.flatten(output_value)]) + dataset_ops._warn_if_collections("tf.contrib.data.scan()") # pylint: disable=protected-access + # Serialize any sparse tensors. new_state = nest.pack_sequence_as(new_state, [ t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state)) @@ -193,10 +195,7 @@ class _ScanDataset(dataset_ops.Dataset): nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)), self._scan_func.captured_inputs, f=self._scan_func, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py index f35795abd38000b13cec0f08596e2ff66e86286c..d7f8a73fe3d67bb83e44e962832ce34c116aef66 100644 --- a/tensorflow/contrib/data/python/ops/shuffle_ops.py +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -18,9 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest from tensorflow.python.data.util import random_seed -from tensorflow.python.data.util import sparse from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -56,10 +54,7 @@ class _ShuffleAndRepeatDataset(dataset_ops.Dataset): count=self._count, seed=self._seed, seed2=self._seed2, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) # pylint: enable=protected-access @property diff --git a/tensorflow/contrib/data/python/ops/sliding.py b/tensorflow/contrib/data/python/ops/sliding.py index 19cc3cb89fc5c494f79ce1d25ed57c92099c8bd2..f935beb1a9e85d4901857e7781a5ed8473838fa5 100644 --- a/tensorflow/contrib/data/python/ops/sliding.py +++ b/tensorflow/contrib/data/python/ops/sliding.py @@ -19,7 +19,6 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -43,10 +42,7 @@ class _SlideDataset(dataset_ops.Dataset): self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access window_size=self._window_size, stride=self._stride, - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 3cbaab5affd7397213b0fbb6b0682db92b99d591..3c82a03df1745d855b2d3f918f7bbde113600556 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -18,8 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops @@ -97,10 +95,7 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset): return gen_dataset_ops.set_stats_aggregator_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._stats_aggregator._resource, # pylint: disable=protected-access - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): @@ -176,6 +171,27 @@ def latency_stats(tag): return _apply_fn +def feature_stats(tag): + """Records the features stats from `Example` records of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with the output + dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will be + associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.feature_stats_dataset, tag) + + return _apply_fn + + class _StatsDataset(dataset_ops.Dataset): """A `Dataset` that acts as an identity, and also records statistics.""" @@ -189,10 +205,7 @@ class _StatsDataset(dataset_ops.Dataset): return self._op_function( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._tag, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index 56f67e1766bbaff680bdff6b939df0c3ba68c679..bb49604d4de90d726418684124608438aa33e6cf 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -22,8 +22,6 @@ import threading from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.eager import context from tensorflow.python.ops import resource_variable_ops @@ -69,10 +67,7 @@ class _ThreadPoolDataset(dataset_ops.Dataset): return gen_dataset_ops.thread_pool_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._thread_pool._resource, # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index 765ef3f9b6d42c9d7af3ce4916731d37d65c9260..4ce6ddede8350735636fd152fdc9df0319265990 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -20,8 +20,6 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -65,10 +63,7 @@ class UniqueDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.unique_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten( - sparse.as_dense_shapes(self.output_shapes, self.output_classes)), - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes))) + **dataset_ops.flat_structure(self)) @property def output_classes(self): diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 64a77bbed1d55c3d95329d9c7783c2b468bde745..9dfb8552f1b0f058b44f8ed09c2ed681367293d5 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -77,6 +77,7 @@ py_library( "//tensorflow/python:device_util", "//tensorflow/python:distribute", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:training", "//tensorflow/python:variable_scope", @@ -148,6 +149,7 @@ py_library( ], deps = [ ":mirrored_strategy", + ":multi_worker_strategy", ":one_device_strategy", ":tpu_strategy", "//tensorflow/contrib/optimizer_v2:training", @@ -445,10 +447,31 @@ py_library( srcs = ["cross_tower_utils.py"], srcs_version = "PY2AND3", deps = [ + ":values", + "//tensorflow/contrib/all_reduce:all_reduce_py", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + ], +) + +cuda_py_test( + name = "cross_tower_utils_test", + srcs = ["cross_tower_utils_test.py"], + additional_deps = [ + ":combinations", + ":cross_tower_utils", + "@absl_py//absl/testing:parameterized", + "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], + tags = [ + "no_pip", ], ) @@ -476,6 +499,7 @@ cuda_py_test( additional_deps = [ ":combinations", ":cross_tower_ops", + ":multi_worker_test_base", ":values", "@absl_py//absl/testing:parameterized", "//tensorflow/python:array_ops", @@ -485,6 +509,7 @@ cuda_py_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", ], + shard_count = 15, tags = [ "multi_and_single_gpu", "no_pip", @@ -547,3 +572,41 @@ cuda_py_test( "no_pip", ], ) + +cuda_py_test( + name = "keras_test", + srcs = ["keras_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/contrib/distribute/python:mirrored_strategy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:training", + "//tensorflow/python/estimator:keras", + "//tensorflow/python/estimator:run_config", + "//tensorflow/python/keras", + ], + tags = [ + "multi_and_single_gpu", + "noguitar", + "notsan", + ], +) + +cuda_py_test( + name = "metrics_v1_test", + srcs = ["metrics_v1_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 15935817b0283ebc04b95304afe41d8690a11442..ba03b14deb9a3897dae29382ce601c0319f84735 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -41,11 +41,15 @@ from __future__ import print_function from collections import OrderedDict import sys +import types +import unittest from absl.testing import parameterized +import six -from tensorflow.contrib.distribute.python import mirrored_strategy -from tensorflow.contrib.distribute.python import one_device_strategy -from tensorflow.contrib.distribute.python import tpu_strategy +from tensorflow.contrib.distribute.python import mirrored_strategy as mirrored_lib +from tensorflow.contrib.distribute.python import multi_worker_strategy +from tensorflow.contrib.distribute.python import one_device_strategy as one_device_lib +from tensorflow.contrib.distribute.python import tpu_strategy as tpu_lib from tensorflow.contrib.optimizer_v2 import adam as adam_v2 from tensorflow.contrib.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.eager import context @@ -67,8 +71,8 @@ def generate(combinations): combinations: a list of dictionaries created using combine() and times(). Restrictions: - -- there should always be a "mode" argument. Accepted values are "eager" - and "graph". + -- the "mode" argument can be either "eager" or "graph". It's "graph" by + default. -- arguments of the test method must match by name to get the corresponding value of the combination. Tests must accept all arguments except the "mode", "required_tpu" and "required_gpus". @@ -83,14 +87,15 @@ def generate(combinations): test will be skipped if the specified number of GPUs aren't available. Returns: - a decorator that will cause the test method to be run under the specified - conditions. + a decorator that will cause the test method or the test class to be run + under the specified conditions. Raises: - ValueError - if "mode" argument wasn't either "eager" or "graph". + ValueError - if "mode" argument wasn't either "eager" or "graph" or if other + arguments were not accepted by the test method. """ - def decorator(test_function): + def decorator(test_method_or_class): """The decorator to be returned.""" # Generate good test names that can be used with --test_filter. @@ -110,70 +115,91 @@ def generate(combinations): list(combination.items()) + [("testcase_name", "_test{}".format(name))])) - @parameterized.named_parameters(*named_combinations) - def decorated(self, **kwargs): - """A wrapped test method that sets up `test_function`.""" - assert "mode" in kwargs - mode = kwargs["mode"] - - distribution = kwargs.pop("distribution", None) - required_tpu = kwargs.pop("required_tpu", False) - required_gpus = kwargs.pop("required_gpus", None) - - if distribution: - assert required_gpus is None, ( - "Do not use `required_gpus` and `distribution` together.") - assert required_tpu is False, ( - "Do not use `required_tpu` and `distribution` together.") - kwargs["distribution"] = distribution.strategy - required_gpus = distribution.required_gpus - required_tpu = distribution.required_tpu - - if required_tpu and not TPU_TEST: - self.skipTest("Test requires a TPU, but it's not available.") - if not required_tpu and TPU_TEST: - self.skipTest("Test that doesn't require a TPU.") - - if not required_gpus: - if GPU_TEST: - self.skipTest("Test that doesn't require GPUs.") - elif context.num_gpus() < required_gpus: - self.skipTest( - "{} GPUs are not available for this test. {} GPUs are available". - format(required_gpus, context.num_gpus())) - - # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu` - # that the user might have specified. `kwargs` still has `mode`, which - # the test is allowed to accept or ignore. - requested_arguments = tf_inspect.getfullargspec(test_function).args - missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( - set(requested_arguments + ["mode"])) - if missing_arguments: - raise ValueError("The test is missing arguments {} .".format( - missing_arguments)) - - kwargs_to_pass = {} - for arg in requested_arguments: - if arg == "self": - kwargs_to_pass[arg] = self - else: - kwargs_to_pass[arg] = kwargs[arg] - - if mode == "eager": - with context.eager_mode(), ops.Graph().as_default(): - test_function(**kwargs_to_pass) - elif mode == "graph": - with context.graph_mode(), ops.Graph().as_default(): - test_function(**kwargs_to_pass) - else: - raise ValueError( - "'mode' has to be either 'eager' or 'graph' and not {}".format( - mode)) + if isinstance(test_method_or_class, type): + class_object = test_method_or_class + class_object._test_method_ids = test_method_ids = {} + for name, test_method in six.iteritems(class_object.__dict__.copy()): + if (name.startswith(unittest.TestLoader.testMethodPrefix) and + isinstance(test_method, types.FunctionType)): + delattr(class_object, name) + methods = {} + parameterized._update_class_dict_for_param_test_case( + class_object.__name__, methods, test_method_ids, name, + parameterized._ParameterizedTestIter( + _augment_with_special_arguments(test_method), + named_combinations, parameterized._NAMED, name)) + for method_name, method in six.iteritems(methods): + setattr(class_object, method_name, method) + + return class_object + else: + test_method = _augment_with_special_arguments(test_method_or_class) + return parameterized.named_parameters(*named_combinations)(test_method) - return decorated return decorator +def _augment_with_special_arguments(test_method): + def decorated(self, **kwargs): + """A wrapped test method that treats some arguments in a special way.""" + mode = kwargs.pop("mode", "graph") + + distribution = kwargs.pop("distribution", None) + required_tpu = kwargs.pop("required_tpu", False) + required_gpus = kwargs.pop("required_gpus", None) + + if distribution: + assert required_gpus is None, ( + "Do not use `required_gpus` and `distribution` together.") + assert required_tpu is False, ( + "Do not use `required_tpu` and `distribution` together.") + kwargs["distribution"] = distribution.strategy + required_gpus = distribution.required_gpus + required_tpu = distribution.required_tpu + + if required_tpu and not TPU_TEST: + self.skipTest("Test requires a TPU, but it's not available.") + if not required_tpu and TPU_TEST: + self.skipTest("Test that doesn't require a TPU.") + + if not required_gpus: + if GPU_TEST: + self.skipTest("Test that doesn't require GPUs.") + elif context.num_gpus() < required_gpus: + self.skipTest( + "{} GPUs are not available for this test. {} GPUs are available". + format(required_gpus, context.num_gpus())) + + # At this point, `kwargs` doesn't have `required_gpus` or `required_tpu` + # that the user might have specified. `kwargs` still has `mode`, which + # the test is allowed to accept or ignore. + requested_arguments = tf_inspect.getfullargspec(test_method).args + missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( + set(requested_arguments + ["mode"])) + if missing_arguments: + raise ValueError("The test is missing arguments {} .".format( + missing_arguments)) + + kwargs_to_pass = {} + for arg in requested_arguments: + if arg == "self": + kwargs_to_pass[arg] = self + else: + kwargs_to_pass[arg] = kwargs[arg] + + if mode == "eager": + with ops.Graph().as_default(), context.eager_mode(): + test_method(**kwargs_to_pass) + elif mode == "graph": + with ops.Graph().as_default(), context.graph_mode(): + test_method(**kwargs_to_pass) + else: + raise ValueError( + "'mode' has to be either 'eager' or 'graph' and not {}".format( + mode)) + return decorated + + def combine(**kwargs): """Generate combinations based on its keyword arguments. @@ -264,9 +290,9 @@ class NamedObject(object): class NamedDistribution(object): """Translates DistributionStrategy and its data into a good name.""" - def __init__(self, name, distribution, required_gpus=None, + def __init__(self, name, distribution_fn, required_gpus=None, required_tpu=False): - self._distribution = distribution + self._distribution_fn = distribution_fn self._name = name self._required_gpus = required_gpus self._required_tpu = required_tpu @@ -276,7 +302,7 @@ class NamedDistribution(object): @property def strategy(self): - return self._distribution + return self._distribution_fn() @property def required_gpus(self): @@ -287,32 +313,60 @@ class NamedDistribution(object): return self._required_tpu +# pylint: disable=g-long-lambda default_strategy = NamedDistribution( "Default", - distribute_lib._default_distribution_strategy, # pylint: disable=protected-access + lambda: distribute_lib._default_distribution_strategy, # pylint: disable=protected-access required_gpus=None) one_device_strategy = NamedDistribution( - "OneDeviceCPU", one_device_strategy.OneDeviceStrategy("/cpu:0"), + "OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"), required_gpus=None) tpu_strategy_single_iteration = NamedDistribution( "TPUSingleIteration", - tpu_strategy.TPUStrategy(iterations_per_step=1), + lambda: tpu_lib.TPUStrategy(iterations_per_step=1), required_tpu=True) -tpu_strategy = NamedDistribution( - "TPU", tpu_strategy.TPUStrategy(), required_tpu=True) +tpu_strategy = NamedDistribution("TPU", tpu_lib.TPUStrategy, required_tpu=True) # Note that we disable prefetching for testing since prefetching makes # the input non-deterministic. mirrored_strategy_with_gpu_and_cpu = NamedDistribution( "MirroredCPUAndGPU", - mirrored_strategy.MirroredStrategy( + lambda: mirrored_lib.MirroredStrategy( ["/gpu:0", "/cpu:0"], prefetch_on_device=False), required_gpus=1) mirrored_strategy_with_two_gpus = NamedDistribution( "Mirrored2GPUs", - mirrored_strategy.MirroredStrategy( + lambda: mirrored_lib.MirroredStrategy( ["/gpu:0", "/gpu:1"], prefetch_on_device=False), required_gpus=2) +multi_worker_strategy_with_cpu = NamedDistribution( + "MultiWorkerCPU", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=0), 0) +multi_worker_strategy_with_one_gpu = NamedDistribution( + "MultiWorker1GPU", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=1), 1) +multi_worker_strategy_with_two_gpus = NamedDistribution( + "MultiWorker2GPUs", + lambda: multi_worker_strategy.MultiWorkerMirroredStrategy( + cluster={ + "worker": [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + }, + num_gpus_per_worker=2), 2) + adam_optimizer_v1_fn = NamedObject( "AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1)) gradient_descent_optimizer_v1_fn = NamedObject( diff --git a/tensorflow/contrib/distribute/python/combinations_test.py b/tensorflow/contrib/distribute/python/combinations_test.py index 184bcf27e59d68b82d28e8f01890c04f214c017c..86aa48cea889c6c2ce169b18bcabb6d08890fbed 100644 --- a/tensorflow/contrib/distribute/python/combinations_test.py +++ b/tensorflow/contrib/distribute/python/combinations_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from collections import OrderedDict +from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.python.eager import test @@ -120,5 +121,28 @@ class TestingCombinationsTest(test.TestCase): _ = combinations.times(c1, c2) +@combinations.generate(combinations.combine(a=[1, 0], b=[2, 3], c=[1])) +class CombineTheTestSuite(parameterized.TestCase): + + def test_add_things(self, a, b, c): + self.assertLessEqual(3, a + b + c) + self.assertLessEqual(a + b + c, 5) + + def test_add_things_one_more(self, a, b, c): + self.assertLessEqual(3, a + b + c) + self.assertLessEqual(a + b + c, 5) + + def not_a_test(self, a=0, b=0, c=0): + del a, b, c + self.fail() + + def _test_but_private(self, a=0, b=0, c=0): + del a, b, c + self.fail() + + # Check that nothing funny happens to a non-callable that starts with "_test". + test_member = 0 + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index c6a1bf6a9f65828c45617ae18a1b0989f9d46225..f8ae8b9712c392fa948c8598dd123cdea01d9866 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import six from tensorflow.contrib.distribute.python import cross_tower_utils @@ -77,12 +78,12 @@ def _all_devices_match(value_destination_pairs): return True -def _simple_broadcast(tensor, destinations): +def _simple_broadcast(value, destinations): index = {} devices = _get_devices_from(destinations) for d in devices: - with ops.device(d): - index[d] = array_ops.identity(tensor) + index[d] = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + value, d) return value_lib.Mirrored(index) @@ -98,7 +99,9 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, continue count += len(v_list) # Sum within each device before aggregating across devices. - v = math_ops.add_n(v_list) + # TODO(yuefengz): Check whether it helps to use accumulation_fn here. + v = cross_tower_utils.aggregate_tensors_or_indexed_slices( + v_list, math_ops.add_n) else: count += 1 all_values.append(v) @@ -107,11 +110,12 @@ def _simple_reduce(per_device_value, reduce_to_device, accumulation_fn, with ops.device(reduce_to_device): with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): - if method_string == "sum": - reduced = accumulation_fn(all_values) - elif method_string == "mean": - reduced = accumulation_fn(all_values) / count - else: + reduced = cross_tower_utils.aggregate_tensors_or_indexed_slices( + all_values, accumulation_fn) + if method_string == "mean": + reduced = cross_tower_utils.divide_by_n_tensors_or_indexed_slices( + reduced, count) + elif method_string != "sum": raise ValueError("`method_string` must be 'sum' or 'mean'") return reduced @@ -231,7 +235,13 @@ class ReductionToOneDeviceCrossTowerOps(CrossTowerOps): def _group_value_by_device(per_device_values): """Group values into sublists by their devices. - This grouping is needed to call the all-reduce library. + This grouping is needed to call the all-reduce library because it expects a + list of the following form: + [(grad0_gpu0, v0_gpu0), (grad1_gpu0, v1_gpu0), (grad2_gpu0, v2_gpu0) ... + (grad0_gpu1, v0_gpu1), (grad1_gpu1, v1_gpu1), (grad2_gpu1, v2_gpu1) ... + (grad0_gpu2, v0_gpu2), (grad1_gpu0, v1_gpu2), (grad2_gpu0, v2_gpu2) ... + ... + ] Args: per_device_values: a list of PerDevice obejcts. @@ -319,7 +329,17 @@ class ConcatAndSplitPacker(object): # TODO(zhengxq): it is also possible to optimize away all the concat # as well. num_splits = self.num_packs - total_grad_size = array_ops.size(concat_grads) + + # The array_ops.size function will sometimes remove static shapes. So if + # all gradient shapes are defined, we use another method to get the + # total size. + # TODO(yuefengz): move this logic to array_ops.size. + if all([g.shape.is_fully_defined() for g, _ in tower_grads_and_vars]): + total_grad_size = sum( + [g.shape.num_elements() for g, _ in tower_grads_and_vars]) + else: + total_grad_size = array_ops.size(concat_grads) + split_size = total_grad_size // num_splits split_size_last = total_grad_size - split_size * (num_splits - 1) split_sizes = [split_size] * (num_splits - 1) + [split_size_last] @@ -409,6 +429,31 @@ class AggregateSmallTensorPacker(object): self.packing) +def _pack_tensors(device_grads, + num_packs=0, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=0): + """Pack tensors if specified.""" + if num_packs > 0: + tensor_packer = ConcatAndSplitPacker(num_packs) + device_grad_packs = tensor_packer.pack(device_grads) + elif agg_small_grads_max_bytes > 0 and agg_small_grads_max_group > 0: + tensor_packer = AggregateSmallTensorPacker(agg_small_grads_max_bytes, + agg_small_grads_max_group) + device_grad_packs = tensor_packer.pack(device_grads) + else: + tensor_packer = None + device_grad_packs = device_grads + return device_grad_packs, tensor_packer + + +def _unpack_tensors(reduced, tensor_packer=None): + """Unpack tensors if they are packed before all-reduce.""" + if tensor_packer: + return tensor_packer.unpack(reduced) + return reduced + + class AllReduceCrossTowerOps(CrossTowerOps): """Reduction using all reduce.""" @@ -437,17 +482,25 @@ class AllReduceCrossTowerOps(CrossTowerOps): agg_small_grads_max_group: see above. tensors. """ - self.all_reduce_alg = all_reduce_alg - self.num_packs = num_packs - self.agg_small_grads_max_bytes = agg_small_grads_max_bytes - self.agg_small_grads_max_group = agg_small_grads_max_group + self._all_reduce_alg = all_reduce_alg + self._num_packs = num_packs + self._agg_small_grads_max_bytes = agg_small_grads_max_bytes + self._agg_small_grads_max_group = agg_small_grads_max_group super(AllReduceCrossTowerOps, self).__init__() def _reduce(self, method_string, per_device_value, destinations): + contains_indexed_slices = cross_tower_utils.contains_indexed_slices( + per_device_value) if ((destinations is None or _devices_match(per_device_value, destinations)) - and not context.executing_eagerly()): + and not context.executing_eagerly() + and not contains_indexed_slices): return self._batch_all_reduce(method_string, [per_device_value])[0] else: + if contains_indexed_slices: + logging.log_first_n( + logging.WARN, + "Efficient allreduce is not supported for IndexedSlices.", 10) + devices = _get_devices_from(destinations or per_device_value) reduce_to_device = devices[0] reduced = _simple_reduce(per_device_value, reduce_to_device, @@ -455,14 +508,18 @@ class AllReduceCrossTowerOps(CrossTowerOps): return self.broadcast(reduced, devices) def _batch_reduce(self, method_string, value_destination_pairs): - if (_all_devices_match(value_destination_pairs) and - not context.executing_eagerly()): + all_devices_match = _all_devices_match(value_destination_pairs) + contains_indexed_slices = cross_tower_utils.contains_indexed_slices( + value_destination_pairs) + if (all_devices_match and not context.executing_eagerly() + and not contains_indexed_slices): return self._batch_all_reduce(method_string, [v[0] for v in value_destination_pairs]) else: - if not context.executing_eagerly(): + if not all_devices_match: logging.warning("Efficient batch_reduce is not supported if " "destinations are different.") + return [ self._reduce(method_string, t, destinations=v) for t, v in value_destination_pairs @@ -470,37 +527,24 @@ class AllReduceCrossTowerOps(CrossTowerOps): def _batch_all_reduce(self, method_string, per_device_values): """All reduce algorithm in a batch.""" + logging.info( + "batch_all_reduce invoked for batches size = %d with " + "algorithm = %s, num_packs = %d, agg_small_grads_max_bytes = %d and " + "agg_small_grads_max_group = %d", len(per_device_values), + self._all_reduce_alg, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) destinations = per_device_values[0].devices grouped = _group_value_by_device(per_device_values) - if self.num_packs > 0: - logging.info( - "batch_all_reduce invoked for batches size = %d with " - "algorithm = %s and num_packs = %d", len(per_device_values), - self.all_reduce_alg, self.num_packs) - tensor_packer = ConcatAndSplitPacker(self.num_packs) - device_grad_packs = tensor_packer.pack(grouped) - elif (self.agg_small_grads_max_bytes > 0 and - self.agg_small_grads_max_group > 0): - logging.info( - "batch_all_reduce invoked for batches size = %d with " - "algorithm = %s, agg_small_grads_max_bytes = %d and " - "agg_small_grads_max_group = %d", len(per_device_values), - self.all_reduce_alg, self.agg_small_grads_max_bytes, - self.agg_small_grads_max_group) - tensor_packer = AggregateSmallTensorPacker( - self.agg_small_grads_max_bytes, self.agg_small_grads_max_group) - device_grad_packs = tensor_packer.pack(grouped) - else: - logging.info( - "batch_all_reduce invoked for batches size = %d with algorithm = %s", - len(per_device_values), self.all_reduce_alg) - tensor_packer = None - device_grad_packs = grouped + + device_grad_packs, self._tensor_packer = _pack_tensors( + grouped, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) # The actual aggregation of the repacked gradients. Note that they are # sharded among different aggregation trees. So it is important to strike # the balance on num_splits. - if self.all_reduce_alg == "nccl": + if self._all_reduce_alg == "nccl": + # TODO(yuefengz): merge this into the all-reduce library. reduced = cross_tower_utils.aggregate_gradients_using_nccl( device_grad_packs) else: @@ -510,13 +554,137 @@ class AllReduceCrossTowerOps(CrossTowerOps): cross_tower_utils.aggregate_gradients_using_hierarchical_copy( destinations, device_grad_packs)) - if tensor_packer: - reduced = tensor_packer.unpack(reduced) - + reduced = _unpack_tensors(reduced, self._tensor_packer) return _ungroup_and_make_mirrored(reduced, per_device_values[0].devices, method_string) +AllReduceSpecTuple = collections.namedtuple("AllReduceSpecTuple", + "alg shards limit") + + +class MultiWorkerAllReduce(AllReduceCrossTowerOps): + """All-reduce algorithms for distributed TensorFlow.""" + + def __init__(self, + worker_devices, + num_gpus_per_worker, + all_reduce_spec=("pscpu/pscpu", 2, -1), + num_packs=0, + agg_small_grads_max_bytes=0, + agg_small_grads_max_group=10): + """Initialize the all-reduce algorithm. + + Args: + worker_devices: a list of device strings for workers participating in + all-reduce. + num_gpus_per_worker: number of GPU devices per worker. + all_reduce_spec: a tuple or a named tuple or a list of tuples specifying + the all-reduce algorithm. + 1. The first element of a tuple is the name of the all-reduce algorithm. + Valid algorithm names are: "nccl", "nccl/xring", "nccl/rechd", + "nccl/pscpu", "xring", "pscpu", "psgpu", "pscpu/pscpu". Algorithms with + a "/" are hierarchical, so two all-reduces are executed, the first one + aggregates tensors within a worker and the second aggregates across + workers. + 2. The second element of a tuple is the number of shards when doing + all-reduce. Let's say its values is M, each tensor after packing will be + split into M shards and then M parallel all-reduces would be performed + before finally they are concatenated backed into a complete tensor. + 3. The third element is the maximum size of tensors that will be + applicable for the algorithm specified by the first element. For + example, if all_reduce_spec=[("nccl", 2, 1024), ("pscpu/pscpu", 2, -1)], + tensors with size not larger than 1024 bytes will be applied a 2-shard + "nccl" all-reduce and other tensors will be applied a 2-shard + "pscpu/pscpu" algorithm. The third elements should be in increasing + order across tuples and end with -1 which indicates infinity. + num_packs: see AllReduceCrossTowerOps. + agg_small_grads_max_bytes: see AllReduceCrossTowerOps. + agg_small_grads_max_group: see AllReduceCrossTowerOps. + """ + self._worker_devices = worker_devices + self._num_gpus_per_worker = num_gpus_per_worker + super(MultiWorkerAllReduce, self).__init__( + num_packs=num_packs, + agg_small_grads_max_bytes=agg_small_grads_max_bytes, + agg_small_grads_max_group=agg_small_grads_max_group) + + def validate_and_complete_spec(spec): + """Validate and complete the all-reduce spec.""" + # TODO(yuefengz): support namedtuple. + if not isinstance(spec, tuple): + raise ValueError( + "A tuple is expected for all-reduce spec: %r" % all_reduce_spec) + if not spec or len(spec) > 3: + raise ValueError( + "Too many elements in the all-reduce spec tuple: %r" % spec) + if len(spec) == 1: + return AllReduceSpecTuple(spec[0], 1, -1) + elif len(spec) == 2: + return AllReduceSpecTuple(spec[0], spec[1], -1) + else: + return AllReduceSpecTuple(*spec) + + self._all_reduce_spec = [] + if isinstance(all_reduce_spec, six.string_types): + self._all_reduce_spec.append(AllReduceSpecTuple(all_reduce_spec, 1, -1)) + elif isinstance(all_reduce_spec, tuple): + self._all_reduce_spec.append(validate_and_complete_spec(all_reduce_spec)) + elif isinstance(all_reduce_spec, list): + self._all_reduce_spec = [ + validate_and_complete_spec(spec) for spec in all_reduce_spec + ] + + def _batch_all_reduce(self, method_string, per_device_values): + """All reduce algorithm in a batch.""" + logging.info( + "distributed batch_all_reduce invoked for batches size = %d with " + "allreduce_spec = %r, num_packs = %d, agg_small_grads_max_bytes = %d " + "and agg_small_grads_max_group = %d", len(per_device_values), + self._all_reduce_spec, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) + + destinations = sorted(per_device_values[0].devices) + device_grads = _group_value_by_device(per_device_values) + + # The all reduce library requires fully defined shapes. + # TODO(yuefengz): when tensor sharding is not needed, static shapes are not + # required as well. + for device_grad in device_grads: + for grad, _ in device_grad: + if not grad.shape.is_fully_defined(): + raise ValueError("Shape is unknown for node %r" % grad) + + remaining_grads = device_grads + aggregated_grads = [] + for spec_tuple in self._all_reduce_spec: + if spec_tuple.limit < 0: + this_grads = remaining_grads + remaining_grads = [] + else: + (this_grads, remaining_grads) = cross_tower_utils.split_grads_by_size( + spec_tuple.limit, remaining_grads) + if this_grads: + device_grad_packs, self._tensor_packer = _pack_tensors( + this_grads, self._num_packs, self._agg_small_grads_max_bytes, + self._agg_small_grads_max_group) + range_agg_grads = cross_tower_utils.sum_gradients_all_reduce( + self._worker_devices, device_grad_packs, len(self._worker_devices), + spec_tuple.alg, spec_tuple.shards, range(self._num_gpus_per_worker)) + range_agg_grads = _unpack_tensors(range_agg_grads, self._tensor_packer) + + if not aggregated_grads: + aggregated_grads = range_agg_grads + else: + assert len(aggregated_grads) == len(range_agg_grads) + for i in range(len(aggregated_grads)): + aggregated_grads[i] += range_agg_grads[i] + assert not remaining_grads + + return _ungroup_and_make_mirrored(aggregated_grads, destinations, + method_string) + + _dgx1_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index 7c7b0870887465ec2fe40007695d099277db38bf..fed5505d92ef2544215069736c166a67d6141708 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -24,6 +24,7 @@ from absl.testing import parameterized from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib +from tensorflow.contrib.distribute.python import multi_worker_test_base from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.python.eager import context from tensorflow.python.eager import test @@ -31,6 +32,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.training import device_util def _make_per_device(values, devices): @@ -56,19 +58,46 @@ def _fake_mirrored(value, devices): {d: v for d, v in zip(devices, [value] * len(devices))}) +def _make_indexed_slices(values, indices, dense_shape, device): + with ops.device(device): + tensor = ops.IndexedSlices( + values=constant_op.constant(values), + indices=constant_op.constant(indices), + dense_shape=constant_op.constant(dense_shape)) + return tensor + + +def _make_mirrored_indexed_slices(devices, values, indices, dense_shape): + return value_lib.Mirrored({ + d: _make_indexed_slices(values, indices, dense_shape, d) for d in devices + }) + + _cpu_device = "/device:CPU:0" -class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): +class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase): - def _assert_value_equal(self, left, right): + def _assert_indexed_slices_equal(self, left, right): + self.assertIsInstance(left, ops.IndexedSlices) + self.assertIsInstance(right, ops.IndexedSlices) + self.assertEqual(device_util.resolve(left.device), + device_util.resolve(right.device)) + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + def _assert_values_equal(self, left, right): if isinstance(left, list): for l, r in zip(left, right): - self._assert_value_equal(l, r) + self._assert_values_equal(l, r) else: self.assertEqual(type(left), type(right)) self.assertEqual(left.devices, right.devices) - if context.executing_eagerly(): + if isinstance(list(left._index.values())[0], ops.IndexedSlices): + for (d, v) in left._index.items(): + self._assert_indexed_slices_equal(v, right._index[d]) + elif context.executing_eagerly(): self.assertEqual([v.numpy() for v in left._index.values()], list(right._index.values())) else: @@ -76,51 +105,7 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): self.assertEqual( sess.run(list(left._index.values())), list(right._index.values())) - # TODO(yuefengz): decouple the num_gpus check from distribution in - # combinations module so that we can pass in devices instead of a distribution - # strategy. - reduction_to_one_combinations = combinations.combine( - cross_tower_ops=[ - combinations.NamedObject( - "DefaultReductionToOneDeviceCrossTowerOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), - combinations.NamedObject( - "ReductionToCPUDeviceCrossTowerOps", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( - reduce_to_device=_cpu_device)), - combinations.NamedObject( - "AccumulateNCrossTowerOp", - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( - accumulation_fn=math_ops.accumulate_n)), - ], - distribution=[ - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus - ], - mode=["graph", "eager"]) - allreduce_combinations = combinations.combine( - cross_tower_ops=[ - combinations.NamedObject( - "AllReduce", - cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)), - combinations.NamedObject( - "HierarchicalCopy", - cross_tower_ops_lib.AllReduceCrossTowerOps( - "hierarchical_copy", 8, 0, 0)), - combinations.NamedObject( - "AllReduceNoGradientRepacking", - cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)), - combinations.NamedObject( - "HierarchicalCopyAggregateSmallTensors", - cross_tower_ops_lib.AllReduceCrossTowerOps( - "hierarchical_copy", 0, 100, 10)) - ], - distribution=[combinations.mirrored_strategy_with_two_gpus], - mode=["graph", "eager"]) - - @combinations.generate(reduction_to_one_combinations + allreduce_combinations) - def testReductionAndBroadcast(self, cross_tower_ops, distribution): + def _testReductionAndBroadcast(self, cross_tower_ops, distribution): devices = distribution.worker_devices values = [constant_op.constant(float(d)) for d in range(len(devices))] @@ -143,29 +128,29 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): # test reduce() for destinations in all_destinations: - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce("mean", per_device, destinations=destinations), _fake_mirrored(mean, destinations or per_device)) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce( "mean", per_device_2, destinations=destinations), _fake_mirrored(mean_2, destinations or per_device)) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce("sum", per_device, destinations=destinations), _fake_mirrored(mean * len(devices), destinations or per_device)) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.reduce( "sum", per_device_2, destinations=destinations), _fake_mirrored(mean_2 * len(devices), destinations or per_device)) # test batch_reduce() for d1, d2 in itertools.product(all_destinations, all_destinations): - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.batch_reduce( "mean", [(per_device, d1), (per_device_2, d2)]), [_fake_mirrored(mean, d1 or per_device), _fake_mirrored(mean_2, d2 or per_device_2)]) - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.batch_reduce( "sum", [(per_device, d1), (per_device_2, d2)]), [_fake_mirrored(mean * len(devices), d1 or per_device), @@ -176,45 +161,205 @@ class CrossTowerOpsTest(test.TestCase, parameterized.TestCase): if destinations is None: continue else: - self._assert_value_equal( + self._assert_values_equal( cross_tower_ops.broadcast(constant_op.constant(1.), destinations), _fake_mirrored(1., destinations)) + +class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase): + # TODO(yuefengz): decouple the num_gpus check from distribution in + # combinations module so that we can pass in devices instead of a distribution + # strategy. + reduction_to_one_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "DefaultReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "ReductionToCPUDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + reduce_to_device=_cpu_device)), + combinations.NamedObject( + "AccumulateNCrossTowerOp", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps( + accumulation_fn=math_ops.accumulate_n)), + ], + distribution=[ + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus + ], + mode=["graph", "eager"]) + allreduce_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "AllReduce", + cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 1, 0, 0)), + combinations.NamedObject( + "HierarchicalCopy", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 8, 0, 0)), + combinations.NamedObject( + "AllReduceNoGradientRepacking", + cross_tower_ops_lib.AllReduceCrossTowerOps("nccl", 0, 0, 0)), + combinations.NamedObject( + "HierarchicalCopyAggregateSmallTensors", + cross_tower_ops_lib.AllReduceCrossTowerOps( + "hierarchical_copy", 0, 100, 10)) + ], + distribution=[combinations.mirrored_strategy_with_two_gpus], + mode=["graph", "eager"]) + + @combinations.generate(reduction_to_one_combinations + allreduce_combinations) + def testReductionAndBroadcast(self, cross_tower_ops, distribution): + with distribution.scope(): + self._testReductionAndBroadcast(cross_tower_ops, distribution) + def testChooseAlgorithm(self): device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) - self.assertEqual(result.all_reduce_alg, "hierarchical_copy") - self.assertEqual(result.num_packs, 8) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "hierarchical_copy") + self.assertEqual(result._num_packs, 8) # if there are only 4 devices device_links = [[1, 2, 3, 4], [0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) - self.assertEqual(result.all_reduce_alg, "nccl") - self.assertEqual(result.num_packs, 1) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "nccl") + self.assertEqual(result._num_packs, 1) # if devices links contain each device itself device_links = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 5], [0, 1, 2, 3, 6], [0, 1, 2, 3, 7], [0, 4, 5, 6, 7], [1, 4, 5, 6, 7], [2, 4, 5, 6, 7], [3, 4, 5, 6, 7]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) - self.assertEqual(result.all_reduce_alg, "hierarchical_copy") - self.assertEqual(result.num_packs, 8) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "hierarchical_copy") + self.assertEqual(result._num_packs, 8) # if not dgx1-like links device_links = [[0, 2, 3, 5], [0, 1, 3, 6], [0, 1, 2, 7], [0, 5, 6, 7], [1, 4, 6, 7], [2, 4, 5, 7], [3, 4, 5, 6], [1, 2, 3, 4]] result = cross_tower_ops_lib._choose_all_reduce_algorithm(device_links) - self.assertTrue( - isinstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps)) - self.assertEqual(result.all_reduce_alg, "nccl") - self.assertEqual(result.num_packs, 1) + self.assertIsInstance(result, cross_tower_ops_lib.AllReduceCrossTowerOps) + self.assertEqual(result._all_reduce_alg, "nccl") + self.assertEqual(result._num_packs, 1) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testSimpleReduceWithIndexedSlices(self): + devices = ["/cpu:0", "/gpu:0"] + t0 = _make_indexed_slices([[1., 2.]], [1], [5, 2], devices[0]) + t1 = _make_indexed_slices([[3., 4.], [5., 6.]], [1, 3], [5, 2], devices[1]) + per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + result = cross_tower_ops_lib._simple_reduce(per_device, devices[0], + math_ops.add_n, "sum") + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices with and without duplicate indices. + total_with_dups = _make_indexed_slices( + [[1., 2.], [3., 4.], [5., 6.]], [1, 1, 3], [5, 2], devices[0]) + total_without_dups = _make_indexed_slices( + [[4., 6.], [5., 6.]], [1, 3], [5, 2], devices[0]) + self._assert_indexed_slices_equal(total_with_dups, result) + self._assert_indexed_slices_equal(total_without_dups, result) + + @combinations.generate(combinations.combine( + cross_tower_ops_instance=[ + combinations.NamedObject( + "ReductionToOneDeviceCrossTowerOps", + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()), + combinations.NamedObject( + "AllReduceCrossTowerOps", + cross_tower_ops_lib.AllReduceCrossTowerOps()) + ], + method_string=["sum", "mean"], + batch_reduce=[True, False], + mode=["graph", "eager"], + required_gpus=1)) + def testIndexedSlicesAllReduce(self, cross_tower_ops_instance, + method_string, batch_reduce): + devices = ["/cpu:0", "/gpu:0"] + dense_shape = [5, 2] + t0 = _make_indexed_slices([[1., 2.]], [1], dense_shape, devices[0]) + t1 = _make_indexed_slices( + [[3., 4.], [5., 6.]], [1, 3], dense_shape, devices[1]) + per_device = value_lib.PerDevice({devices[0]: t0, devices[1]: t1}) + + if batch_reduce: + result = cross_tower_ops_instance.batch_reduce(method_string, + [(per_device, devices)]) + else: + result = cross_tower_ops_instance.reduce(method_string, per_device, + devices) + + total_indices_with_dups = [1, 1, 3] + total_indices_without_dups = [1, 3] + + if method_string == "sum": + total_values_with_dups = [[1., 2.], [3., 4.], [5., 6.]] + total_values_without_dups = [[4., 6.], [5., 6.]] + else: + assert method_string == "mean" + total_values_with_dups = [[0.5, 1.], [1.5, 2.], [2.5, 3.]] + total_values_without_dups = [[2., 3.], [2.5, 3.]] + + total_mirrored_with_dups = _make_mirrored_indexed_slices( + devices, total_values_with_dups, total_indices_with_dups, dense_shape) + total_mirrored_without_dups = _make_mirrored_indexed_slices( + devices, total_values_without_dups, total_indices_without_dups, + dense_shape) + + # Test that the result is semantically equal to both the concatenated + # IndexedSlices, as well as when the duplicate indices are summed up. + if batch_reduce: + total_mirrored_with_dups = [total_mirrored_with_dups] + total_mirrored_without_dups = [total_mirrored_without_dups] + + self._assert_values_equal(total_mirrored_with_dups, result) + self._assert_values_equal(total_mirrored_without_dups, result) + + +class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase, + CrossTowerOpsTestBase): + + worker_devices = [ + "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1" + ] + multi_worker_allreduce_combinations = combinations.combine( + cross_tower_ops=[ + combinations.NamedObject( + "MultiWorkerAllReduce", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 0, 0)), + combinations.NamedObject( + "MultiWorkerAllReducePack", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 1, 0, 0)), + combinations.NamedObject( + "MultiWorkerAllReduceAggregation", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, ("pscpu/pscpu", 2, -1), 0, 100, 10)), + combinations.NamedObject( + "MultiWorkerAllReduceMultipleSpecs", + cross_tower_ops_lib.MultiWorkerAllReduce( + worker_devices, 2, [("pscpu/pscpu", 2, 100), + ("xring", 2, -1)], 0, 0, 0)), + ], + distribution=[ + combinations.multi_worker_strategy_with_cpu, + combinations.multi_worker_strategy_with_one_gpu, + combinations.multi_worker_strategy_with_two_gpus + ], + mode=["graph"]) + + @combinations.generate(multi_worker_allreduce_combinations) + def testReductionAndBroadcast(self, cross_tower_ops, distribution): + with distribution.scope(): + self._testReductionAndBroadcast(cross_tower_ops, distribution) if __name__ == "__main__": diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils.py b/tensorflow/contrib/distribute/python/cross_tower_utils.py index fc04e2195f6d305e0f7c642f24c355286f1a8cfa..2bb088e704c584598b863b1b836166af2a5bb12c 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_utils.py +++ b/tensorflow/contrib/distribute/python/cross_tower_utils.py @@ -21,9 +21,12 @@ from __future__ import print_function import collections as pycoll from tensorflow.contrib import nccl +from tensorflow.contrib.all_reduce.python import all_reduce +from tensorflow.contrib.distribute.python import values as value_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops @@ -156,6 +159,148 @@ def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, return (grad, v), None +def group_device_names(devices, group_size): + """Group device names into groups of group_size. + + Args: + devices: a list of canonical device strings. + group_size: integer which is equal to or greater than 1. + + Returns: + list of lists of devices, where each inner list is group_size long, + and each device appears at least once in an inner list. If + len(devices) % group_size == 0 then each device will appear exactly once. + + Raises: + ValueError: if group_size > len(devices) + """ + num_devices = len(devices) + if group_size > num_devices: + raise ValueError( + 'only %d devices, but group_size=%d' % (num_devices, group_size)) + num_groups = ( + num_devices // group_size + (1 if (num_devices % group_size != 0) else 0)) + groups = [[] for i in range(num_groups)] + for i in range(num_groups * group_size): + groups[i % num_groups].append(devices[i % num_devices]) + return groups + + +def split_grads_by_size(threshold_size, device_grads): + """Break gradients into two sets according to tensor size. + + Args: + threshold_size: int size cutoff for small vs large tensor. + device_grads: List of lists of (gradient, variable) tuples. The outer + list is over devices. The inner list is over individual gradients. + + Returns: + small_grads: Subset of device_grads where shape is <= threshold_size + elements. + large_grads: Subset of device_grads where shape is > threshold_size + elements. + """ + small_grads = [] + large_grads = [] + for dl in device_grads: + small_dl = [] + large_dl = [] + for (g, v) in dl: + tensor_size = g.get_shape().num_elements() + if tensor_size <= threshold_size: + small_dl.append([g, v]) + else: + large_dl.append([g, v]) + if small_dl: + small_grads.append(small_dl) + if large_dl: + large_grads.append(large_dl) + return small_grads, large_grads + + +def sum_grad_and_var_all_reduce(grad_and_vars, + num_workers, + alg, + gpu_indices, + aux_devices=None, + num_shards=1): + """Apply all-reduce algorithm over specified gradient tensors.""" + with ops.name_scope('allreduce'): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + scaled_grads = [g for g, _ in grad_and_vars] + if alg == 'nccl': + summed_grads = nccl.all_sum(scaled_grads) + elif alg == 'xring': + summed_grads = all_reduce.build_ring_all_reduce( + scaled_grads, num_workers, num_shards, gpu_indices, math_ops.add) + elif alg == 'nccl/xring': + summed_grads = all_reduce.build_nccl_then_ring(scaled_grads, num_shards, + math_ops.add) + elif alg == 'nccl/rechd': + summed_grads = all_reduce.build_nccl_then_recursive_hd( + scaled_grads, math_ops.add) + elif alg == 'nccl/pscpu': + summed_grads = all_reduce.build_nccl_then_shuffle( + scaled_grads, aux_devices, math_ops.add, math_ops.add_n) + elif alg == 'pscpu/pscpu': + second_gather_devices = aux_devices[:num_shards] + summed_grads = all_reduce.build_shuffle_then_shuffle( + scaled_grads, aux_devices, second_gather_devices, math_ops.add_n) + elif alg in ['pscpu', 'psgpu']: + summed_grads = all_reduce.build_shuffle_all_reduce( + scaled_grads, aux_devices, math_ops.add_n) + else: + raise ValueError('unsupported all_reduce alg: ', alg) + + result = [] + for (_, v), g in zip(grad_and_vars, summed_grads): + result.append([g, v]) + return result + + +def sum_gradients_all_reduce(dev_prefixes, tower_grads, num_workers, alg, + num_shards, gpu_indices): + """Apply all-reduce algorithm over specified gradient tensors. + + Args: + dev_prefixes: list of prefix strings to use to generate PS device names. + tower_grads: the gradients to reduce. + num_workers: number of worker processes across entire job. + alg: the all-reduce algorithm to apply. + num_shards: alg-specific sharding factor. + gpu_indices: indices of local GPUs in order usable for ring-reduce. + + Returns: + list of reduced tensors + """ + alg_contains_shuffle = any([n in alg for n in ['pscpu', 'psgpu']]) + is_hierarchical = '/' in alg + if 'pscpu' in alg: + aux_devices = [prefix + '/cpu:0' for prefix in dev_prefixes] + elif 'psgpu' in alg: + aux_devices = [ + prefix + '/gpu:%d' % i + for i in range(len(gpu_indices)) + for prefix in dev_prefixes + ] + else: + aux_devices = ['/job:localhost/cpu:0'] + # Auxiliary devices for hierarchical all-reduces. + aux_device_groups = group_device_names( + aux_devices, num_shards if alg_contains_shuffle else 1) + group_index = 0 + reduced_gv_list = [] + for grad_and_vars in zip(*tower_grads): + reduced_gv_list.append( + sum_grad_and_var_all_reduce( + grad_and_vars, num_workers, alg, gpu_indices, aux_devices + if is_hierarchical else aux_device_groups[group_index], num_shards)) + group_index = (group_index + 1) % len(aux_device_groups) + new_tower_grads = [list(x) for x in zip(*reduced_gv_list)] + return new_tower_grads + + def extract_ranges(index_list, range_size_limit=32): """Extract consecutive ranges and singles from index_list. @@ -328,7 +473,7 @@ def unpack_small_tensors(tower_grads, packing): for dev_idx, gv_list in enumerate(tower_grads): gv_list = list(gv_list) new_gv_list = gv_list[num_packed:] - for i in xrange(0, num_packed): + for i in range(num_packed): k = '%d:%d' % (dev_idx, i) gpt = packing[k] gv = unpack_grad_tuple(gv_list[i], gpt) @@ -337,3 +482,46 @@ def unpack_small_tensors(tower_grads, packing): new_gv_list.insert(idx, gv[gi]) new_tower_grads.append(new_gv_list) return new_tower_grads + + +def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): + """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" + if any(isinstance(v, ops.IndexedSlices) for v in values): + return gradients_impl._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access + else: + return accumulation_fn(values) + + +def divide_by_n_tensors_or_indexed_slices(value, n): + if isinstance(value, ops.IndexedSlices): + value = gradients_impl._HandleNestedIndexedSlices(value) # pylint: disable=protected-access + return ops.IndexedSlices( + value.values / n, value.indices, value.dense_shape) + else: + return value / n + + +def copy_tensor_or_indexed_slices_to_device(value, device): + with ops.device(device): + if isinstance(value, ops.IndexedSlices): + copied_values = array_ops.identity(value.values) + copied_indices = array_ops.identity(value.indices) + copied_shape = array_ops.identity(value.dense_shape) + result = ops.IndexedSlices(copied_values, copied_indices, copied_shape) + else: + result = array_ops.identity(value) + return result + + +def contains_indexed_slices(value): + """Check whether the value is `IndexedSlices` or contains `IndexedSlices`.""" + if isinstance(value, ops.IndexedSlices): + return True + elif isinstance(value, (list, tuple)) and value: + return any(contains_indexed_slices(v) for v in value) + elif isinstance(value, value_lib.DistributedValues): + return contains_indexed_slices(list(value._index.values())) # pylint: disable=protected-access + elif isinstance(value, value_lib.MapOutput): + return contains_indexed_slices(value.get()) + else: + return False diff --git a/tensorflow/contrib/distribute/python/cross_tower_utils_test.py b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4ef8db681503dcef8c72f641455dbb999cef05cf --- /dev/null +++ b/tensorflow/contrib/distribute/python/cross_tower_utils_test.py @@ -0,0 +1,152 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for cross_tower_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.distribute.python import combinations +from tensorflow.contrib.distribute.python import cross_tower_utils +from tensorflow.contrib.distribute.python import values as value_lib +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.training import device_util + + +class IndexedSlicesUtilsTest(test.TestCase, parameterized.TestCase): + + def _assert_values_equal(self, left, right): + self.assertAllEqual( + self.evaluate(ops.convert_to_tensor(left)), + self.evaluate(ops.convert_to_tensor(right))) + + @test_util.run_in_graph_and_eager_modes() + def testAggregateTensors(self): + t0 = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + t1 = constant_op.constant([[0., 0.], [5, 6], [7., 8.]]) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes() + def testAggregateIndexedSlices(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + total = constant_op.constant([[1., 2.], [5, 6], [10., 12.]]) + result = cross_tower_utils.aggregate_tensors_or_indexed_slices([t0, t1]) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(total, result) + + @test_util.run_in_graph_and_eager_modes() + def testDivideTensor(self): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes() + def testDivideIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + n = 2 + expected = constant_op.constant([[0.5, 1.], [0, 0], [1.5, 2.]]) + result = cross_tower_utils.divide_by_n_tensors_or_indexed_slices(t, n) + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(expected, result) + + @test_util.run_in_graph_and_eager_modes() + def testIsIndexedSlices(self): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices(t)) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_List(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices([t0, t1])) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_Tuple(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + self.assertTrue(cross_tower_utils.contains_indexed_slices((t0, t1))) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_PerDevice(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_device = value_lib.PerDevice({"/gpu:0": t0, "/cpu:0": t1}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + + @test_util.run_in_graph_and_eager_modes() + def testContainsIndexedSlices_PerDeviceMapOutput(self): + t0 = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + t1 = math_ops._as_indexed_slices( + constant_op.constant([[0., 0.], [5, 6], [7., 8.]])) + per_device = value_lib.PerDevice({ + "/gpu:0": value_lib.MapOutput([t0]), + "/cpu:0": value_lib.MapOutput([t1])}) + self.assertTrue(cross_tower_utils.contains_indexed_slices(per_device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyTensor(self): + with ops.device("/cpu:0"): + t = constant_op.constant([[1., 2.], [0, 0], [3., 4.]]) + destination = "/gpu:0" + result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + @combinations.generate(combinations.combine( + mode=["graph", "eager"], + required_gpus=1)) + def testCopyIndexedSlices(self): + with ops.device("/cpu:0"): + t = math_ops._as_indexed_slices( + constant_op.constant([[1., 2.], [0, 0], [3., 4.]])) + destination = "/gpu:0" + result = cross_tower_utils.copy_tensor_or_indexed_slices_to_device( + t, destination) + + self.assertIsInstance(result, ops.IndexedSlices) + self._assert_values_equal(t, result) + self.assertEqual(device_util.resolve(destination), + device_util.resolve(result.device)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py new file mode 100644 index 0000000000000000000000000000000000000000..75ecd90dcffa7a786b78238ef453c4c8e4346afa --- /dev/null +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -0,0 +1,148 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Keras Sequential and Functional models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from tensorflow.contrib.distribute.python import mirrored_strategy +from tensorflow.python import keras +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.estimator import keras as keras_lib +from tensorflow.python.estimator import run_config as run_config_lib +from tensorflow.python.framework import test_util +from tensorflow.python.keras import testing_utils +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import rmsprop + +_RANDOM_SEED = 1337 +_TRAIN_SIZE = 200 +_INPUT_SIZE = (10,) +_NUM_CLASS = 2 + + +def simple_sequential_model(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE)) + model.add(keras.layers.Dropout(0.1)) + model.add(keras.layers.Dense(_NUM_CLASS, activation='softmax')) + return model + + +def simple_functional_model(): + a = keras.layers.Input(shape=_INPUT_SIZE) + b = keras.layers.Dense(16, activation='relu')(a) + b = keras.layers.Dropout(0.1)(b) + b = keras.layers.Dense(_NUM_CLASS, activation='softmax')(b) + model = keras.models.Model(inputs=[a], outputs=[b]) + return model + + +def get_ds_train_input_fn(): + np.random.seed(_RANDOM_SEED) + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_train = keras.utils.to_categorical(y_train) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_train, y_train)) + dataset = dataset.batch(32) + return dataset + + +def get_ds_test_input_fn(): + np.random.seed(_RANDOM_SEED) + _, (x_test, y_test) = testing_utils.get_test_data( + train_samples=_TRAIN_SIZE, + test_samples=50, + input_shape=_INPUT_SIZE, + num_classes=_NUM_CLASS) + y_test = keras.utils.to_categorical(y_test) + + dataset = dataset_ops.Dataset.from_tensor_slices((x_test, y_test)) + dataset = dataset.batch(32) + return dataset + + +class TestKerasDistributionStrategy(test_util.TensorFlowTestCase): + + def setUp(self): + self._base_dir = os.path.join(self.get_temp_dir(), + 'keras_mirrored_strategy_test') + gfile.MakeDirs(self._base_dir) + self._config = run_config_lib.RunConfig( + tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir) + + def tearDown(self): + writer_cache.FileWriterCache.clear() + if os.path.isdir(self._base_dir): + gfile.DeleteRecursively(self._base_dir) + + def test_train_functional_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_functional_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=config) + before_eval_results = est_keras.evaluate( + input_fn=get_ds_test_input_fn, steps=1) + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, + steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + + def test_train_sequential_with_distribution_strategy(self): + dist = mirrored_strategy.MirroredStrategy( + devices=['/device:GPU:0', '/device:GPU:1']) + keras_model = simple_sequential_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer=rmsprop.RMSPropOptimizer(learning_rate=0.01)) + config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED, + model_dir=self._base_dir, + train_distribute=dist) + with self.test_session(): + est_keras = keras_lib.model_to_estimator( + keras_model=keras_model, config=config) + before_eval_results = est_keras.evaluate( + input_fn=get_ds_test_input_fn, steps=1) + est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16) + after_eval_results = est_keras.evaluate(input_fn=get_ds_test_input_fn, + steps=1) + self.assertLess(after_eval_results['loss'], before_eval_results['loss']) + + writer_cache.FileWriterCache.clear() + gfile.DeleteRecursively(self._config.model_dir) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6c6bf143098c1bba64d47efce1bfface7682683d --- /dev/null +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -0,0 +1,438 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for V1 metrics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import variables + + +def _labeled_dataset_fn(): + # First four batches of x: labels, predictions -> (labels == predictions) + # 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False + # 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False + # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False + # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True + return dataset_ops.Dataset.range(1000).map( + lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4) + + +def _boolean_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # T, T -> TP; F, T -> FP; T, F -> FN + # F, F -> TN; T, T -> TP; F, T -> FP + # T, F -> FN; F, F -> TN; T, T -> TP + # F, T -> FP; T, F -> FN; F, F -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [True, True, False, False]}).repeat().batch(3) + + +def _threshold_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN + # False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP + # True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP + # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3) + + +def _regression_dataset_fn(): + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [1., .5, 1., 0.], + "predictions": [1., .75, .25, 0.]}).repeat() + + +def all_combinations(): + return combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + mode=["graph"]) + + +# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k, +# metrics.precision_at_k +class MetricsV1Test(test.TestCase, parameterized.TestCase): + + def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): + with ops.Graph().as_default(), distribution.scope(): + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + value, update = distribution.call_for_each_tower( + metric_fn, iterator.get_next()) + update = distribution.group(update) + self.evaluate(variables.local_variables_initializer()) + # TODO(josh11b): Once we switch to using a global batch size for input, + # replace "distribution.num_towers" with "1". + batches_per_update = distribution.num_towers + + # Update variables using the first `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value), + 0.001, msg="After first update") + + # Update variables using the second `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(2 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After second update") + + if batches_per_update == 1: # Consume 4 input batches + self.evaluate(update) + self.assertAllClose(expected_fn(3 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After third update") + self.evaluate(update) + self.assertAllClose(expected_fn(4 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After fourth update") + + @combinations.generate(all_combinations()) + def testMean(self, distribution): + def _dataset_fn(): + return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4) + + def _expected_fn(num_batches): + # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc. + return num_batches * 2 - 0.5 + + self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn) + + @combinations.generate(all_combinations()) + def testAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.accuracy(labels, predictions) + + def _expected_fn(num_batches): + return [3./4, 3./8, 3./12, 4./16][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanPerClassAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_per_class_accuracy( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1., 1., 1., 0., 0.]), + mean([0.5, 0.5, 0.5, 0., 0.]), + mean([1./3, 1./3, 0.5, 0., 0.]), + mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanIOU(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_iou( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch + mean([1./4, 1./4, 1./3, 0., 0.]), + mean([1./6, 1./6, 1./5, 0., 0.]), + mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanTensor(self, distribution): + def _dataset_fn(): + dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) + # Want to produce a fixed, known shape, so drop remainder when batching. + dataset = dataset.apply(batching.batch_and_drop_remainder(4)) + return dataset + + def _expected_fn(num_batches): + # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2 + # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1 + # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches + # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1 + first = 2. * num_batches - 2. + return [first, first + 1., first + 2., first + 3.] + + self._test_metric( + distribution, _dataset_fn, metrics.mean_tensor, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCROC(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.5, 7./9, 0.8, 0.75][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCPR(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="PR", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[0.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 3., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [3.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecision(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 0.5, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecisionAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecall(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecallAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1./32, 0.208333, 0.15625][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRootMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.root_mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 0.176777, 0.456435, 0.395285][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSensitivityAtSpecificity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.sensitivity_at_specificity(labels, predictions, 0.8) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSpecificityAtSensitivity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.specificity_at_sensitivity(labels, predictions, 0.95) + + def _expected_fn(num_batches): + return [0., 1./3, 0.5, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 89f2c431fece63269928fec6aa6d23b5a79ba0b9..900aa10e93e8881aa236bac8a2873d5c5531c6f6 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import contextlib import threading import six @@ -30,6 +31,7 @@ from tensorflow.python.eager import tape from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import coordinator from tensorflow.python.training import device_util @@ -39,6 +41,16 @@ from tensorflow.python.training import distribute as distribute_lib # TODO(josh11b): Replace asserts in this file with if ...: raise ... +@contextlib.contextmanager +def _enter_graph(g): + if context.executing_eagerly(): + with g.as_default(), context.eager_mode(): + yield + else: + with g.as_default(): + yield + + def _cpu_device(device): cpu_device = tf_device.DeviceSpec.from_string(device) cpu_device.merge_from(tf_device.DeviceSpec(device_type="CPU", device_index=0)) @@ -73,9 +85,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): assert len(set(devices)) == len(devices), ( "No duplicates allowed in `devices` argument.") # TODO(josh11b): Require at least 2 devices? - self._devices = devices - self._canonical_device_set = set( - [device_util.canonicalize(d) for d in devices]) + self._devices = [device_util.resolve(d) for d in devices] + self._canonical_device_set = set(self._devices) self._device_index = values.PerDevice( dict((d, i) for i, d in enumerate(devices))) self._cross_tower_ops = cross_tower_ops @@ -108,7 +119,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if i > 0: # Give replicas meaningful distinct names: var0name = index[devices[0]].name.split(":")[0] - kwargs["name"] = "%s/replica_%d" % (var0name, i) + # We append a / to variable names created on towers with id > 0 to + # ensure that we ignore the name scope and instead use the given + # name as the absolute name of the variable. + kwargs["name"] = "%s/replica_%d/" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): kwargs["initial_value"] = array_ops.identity( @@ -248,8 +262,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): {t.device: t.merge_args for t in threads}) merge_kwargs = values.regroup( {t.device: t.merge_kwargs for t in threads}) - merge_result = threads[0].merge_fn( - self, *merge_args, **merge_kwargs) + # We capture the name_scope of the MTT when we call merge_fn + # to ensure that if we have opened a name scope in the MTT, + # it will be respected when executing the merge function. We only + # capture the name_scope from the first MTT and assume it is + # the same for all other MTTs. + mtt_captured_name_scope = threads[0].captured_name_scope + with ops.name_scope(mtt_captured_name_scope): + merge_result = threads[0].merge_fn( + self, *merge_args, **merge_kwargs) for t in threads: t.merge_result = values.select_device(t.device, merge_result) finally: @@ -323,6 +344,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): **values.select_device_mirrored(d, kwargs)) return values.regroup(updates, values.Mirrored) + def read_var(self, tower_local_var): + """Read the aggregate value of a tower-local variable.""" + if isinstance(tower_local_var, values.TowerLocalVariable): + return math_ops.add_n(self.unwrap(tower_local_var)) + assert isinstance(tower_local_var, values.Mirrored) + return array_ops.identity(tower_local_var.get()) + def _fetch(self, val, destination, fn): """Return a copy of `val` or `fn(val)` on `destination`.""" if isinstance(val, values.TowerLocalVariable): @@ -389,7 +417,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # pylint: disable=protected-access return list(colocate_with._index.keys()) elif isinstance(colocate_with, six.string_types): - return [colocate_with] + return [device_util.resolve(colocate_with)] + elif isinstance(colocate_with, list): + return [device_util.resolve(d) for d in colocate_with] else: return colocate_with @@ -416,6 +446,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self.merge_args = None self.merge_kwargs = None self.merge_result = None + self.captured_name_scope = None # We use a thread.Event for the main thread to signal when this # thread should start running (`should_run`), and another for # this thread to transfer control back to the main thread @@ -439,13 +470,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self._variable_creator_stack = self.graph._variable_creator_stack[:] self._captured_var_scope = variable_scope.get_variable_scope() # Adding a "/" at end lets us re-enter this scope later. - self._captured_name_scope = self.graph.get_name_scope() - if self._captured_name_scope: - self._captured_name_scope += "/" + self._name_scope = self.graph.get_name_scope() + if self._name_scope: + self._name_scope += "/" if self.tower_id > 0: - if not self._captured_name_scope: - self._captured_name_scope = "" - self._captured_name_scope += "tower_%d/" % self.tower_id + if not self._name_scope: + self._name_scope = "" + self._name_scope += "tower_%d/" % self.tower_id def run(self): # pylint: disable=protected-access @@ -458,10 +489,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): with self.coord.stop_on_exception(), \ context.context()._mode(self.context_mode), \ context.context().device_policy(self.context_device_policy), \ - self.graph.as_default(), \ + _enter_graph(self.graph), \ MirroredTowerContext(self.distribution, self.tower_id), \ ops.device(self.device), \ - ops.name_scope(self._captured_name_scope), \ + ops.name_scope(self._name_scope), \ variable_scope.variable_scope( self._captured_var_scope, reuse=self.tower_id > 0), \ variable_scope.variable_creator_scope(self.variable_creator_fn): @@ -487,6 +518,10 @@ class MirroredTowerContext(distribute_lib.TowerContext): t.merge_fn = fn t.merge_args = args t.merge_kwargs = kwargs + t.captured_name_scope = t.graph.get_name_scope() + # Adding a "/" at end lets us re-enter this scope later. + if t.captured_name_scope: + t.captured_name_scope += "/" t.has_paused.set() t.should_run.wait() t.should_run.clear() diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 3f9a02b249dde9a66056ed8952b664bbc3f74ead..bccd278847e3c87080af3cb15665e7a0d802d8fb 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -438,6 +438,74 @@ class MirroredStrategyVariableCreationTest(test.TestCase): self.assertEquals("foo/" + name + ":0", v0.name) self.assertEquals("tower_1/foo/" + name + ":0", v1.name) + # variable_scope.variable() respects name scopes when creating + # variables. On the other hand variable_scope.get_variable() ignores name + # scopes when creating variables. We test both methods of creating variables + # to make sure that we have the same variable names in both cases. + def testNameScopeWithVariable(self): + def in_cross_tower(_): + c = variable_scope.variable(1.0, name="c") + return c + + def model_fn(): + b = variable_scope.variable(1.0, name="b") + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.variable(1.0, name="a") + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("main/a:0", a0.name) + self.assertEquals("main/a/replica_1:0", a1.name) + self.assertEquals("main/b:0", b0.name) + self.assertEquals("main/b/replica_1:0", b1.name) + self.assertEquals("main/foo/c:0", c0.name) + self.assertEquals("main/foo/c/replica_1:0", c1.name) + + def testNameScopeWithGetVariable(self): + def in_cross_tower(_): + c = variable_scope.get_variable("c", [1]) + return c + + def model_fn(): + b = variable_scope.get_variable("b", [1]) + with ops.name_scope("foo"): + c = distribute_lib.get_tower_context().merge_call(in_cross_tower) + return b, c + + dist = mirrored_strategy.MirroredStrategy( + ["/device:GPU:0", "/device:CPU:0"]) + + with context.graph_mode(), dist.scope(): + with ops.name_scope("main"): + a = variable_scope.get_variable("a", [1]) + result = dist.call_for_each_tower(model_fn, run_concurrently=False) + result_b = result[0] + result_c = result[1] + self.assertIsInstance(result_b, values.DistributedValues) + self.assertIsInstance(result_c, values.DistributedValues) + a0, a1 = dist.unwrap(a) + b0, b1 = dist.unwrap(result_b) + c0, c1 = dist.unwrap(result_c) + self.assertEquals("a:0", a0.name) + self.assertEquals("a/replica_1:0", a1.name) + self.assertEquals("b:0", b0.name) + self.assertEquals("b/replica_1:0", b1.name) + self.assertEquals("c:0", c0.name) + self.assertEquals("c/replica_1:0", c1.name) + def testDynamicRnnVariables(self): def model_fn(): inputs = constant_op.constant(2 * [2 * [[0.0, 1.0, 2.0, 3.0, 4.0]]]) diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py index 8277e1e7919e86ef616b31d0986589dcc9c49bbd..4fdb9bf69b4f6ad76b79fd298f5303f24a1bd455 100644 --- a/tensorflow/contrib/distribute/python/monitor_test.py +++ b/tensorflow/contrib/distribute/python/monitor_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.distribute.python import combinations from tensorflow.contrib.distribute.python import monitor as monitor_lib from tensorflow.contrib.distribute.python import one_device_strategy from tensorflow.contrib.distribute.python.single_loss_example import single_loss_example +from tensorflow.python.client import session from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import ops @@ -65,7 +66,7 @@ class MonitorTest(test.TestCase, parameterized.TestCase): step_function, _ = single_loss_example( lambda: gradient_descent.GradientDescentOptimizer(0.2), distribution) - with self.test_session() as sess: + with session.Session() as sess, context.eager_mode(): with self.assertRaisesRegexp(ValueError, "Should not provide"): _ = monitor_lib.Monitor(step_function, sess) diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py index a552b370ebf359464afcaf3211119e73434e0dfb..0f21a427320510635279f80c11711e81715ec37c 100644 --- a/tensorflow/contrib/distribute/python/multi_worker_strategy.py +++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py @@ -121,7 +121,7 @@ class MultiWorkerMirroredStrategy(MirroredStrategy): worker: [device_util.canonicalize(worker, '/device:CPU:0')] for worker in self._workers } - self._devices = nest.flatten(self._worker_device_map.values()) + self._devices = nest.flatten(self._worker_device_map) super(MultiWorkerMirroredStrategy, self).__init__( devices=self._devices, prefetch_on_device=prefetch_on_device) diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index 09b6d4a515ab46879520f304cd5ef60469512380..7f4bab9d93814eb70a2a1586fc291a16b2766b90 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -102,6 +102,10 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.device(self._device), distribute_lib.UpdateContext(self._device): return fn(*args, **kwargs) + def read_var(self, tower_local_var): + """Read the aggregate value of a tower-local variable.""" + return array_ops.identity(tower_local_var) + def _fetch(self, val, destination, fn): """Return a copy of `val` or `fn(val)` on `destination`.""" with ops.device(self._device): diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 49b4e24daa4ffe417712bc854aa29995d5afc408..9572ade8e497fa13a7ca0746399d3e0237ee79fd 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -65,9 +65,10 @@ class DistributedValues(object): device = device_util.canonicalize(device) try: return self._index[device] - except KeyError: - raise ValueError("Device %s not found in %s (current device %s)" % - (device, self._index.keys(), device_util.current())) + except KeyError as e: + six.raise_from( + ValueError("Device %s not found in %s (current device %s)" % + (device, self._index.keys(), device_util.current())), e) def on_device(self, device): device = device_util.canonicalize(device) diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 6192f04c8b695d124b498850ad430823b44fd472..ad00d1734dd14ed846522a33d888a5387cb25cc6 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -16,6 +16,13 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( name = "bijectors_py", srcs = glob(["python/ops/bijectors/*.py"]), + deprecation = ("TensorFlow Distributions has migrated to " + + "TensorFlow Probability " + + "(https://github.com/tensorflow/probability). " + + "Deprecated copies remaining in tf.contrib.distributions " + + "are unmaintained, unsupported, and will be removed by " + + "late 2018. You should update all usage of " + + "`tf.contrib.distributions` to `tfp.distributions`."), srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/linalg:linalg_py", @@ -42,6 +49,13 @@ py_library( py_library( name = "distributions_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + deprecation = ("TensorFlow Distributions has migrated to " + + "TensorFlow Probability " + + "(https://github.com/tensorflow/probability). " + + "Deprecated copies remaining in tf.contrib.distributions " + + "are unmaintained, unsupported, and will be removed by " + + "late 2018. You should update all usage of " + + "`tf.contrib.distributions` to `tfp.distributions`."), srcs_version = "PY2AND3", deps = [ ":bijectors_py", @@ -940,6 +954,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "fill_triangular_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/fill_triangular_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "gumbel_test", size = "small", @@ -1032,6 +1065,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "matrix_inverse_tril_test", + size = "medium", + srcs = ["python/kernel_tests/bijectors/matrix_inverse_tril_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "real_nvp_test", size = "small", @@ -1099,6 +1151,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "scale_tril_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/scale_tril_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "sigmoid_test", size = "small", @@ -1216,6 +1287,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "transform_diagonal_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/transform_diagonal_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "weibull_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index ddf59891e626a85e6c917ac74b3cfaabf16eb15d..5cec93c4df2e970f203253be6342bb292f296eb0 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== """Classes representing statistical distributions and ops for working with them. - -See the @{$python/contrib.distributions} guide. """ from __future__ import absolute_import from __future__ import division @@ -32,6 +30,7 @@ from tensorflow.contrib.distributions.python.ops.conditional_distribution import from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.deterministic import * from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular +from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular_inverse from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse @@ -156,6 +155,7 @@ _allowed_symbols = [ 'kl_divergence', 'RegisterKL', 'fill_triangular', + 'fill_triangular_inverse', 'matrix_diag_transform', 'reduce_weighted_logsumexp', 'softplus_inverse', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py index 8b279ebcd908b6f375b35594ac5f3db9228a1e31..f8a52615b0f3f5ad0c7e01e0f76c7d7a6b455ef7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/conditional_bijector_test.py @@ -59,7 +59,7 @@ class ConditionalBijectorTest(test.TestCase): for name in ["inverse_log_det_jacobian", "forward_log_det_jacobian"]: method = getattr(b, name) with self.assertRaisesRegexp(ValueError, name + ".*b1.*b2"): - method(1., event_ndims=0., arg1="b1", arg2="b2") + method(1., event_ndims=0, arg1="b1", arg2="b2") if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py new file mode 100644 index 0000000000000000000000000000000000000000..caeaf2a0c6e4fff28c0edd82cb09ca0bcee85fc3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/fill_triangular_test.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class FillTriangularBijectorTest(test.TestCase): + """Tests the correctness of the FillTriangular bijector.""" + + @test_util.run_in_graph_and_eager_modes() + def testBijector(self): + x = np.float32(np.array([1., 2., 3.])) + y = np.float32(np.array([[3., 0.], + [2., 1.]])) + + b = bijectors.FillTriangular() + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + self.assertAllClose(fldj, 0.) + + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(ildj, 0.) + + @test_util.run_in_graph_and_eager_modes() + def testShape(self): + x_shape = tensor_shape.TensorShape([5, 4, 6]) + y_shape = tensor_shape.TensorShape([5, 4, 3, 3]) + + b = bijectors.FillTriangular(validate_args=True) + + x = array_ops.ones(shape=x_shape, dtype=dtypes.float32) + y_ = b.forward(x) + self.assertAllEqual(y_.shape.as_list(), y_shape.as_list()) + x_ = b.inverse(y_) + self.assertAllEqual(x_.shape.as_list(), x_shape.as_list()) + + y_shape_ = b.forward_event_shape(x_shape) + self.assertAllEqual(y_shape_.as_list(), y_shape.as_list()) + x_shape_ = b.inverse_event_shape(y_shape) + self.assertAllEqual(x_shape_.as_list(), x_shape.as_list()) + + y_shape_tensor = self.evaluate( + b.forward_event_shape_tensor(x_shape.as_list())) + self.assertAllEqual(y_shape_tensor, y_shape.as_list()) + x_shape_tensor = self.evaluate( + b.inverse_event_shape_tensor(y_shape.as_list())) + self.assertAllEqual(x_shape_tensor, x_shape.as_list()) + + @test_util.run_in_graph_and_eager_modes() + def testShapeError(self): + + b = bijectors.FillTriangular(validate_args=True) + + x_shape_bad = tensor_shape.TensorShape([5, 4, 7]) + with self.assertRaisesRegexp(ValueError, "is not a triangular number"): + b.forward_event_shape(x_shape_bad) + with self.assertRaisesOpError("is not a triangular number"): + self.evaluate(b.forward_event_shape_tensor(x_shape_bad.as_list())) + + y_shape_bad = tensor_shape.TensorShape([5, 4, 3, 2]) + with self.assertRaisesRegexp(ValueError, "Matrix must be square"): + b.inverse_event_shape(y_shape_bad) + with self.assertRaisesOpError("Matrix must be square"): + self.evaluate(b.inverse_event_shape_tensor(y_shape_bad.as_list())) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py new file mode 100644 index 0000000000000000000000000000000000000000..18397035571561731698b06d90e20dc74e3cf83c --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class MatrixInverseTriLBijectorTest(test.TestCase): + """Tests the correctness of the Y = inv(tril) transformation.""" + + @test_util.run_in_graph_and_eager_modes() + def testComputesCorrectValues(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + self.assertEqual("matrix_inverse_tril", inv.name) + x_ = np.array([[0.7, 0., 0.], + [0.1, -1., 0.], + [0.3, 0.25, 0.5]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -6. * np.sum(np.log(np.abs(np.diag(x_)))) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testOneByOneMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[5.]], dtype=np.float32) + x_inv_ = np.array([[0.2]], dtype=np.float32) + expected_fldj_ = np.log(0.04) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testZeroByZeroMatrix(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.eye(0, dtype=np.float32) + x_inv_ = np.eye(0, dtype=np.float32) + expected_fldj_ = 0. + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertNear(expected_fldj_, fldj_, err=1e-3) + self.assertNear(-expected_fldj_, ildj_, err=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testBatch(self): + # Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape + # (2, 1). + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[[[1., 0.], + [2., 3.]]], + [[[4., 0.], + [5., -6.]]]], dtype=np.float32) + x_inv_ = np.linalg.inv(x_) + expected_fldj_ = -4. * np.sum( + np.log(np.abs(np.diagonal(x_, axis1=-2, axis2=-1))), axis=-1) + + y = inv.forward(x_) + x_back = inv.inverse(x_inv_) + fldj = inv.forward_log_det_jacobian(x_, event_ndims=2) + ildj = inv.inverse_log_det_jacobian(x_inv_, event_ndims=2) + + y_, x_back_, fldj_, ildj_ = self.evaluate([y, x_back, fldj, ildj]) + + self.assertAllClose(x_inv_, y_, atol=0., rtol=1e-5) + self.assertAllClose(x_, x_back_, atol=0., rtol=1e-5) + self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3) + self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3) + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputRankTooLow(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([0.1], dtype=np.float32) + rank_error_msg = "must have rank at least 2" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + # TODO(b/80481923): Figure out why these assertions fail, and fix them. + ## def testErrorOnInputNonSquare(self): + ## inv = bijectors.MatrixInverseTriL(validate_args=True) + ## x_ = np.array([[1., 2., 3.], + ## [4., 5., 6.]], dtype=np.float32) + ## square_error_msg = "must be a square matrix" + ## with self.test_session(): + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse(x_).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + ## square_error_msg): + ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputNotLowerTriangular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 2.], + [3., 4.]], dtype=np.float32) + triangular_error_msg = "must be lower triangular" + with self.test_session(): + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + triangular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + @test_util.run_in_graph_and_eager_modes() + def testErrorOnInputSingular(self): + inv = bijectors.MatrixInverseTriL(validate_args=True) + x_ = np.array([[1., 0.], + [0., 0.]], dtype=np.float32) + nonsingular_error_msg = "must have all diagonal entries nonzero" + with self.test_session(): + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse(x_).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.forward_log_det_jacobian(x_, event_ndims=2).eval() + with self.assertRaisesOpError(nonsingular_error_msg): + inv.inverse_log_det_jacobian(x_, event_ndims=2).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py new file mode 100644 index 0000000000000000000000000000000000000000..566a7b3dff9b5d97a1cb143e0b32fc15984c3a02 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/scale_tril_test.py @@ -0,0 +1,69 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ScaleTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class ScaleTriLBijectorTest(test.TestCase): + """Tests the correctness of the ScaleTriL bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testComputesCorrectValues(self): + shift = 1.61803398875 + x = np.float32(np.array([-1, .5, 2])) + y = np.float32(np.array([[np.exp(2) + shift, 0.], + [.5, np.exp(-1) + shift]])) + + b = bijectors.ScaleTriL(diag_bijector=bijectors.Exp(), + diag_shift=shift) + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + @test_util.run_in_graph_and_eager_modes() + def testInvertible(self): + + # Generate random inputs from an unconstrained space, with + # event size 6 to specify 3x3 triangular matrices. + batch_shape = [2, 1] + x = np.float32(np.random.randn(*(batch_shape + [6]))) + b = bijectors.ScaleTriL(diag_bijector=bijectors.Softplus(), + diag_shift=3.14159) + y = self.evaluate(b.forward(x)) + self.assertAllEqual(y.shape, batch_shape + [3, 3]) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=1)) + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllClose(fldj, -ildj) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py index 45760a29ee42835da69ef63803ccec7ce82a5a8f..795f1993ba5c31bf5a26333f31f1bc73125bff07 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py @@ -151,16 +151,24 @@ class SinhArcsinhBijectorTest(test.TestCase): self.assertAllClose(y, bijector.forward(x).eval(), rtol=1e-4, atol=0.) self.assertAllClose(x, bijector.inverse(y).eval(), rtol=1e-4, atol=0.) - # Do the numpy calculation in float128 to avoid inf/nan. - y_float128 = np.float128(y) - self.assertAllClose( - np.log(np.cosh( - np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( - y_float128**2 + 1)) - - np.log(tailweight), - bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), - rtol=1e-4, - atol=0.) + # On IBM PPC systems, longdouble (np.float128) is same as double except that it can have more precision. + # Type double being of 8 bytes, can't hold square of max of float64 (which is also 8 bytes) and + # below test fails due to overflow error giving inf. So this check avoids that error by skipping square + # calculation and corresponding assert. + + if np.amax(y) <= np.sqrt(np.finfo(np.float128).max) and \ + np.fabs(np.amin(y)) <= np.sqrt(np.fabs(np.finfo(np.float128).min)): + + # Do the numpy calculation in float128 to avoid inf/nan. + y_float128 = np.float128(y) + self.assertAllClose( + np.log(np.cosh( + np.arcsinh(y_float128) / tailweight - skewness) / np.sqrt( + y_float128**2 + 1)) - + np.log(tailweight), + bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), + rtol=1e-4, + atol=0.) self.assertAllClose( -bijector.inverse_log_det_jacobian(y, event_ndims=0).eval(), bijector.forward_log_det_jacobian(x, event_ndims=0).eval(), diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py new file mode 100644 index 0000000000000000000000000000000000000000..6428a68702274fae384ae3de6d03f7ca126e2346 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/transform_diagonal_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TransformDiagonal bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class TransformDiagonalBijectorTest(test.TestCase): + """Tests correctness of the TransformDiagonal bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + @test_util.run_in_graph_and_eager_modes() + def testBijector(self): + x = np.float32(np.random.randn(3, 4, 4)) + + y = x.copy() + for i in range(x.shape[0]): + np.fill_diagonal(y[i, :, :], np.exp(np.diag(x[i, :, :]))) + + exp = bijectors.Exp() + b = bijectors.TransformDiagonal(diag_bijector=exp) + + y_ = self.evaluate(b.forward(x)) + self.assertAllClose(y, y_) + + x_ = self.evaluate(b.inverse(y)) + self.assertAllClose(x, x_) + + fldj = self.evaluate(b.forward_log_det_jacobian(x, event_ndims=2)) + ildj = self.evaluate(b.inverse_log_det_jacobian(y, event_ndims=2)) + self.assertAllEqual( + fldj, + self.evaluate(exp.forward_log_det_jacobian( + np.array([np.diag(x_mat) for x_mat in x]), + event_ndims=1))) + self.assertAllEqual( + ildj, + self.evaluate(exp.inverse_log_det_jacobian( + np.array([np.diag(y_mat) for y_mat in y]), + event_ndims=1))) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index 31d24aa9ea09007b8db40e4869371b1f62639ac7..bbbec2103aefd3f38a9b734bcd3f2e15fc8bb683 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -29,7 +29,9 @@ from tensorflow.contrib.distributions.python.ops import mvn_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.linalg import linear_operator_diag @@ -540,5 +542,51 @@ class PadDynamicTest(_PadTest, test.TestCase): return False +class TestMoveDimension(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_move_dimension_static_shape(self): + + x = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 1, 1) + self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 0, 3) + self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 0, -2) + self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 4, 2) + self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1]) + + @test_util.run_in_graph_and_eager_modes() + def test_move_dimension_dynamic_shape(self): + + x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6]) + x = array_ops.placeholder_with_default(input=x_, shape=None) + + x_perm = distribution_util.move_dimension(x, 1, 1) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 4, 1, 6]) + + x_perm = distribution_util.move_dimension(x, 0, 3) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 0, -2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [30, 4, 1, 200, 6]) + + x_perm = distribution_util.move_dimension(x, 4, 2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 6, 4, 1]) + + x_perm = distribution_util.move_dimension(x, -1, 2) + self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)), + [200, 30, 6, 4, 1]) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py index ce6cf702d522792f1ad26066a3d9be42003a0e3c..9c4dfed83631e9f0815fb674d650cac2e570b923 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py @@ -98,23 +98,21 @@ class StatisticalTestingTest(test.TestCase): num_samples = 5000 # 5000 samples is chosen to be enough to find discrepancies of # size 0.1 or more with assurance 1e-6, as confirmed here: - with self.test_session() as sess: - d = st.min_discrepancy_of_true_means_detectable_by_dkwm( - num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6) - d = sess.run(d) - self.assertLess(d, 0.1) + d = st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, 0., 1., false_fail_rate=1e-6, false_pass_rate=1e-6) + d = self.evaluate(d) + self.assertLess(d, 0.1) # Test that the confidence interval computed for the mean includes # 0.5 and excludes 0.4 and 0.6. - with self.test_session() as sess: - samples = rng.uniform(size=num_samples).astype(np.float32) - (low, high) = st.true_mean_confidence_interval_by_dkwm( - samples, 0., 1., error_rate=1e-6) - low, high = sess.run([low, high]) - self.assertGreater(low, 0.4) - self.assertLess(low, 0.5) - self.assertGreater(high, 0.5) - self.assertLess(high, 0.6) + samples = rng.uniform(size=num_samples).astype(np.float32) + (low, high) = st.true_mean_confidence_interval_by_dkwm( + samples, 0., 1., error_rate=1e-6) + low, high = self.evaluate([low, high]) + self.assertGreater(low, 0.4) + self.assertLess(low, 0.5) + self.assertGreater(high, 0.5) + self.assertLess(high, 0.6) def test_dkwm_mean_one_sample_assertion(self): rng = np.random.RandomState(seed=0) @@ -123,21 +121,45 @@ class StatisticalTestingTest(test.TestCase): # Test that the test assertion agrees that the mean of the standard # uniform distribution is 0.5. samples = rng.uniform(size=num_samples).astype(np.float32) - with self.test_session() as sess: - sess.run(st.assert_true_mean_equal_by_dkwm( - samples, 0., 1., 0.5, false_fail_rate=1e-6)) - - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is not 0.4. - with self.assertRaisesOpError("Mean confidence interval too high"): - sess.run(st.assert_true_mean_equal_by_dkwm( - samples, 0., 1., 0.4, false_fail_rate=1e-6)) - - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is not 0.6. - with self.assertRaisesOpError("Mean confidence interval too low"): - sess.run(st.assert_true_mean_equal_by_dkwm( - samples, 0., 1., 0.6, false_fail_rate=1e-6)) + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.5, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.4. + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not 0.6. + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm( + samples, 0., 1., 0.6, false_fail_rate=1e-6)) + + def test_dkwm_mean_in_interval_one_sample_assertion(self): + rng = np.random.RandomState(seed=0) + num_samples = 5000 + + # Test that the test assertion agrees that the mean of the standard + # uniform distribution is between 0.4 and 0.6. + samples = rng.uniform(size=num_samples).astype(np.float32) + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.4, expected_high=0.6, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not between 0.2 and 0.4. + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.2, expected_high=0.4, false_fail_rate=1e-6)) + + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is not between 0.6 and 0.8. + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_in_interval_by_dkwm( + samples, 0., 1., + expected_low=0.6, expected_high=0.8, false_fail_rate=1e-6)) def test_dkwm_mean_two_sample_assertion(self): rng = np.random.RandomState(seed=0) @@ -145,20 +167,18 @@ class StatisticalTestingTest(test.TestCase): # 4000 samples is chosen to be enough to find discrepancies of # size 0.2 or more with assurance 1e-6, as confirmed here: - with self.test_session() as sess: - d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( - num_samples, 0., 1., num_samples, 0., 1., - false_fail_rate=1e-6, false_pass_rate=1e-6) - d = sess.run(d) - self.assertLess(d, 0.2) + d = st.min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( + num_samples, 0., 1., num_samples, 0., 1., + false_fail_rate=1e-6, false_pass_rate=1e-6) + d = self.evaluate(d) + self.assertLess(d, 0.2) # Test that the test assertion agrees that the standard # uniform distribution has the same mean as itself. samples1 = rng.uniform(size=num_samples).astype(np.float32) samples2 = rng.uniform(size=num_samples).astype(np.float32) - with self.test_session() as sess: - sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( - samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6)) + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., samples2, 0., 1., false_fail_rate=1e-6)) def test_dkwm_mean_two_sample_assertion_beta_2_1_false(self): rng = np.random.RandomState(seed=0) @@ -168,15 +188,14 @@ class StatisticalTestingTest(test.TestCase): # As established above, 4000 samples is enough to find discrepancies # of size 0.2 or more with assurance 1e-6. - with self.test_session() as sess: - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is different from the mean of beta(2, 1). - beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) - with self.assertRaisesOpError("samples1 has a smaller mean"): - sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( - samples1, 0., 1., - beta_high_samples, 0., 1., - false_fail_rate=1e-6)) + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(2, 1). + beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32) + with self.assertRaisesOpError("true mean smaller than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_high_samples, 0., 1., + false_fail_rate=1e-6)) def test_dkwm_mean_two_sample_assertion_beta_1_2_false(self): rng = np.random.RandomState(seed=0) @@ -186,15 +205,14 @@ class StatisticalTestingTest(test.TestCase): # As established above, 4000 samples is enough to find discrepancies # of size 0.2 or more with assurance 1e-6. - with self.test_session() as sess: - # Test that the test assertion confirms that the mean of the - # standard uniform distribution is different from the mean of beta(1, 2). - beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) - with self.assertRaisesOpError("samples2 has a smaller mean"): - sess.run(st.assert_true_mean_equal_by_dkwm_two_sample( - samples1, 0., 1., - beta_low_samples, 0., 1., - false_fail_rate=1e-6)) + # Test that the test assertion confirms that the mean of the + # standard uniform distribution is different from the mean of beta(1, 2). + beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32) + with self.assertRaisesOpError("true mean greater than expected"): + self.evaluate(st.assert_true_mean_equal_by_dkwm_two_sample( + samples1, 0., 1., + beta_low_samples, 0., 1., + false_fail_rate=1e-6)) def test_dkwm_argument_validity_checking(self): rng = np.random.RandomState(seed=0) @@ -203,18 +221,17 @@ class StatisticalTestingTest(test.TestCase): # Test that the test library complains if the given samples fall # outside the purported bounds. - with self.test_session() as sess: - with self.assertRaisesOpError("maximum value exceeds expectations"): - sess.run(st.true_mean_confidence_interval_by_dkwm( - samples, [[0., 1.]], [[0.5, 1.5]], error_rate=0.5)) - with self.assertRaisesOpError("minimum value falls below expectations"): - sess.run(st.true_mean_confidence_interval_by_dkwm( - samples, [[0.5, 1.5]], [[1., 2.]], error_rate=0.5)) - - # But doesn't complain if they don't. - op = st.true_mean_confidence_interval_by_dkwm( - samples, [[0., 1.]], [[1., 2.]], error_rate=0.5) - _ = sess.run(op) + with self.assertRaisesOpError("maximum value exceeds expectations"): + self.evaluate(st.true_mean_confidence_interval_by_dkwm( + samples, [[0., 1.]], [[0.5, 1.5]], error_rate=0.5)) + with self.assertRaisesOpError("minimum value falls below expectations"): + self.evaluate(st.true_mean_confidence_interval_by_dkwm( + samples, [[0.5, 1.5]], [[1., 2.]], error_rate=0.5)) + + # But doesn't complain if they don't. + op = st.true_mean_confidence_interval_by_dkwm( + samples, [[0., 1.]], [[1., 2.]], error_rate=0.5) + _ = self.evaluate(op) def test_do_maximum_mean(self): n = 117 @@ -223,10 +240,9 @@ class StatisticalTestingTest(test.TestCase): samples = rng.uniform(size=n).astype(np.float32) # Compute the answer in TF using the code under test - with self.test_session() as sess: - envelope_t = ops.convert_to_tensor(envelope) - max_mean = st._do_maximum_mean(samples, envelope_t, 1) - max_mean = sess.run(max_mean) + envelope_t = ops.convert_to_tensor(envelope) + max_mean = st._do_maximum_mean(samples, envelope_t, 1) + max_mean = self.evaluate(max_mean) # Compute the correct answer for this case in numpy. In this # example, `n` and `envelope` are such that `samples[2]` is the diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..03e26b198ea02ad1bef8bcd2f6076078ecd7df0b --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/BUILD @@ -0,0 +1,48 @@ +# Description: +# Internal testing utilities, e.g., computing the correct answer to +# put in a unit test. + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "correlation_matrix_volumes_py", + srcs = [ + "correlation_matrix_volumes_lib.py", + ], + deps = [ + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//third_party/py/numpy", + ], +) + +py_binary( + name = "correlation_matrix_volumes", + srcs = [ + "correlation_matrix_volumes.py", + ], + deps = [ + ":correlation_matrix_volumes_py", + ], +) + +py_test( + name = "correlation_matrix_volumes_test", + size = "medium", + srcs = ["correlation_matrix_volumes_test.py"], + tags = ["no_pip"], + deps = [ + ":correlation_matrix_volumes_py", + # For statistical testing + "//tensorflow/contrib/distributions:distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + ], +) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py new file mode 100644 index 0000000000000000000000000000000000000000..2eab51cd3053ea55f2e03619fd002fbf48251fb1 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Executable to estimate the volume of various sets of correlation matrices. + +See correlation_matrix_volumes_lib.py for purpose and methodology. + +Invocation example: +``` +python correlation_matrix_volumes.py --num_samples 1e7 +``` + +This will compute 10,000,000-sample confidence intervals for the +volumes of several sets of correlation matrices. Which sets, and the +desired statistical significance, are hard-coded in this source file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pprint + +from absl import app +from absl import flags + +from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr + +FLAGS = flags.FLAGS + +# Float to support giving the number of samples in scientific notation. +# The production run used for the LKJ test used 1e7 samples. +flags.DEFINE_float('num_samples', 1e4, 'Number of samples to use.') + + +def ctv_debatched(det_bounds, dim, num_samples, error_rate=1e-6, seed=42): + # This wrapper undoes the batching in compute_true_volumes, because + # apparently several 5x5x9x1e7 Tensors of float32 can strain RAM. + bounds = {} + for db in det_bounds: + bounds[db] = corr.compute_true_volumes( + [db], dim, num_samples, error_rate=error_rate, seed=seed)[db] + return bounds + + +# The particular bounds in all three of these functions were chosen by +# a somewhat arbitrary walk through an empirical tradeoff, for the +# purpose of testing the LKJ distribution. Setting the determinant +# bound lower +# - Covers more of the testee's sample space, and +# - Increases the probability that the rejection sampler will hit, thus +# - Decreases the relative error (at a fixed sample count) in the +# rejection-based volume estimate; +# but also +# - Increases the variance of the estimator used in the LKJ test. +# This latter variance is also affected by the dimension and the +# tested concentration parameter, and can be compensated for with more +# compute (expensive) or a looser discrepancy limit (unsatisfying). +# The values here are the projection of the points in that test design +# space that ended up getting chosen. +def compute_3x3_volumes(num_samples): + det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45] + return ctv_debatched( + det_bounds, 3, num_samples, error_rate=5e-7, seed=46) + + +def compute_4x4_volumes(num_samples): + det_bounds = [0.01, 0.25, 0.3, 0.35, 0.4, 0.45] + return ctv_debatched( + det_bounds, 4, num_samples, error_rate=5e-7, seed=47) + + +def compute_5x5_volumes(num_samples): + det_bounds = [0.01, 0.2, 0.25, 0.3, 0.35, 0.4] + return ctv_debatched( + det_bounds, 5, num_samples, error_rate=5e-7, seed=48) + + +def main(_): + full_bounds = {} + full_bounds[3] = compute_3x3_volumes(int(FLAGS.num_samples)) + full_bounds[4] = compute_4x4_volumes(int(FLAGS.num_samples)) + full_bounds[5] = compute_5x5_volumes(int(FLAGS.num_samples)) + pprint.pprint(full_bounds) + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..455e71f00c96e799c4aaae25050c77a9ae36df06 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_lib.py @@ -0,0 +1,323 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Estimating the volume of the correlation matrices with bounded determinant. + +Why? Because lkj_test.py tests the sampler for the LKJ distribution +by estimating the same volume another way. + +How? Rejection sampling. Or, more precisely, importance sampling, +proposing from the uniform distribution on symmetric matrices with +diagonal 1s and entries in [-1, 1]. Such a matrix is a correlation +matrix if and only if it is also positive semi-definite. + +The samples can then be converted into a confidence interval on the +volume in question by the [Clopper-Pearson +method](https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval), +also implemented here. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import sys + +import numpy as np + +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import uniform +from tensorflow.python.ops.distributions import util +from tensorflow.python.platform import tf_logging + +__all__ = [ + "correlation_matrix_volume_rejection_samples", + "compute_true_volumes", +] + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +optimize = try_import("scipy.optimize") +stats = try_import("scipy.stats") + + +def _psd_mask(x): + """Computes whether each square matrix in the input is positive semi-definite. + + Args: + x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. + + Returns: + mask: A floating-point `Tensor` of shape `[B1, ... Bn]`. Each + scalar is 1 if the corresponding matrix was PSD, otherwise 0. + """ + # Allegedly + # https://scicomp.stackexchange.com/questions/12979/testing-if-a-matrix-is-positive-semi-definite + # it is more efficient to test for positive semi-definiteness by + # trying to compute the Cholesky decomposition -- the matrix is PSD + # if you succeed and not PSD if you fail. However, TensorFlow's + # Cholesky raises an exception if _any_ of the input matrices are + # not PSD, from which I don't know how to extract _which ones_, so I + # proceed by explicitly computing all the eigenvalues and checking + # whether they are all positive or not. + # + # Also, as was discussed in the answer, it is somewhat dangerous to + # treat SPD-ness as binary in floating-point arithmetic. Cholesky + # factorization can complete and 'look' like everything is fine + # (e.g., O(1) entries and a diagonal of all ones) but the matrix can + # have an exponential condition number. + eigenvalues, _ = linalg_ops.self_adjoint_eig(x) + return math_ops.cast( + math_ops.reduce_min(eigenvalues, axis=-1) >= 0, dtype=x.dtype) + + +def _det_large_enough_mask(x, det_bounds): + """Returns whether the input matches the given determinant limit. + + Args: + x: A floating-point `Tensor` of shape `[B1, ..., Bn, M, M]`. + det_bounds: A floating-point `Tensor` that must broadcast to shape + `[B1, ..., Bn]`, giving the desired lower bound on the + determinants in `x`. + + Returns: + mask: A floating-point `Tensor` of shape [B1, ..., Bn]. Each + scalar is 1 if the corresponding matrix had determinant above + the corresponding bound, otherwise 0. + """ + # For the curious: I wonder whether it is possible and desirable to + # use a Cholesky decomposition-based algorithm for this, since the + # only matrices whose determinant this code cares about will be PSD. + # Didn't figure out how to code that in TensorFlow. + # + # Expert opinion is that it would be about twice as fast since + # Cholesky is roughly half the cost of Gaussian Elimination with + # Partial Pivoting. But this is less of an impact than the switch in + # _psd_mask. + return math_ops.cast( + linalg_ops.matrix_determinant(x) > det_bounds, dtype=x.dtype) + + +def _uniform_correlation_like_matrix(num_rows, batch_shape, dtype, seed): + """Returns a uniformly random `Tensor` of "correlation-like" matrices. + + A "correlation-like" matrix is a symmetric square matrix with all entries + between -1 and 1 (inclusive) and 1s on the main diagonal. Of these, + the ones that are positive semi-definite are exactly the correlation + matrices. + + Args: + num_rows: Python `int` dimension of the correlation-like matrices. + batch_shape: `Tensor` or Python `tuple` of `int` shape of the + batch to return. + dtype: `dtype` of the `Tensor` to return. + seed: Random seed. + + Returns: + matrices: A `Tensor` of shape `batch_shape + [num_rows, num_rows]` + and dtype `dtype`. Each entry is in [-1, 1], and each matrix + along the bottom two dimensions is symmetric and has 1s on the + main diagonal. + """ + num_entries = num_rows * (num_rows + 1) / 2 + ones = array_ops.ones(shape=[num_entries], dtype=dtype) + # It seems wasteful to generate random values for the diagonal since + # I am going to throw them away, but `fill_triangular` fills the + # diagonal, so I probably need them. + # It's not impossible that it would be more efficient to just fill + # the whole matrix with random values instead of messing with + # `fill_triangular`. Then would need to filter almost half out with + # `matrix_band_part`. + unifs = uniform.Uniform(-ones, ones).sample(batch_shape, seed=seed) + tril = util.fill_triangular(unifs) + symmetric = tril + array_ops.matrix_transpose(tril) + diagonal_ones = array_ops.ones( + shape=util.pad(batch_shape, axis=0, back=True, value=num_rows), + dtype=dtype) + return array_ops.matrix_set_diag(symmetric, diagonal_ones) + + +def correlation_matrix_volume_rejection_samples( + det_bounds, dim, sample_shape, dtype, seed): + """Returns rejection samples from trying to get good correlation matrices. + + The proposal being rejected from is the uniform distribution on + "correlation-like" matrices. We say a matrix is "correlation-like" + if it is a symmetric square matrix with all entries between -1 and 1 + (inclusive) and 1s on the main diagonal. Of these, the ones that + are positive semi-definite are exactly the correlation matrices. + + The rejection algorithm, then, is to sample a `Tensor` of + `sample_shape` correlation-like matrices of dimensions `dim` by + `dim`, and check each one for (i) being a correlation matrix (i.e., + PSD), and (ii) having determinant at least the corresponding entry + of `det_bounds`. + + Args: + det_bounds: A `Tensor` of lower bounds on the determinants of + acceptable matrices. The shape must broadcast with `sample_shape`. + dim: A Python `int` dimension of correlation matrices to sample. + sample_shape: Python `tuple` of `int` shape of the samples to + compute, excluding the two matrix dimensions. + dtype: The `dtype` in which to do the computation. + seed: Random seed. + + Returns: + weights: A `Tensor` of shape `sample_shape`. Each entry is 0 if the + corresponding matrix was not a correlation matrix, or had too + small of a determinant. Otherwise, the entry is the + multiplicative inverse of the density of proposing that matrix + uniformly, i.e., the volume of the set of `dim` by `dim` + correlation-like matrices. + volume: The volume of the set of `dim` by `dim` correlation-like + matrices. + """ + with ops.name_scope("rejection_sampler"): + rej_proposals = _uniform_correlation_like_matrix( + dim, sample_shape, dtype, seed=seed) + rej_proposal_volume = 2. ** (dim * (dim - 1) / 2.) + # The density of proposing any given point is 1 / rej_proposal_volume; + # The weight of that point should be scaled by + # 1 / density = rej_proposal_volume. + rej_weights = rej_proposal_volume * _psd_mask( + rej_proposals) * _det_large_enough_mask(rej_proposals, det_bounds) + return rej_weights, rej_proposal_volume + + +def _clopper_pearson_confidence_interval(samples, error_rate): + """Computes a confidence interval for the mean of the given 1-D distribution. + + Assumes (and checks) that the given distribution is Bernoulli, i.e., + takes only two values. This licenses using the CDF of the binomial + distribution for the confidence, which is tighter (for extreme + probabilities) than the DKWM inequality. The method is known as the + [Clopper-Pearson method] + (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). + + Assumes: + + - The given samples were drawn iid from the distribution of interest. + + - The given distribution is a Bernoulli, i.e., supported only on + low and high. + + Guarantees: + + - The probability (over the randomness of drawing the given sample) + that the true mean is outside the returned interval is no more + than the given error_rate. + + Args: + samples: `np.ndarray` of samples drawn iid from the distribution + of interest. + error_rate: Python `float` admissible rate of mistakes. + + Returns: + low: Lower bound of confidence interval. + high: Upper bound of confidence interval. + + Raises: + ValueError: If `samples` has rank other than 1 (batch semantics + are not implemented), or if `samples` contains values other than + `low` or `high` (as that makes the distribution not Bernoulli). + """ + # TODO(b/78025336) Migrate this confidence interval function + # to statistical_testing.py. In order to do that + # - Get the binomial CDF from the Binomial distribution + # - Implement scalar root finding in TF. Batch bisection search + # shouldn't be too hard, and is definitely good enough for this + # problem. Batching the Brent algorithm (from scipy) that is used + # here may be more involved, but may also not be necessary---it's + # only used here because scipy made it convenient. In particular, + # robustness is more important than speed here, which may make + # bisection search actively better. + # - The rest is just a matter of rewriting in the appropriate style. + if optimize is None or stats is None: + raise ValueError( + "Scipy is required for computing Clopper-Pearson confidence intervals") + if len(samples.shape) != 1: + raise ValueError("Batch semantics not implemented") + n = len(samples) + low = np.amin(samples) + high = np.amax(samples) + successes = np.count_nonzero(samples - low) + failures = np.count_nonzero(samples - high) + if successes + failures != n: + uniques = np.unique(samples) + msg = ("Purportedly Bernoulli distribution had distinct samples" + " {}, {}, and {}".format(uniques[0], uniques[1], uniques[2])) + raise ValueError(msg) + def p_small_enough(p): + prob = stats.binom.logcdf(successes, n, p) + return prob - np.log(error_rate / 2.) + def p_big_enough(p): + prob = stats.binom.logsf(successes, n, p) + return prob - np.log(error_rate / 2.) + high_p = optimize.brentq( + p_small_enough, float(successes) / n, 1., rtol=1e-9) + low_p = optimize.brentq( + p_big_enough, 0., float(successes) / n, rtol=1e-9) + low_interval = low + (high - low) * low_p + high_interval = low + (high - low) * high_p + return (low_interval, high_interval) + + +def compute_true_volumes( + det_bounds, dim, num_samples, error_rate=1e-6, seed=42): + """Returns confidence intervals for the desired correlation matrix volumes. + + The confidence intervals are computed by the [Clopper-Pearson method] + (https://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval). + + Args: + det_bounds: A rank-1 numpy array of lower bounds on the + determinants of acceptable matrices. Entries must be unique. + dim: A Python `int` dimension of correlation matrices to sample. + num_samples: The number of samples to draw. + error_rate: The statistical significance of the returned + confidence intervals. The significance is broadcast: Each + returned interval separately may be incorrect with probability + (under the sample of correlation-like matrices drawn internally) + at most `error_rate`. + seed: Random seed. + + Returns: + bounds: A Python `dict` mapping each determinant bound to the low, high + tuple giving the confidence interval. + """ + bounds = {} + with session.Session() as sess: + rej_weights, _ = correlation_matrix_volume_rejection_samples( + det_bounds, dim, [num_samples, len(det_bounds)], np.float32, seed=seed) + rej_weights = sess.run(rej_weights) + for rw, det in zip(np.rollaxis(rej_weights, 1), det_bounds): + template = ("Estimating volume of {}x{} correlation " + "matrices with determinant >= {}.") + print(template.format(dim, dim, det)) + sys.stdout.flush() + bounds[det] = _clopper_pearson_confidence_interval( + rw, error_rate=error_rate) + return bounds diff --git a/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8f99300e63871119800a42f122c8321e5986541a --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/util/correlation_matrix_volumes_test.py @@ -0,0 +1,150 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for correlation_matrix_volumes_lib.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.kernel_tests.util import correlation_matrix_volumes_lib as corr +from tensorflow.contrib.distributions.python.ops import statistical_testing as st +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.platform import test + + +# NxN correlation matrices are determined by the N*(N-1)/2 +# lower-triangular entries. In addition to being between -1 and 1, +# they must also obey the constraint that the determinant of the +# resulting symmetric matrix is non-negative. In 2x2, we can even +# analytically compute the volume when the determinant is bounded to > +# epsilon, as that boils down to the one lower-triangular entry being +# less than 1 - epsilon in absolute value. +def two_by_two_volume(det_bound): + return 2 * np.sqrt(1.0 - det_bound) + + +# The post +# https://psychometroscar.com/the-volume-of-a-3-x-3-correlation-matrix/ +# derives (with elementary calculus) that the volume (with respect to +# Lebesgue^3 measure) of the set of 3x3 correlation matrices is +# pi^2/2. The same result is also obtained by [1]. +def three_by_three_volume(): + return np.pi**2 / 2. + + +# The volume of the unconstrained set of correlation matrices is also +# the normalization constant of the LKJ distribution from [2]. As +# part of defining the distribution, that reference a derives general +# formula for this volume for all dimensions. A TensorFlow +# computation thereof gave the below result for 4x4: +def four_by_four_volume(): + # This constant computed as math_ops.exp(lkj.log_norm_const(4, [1.0])) + return 11.6973076 + +# [1] Rousseeuw, P. J., & Molenberghs, G. (1994). "The shape of +# correlation matrices." The American Statistician, 48(4), 276-279. + +# [2] Daniel Lewandowski, Dorota Kurowicka, and Harry Joe, "Generating +# random correlation matrices based on vines and extended onion +# method," Journal of Multivariate Analysis 100 (2009), pp 1989-2001. + + +class CorrelationMatrixVolumesTest(test.TestCase): + + def testRejection2D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array( + [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) + exact_volumes = two_by_two_volume(det_bounds) + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 2, [num_samples, 9], dtype=np.float32, seed=43) + # shape of rej_weights: [num_samples, 9, 2, 2] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + # Correct the false fail rate due to different broadcasting + false_fail_rate=1.1e-7, false_pass_rate=1e-6), + 0.036) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testRejection3D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array([0.0], dtype=np.float32) + exact_volumes = np.array([three_by_three_volume()], dtype=np.float32) + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 3, [num_samples, 1], dtype=np.float32, seed=44) + # shape of rej_weights: [num_samples, 1, 3, 3] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + false_fail_rate=1e-6, false_pass_rate=1e-6), + # Going for about a 3% relative error + 0.15) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testRejection4D(self): + num_samples = int(1e5) # Chosen for a small min detectable discrepancy + det_bounds = np.array([0.0], dtype=np.float32) + exact_volumes = [four_by_four_volume()] + (rej_weights, + rej_proposal_volume) = corr.correlation_matrix_volume_rejection_samples( + det_bounds, 4, [num_samples, 1], dtype=np.float32, seed=45) + # shape of rej_weights: [num_samples, 1, 4, 4] + chk1 = st.assert_true_mean_equal_by_dkwm( + rej_weights, low=0., high=rej_proposal_volume, expected=exact_volumes, + false_fail_rate=1e-6) + chk2 = check_ops.assert_less( + st.min_discrepancy_of_true_means_detectable_by_dkwm( + num_samples, low=0., high=rej_proposal_volume, + false_fail_rate=1e-6, false_pass_rate=1e-6), + # Going for about a 10% relative error + 1.1) + with ops.control_dependencies([chk1, chk2]): + rej_weights = array_ops.identity(rej_weights) + self.evaluate(rej_weights) + + def testVolumeEstimation2D(self): + # Test that the confidence intervals produced by + # corr.compte_true_volumes are sound, in the sense of containing + # the exact volume. + num_samples = int(1e5) # Chosen by symmetry with testRejection2D + det_bounds = np.array( + [0.01, 0.02, 0.03, 0.04, 0.05, 0.3, 0.35, 0.4, 0.5], dtype=np.float32) + volume_bounds = corr.compute_true_volumes( + det_bounds, 2, num_samples, error_rate=1e-6, seed=47) + exact_volumes = two_by_two_volume(det_bounds) + for det, volume in zip(det_bounds, exact_volumes): + computed_low, computed_high = volume_bounds[det] + self.assertLess(computed_low, volume) + self.assertGreater(computed_high, volume) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py index d813831bef803a22c095d9c98e7163aa4861a15d..bb9b8043b2233b2109f51b5dde188d088fdb0d39 100644 --- a/tensorflow/contrib/distributions/python/ops/autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Autoregressive(distribution_lib.Distribution): @@ -107,6 +108,14 @@ class Autoregressive(distribution_lib.Distribution): https://arxiv.org/abs/1606.05328 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution_fn, sample0=None, @@ -144,7 +153,7 @@ class Autoregressive(distribution_lib.Distribution): `distribution_fn(sample0).event_shape.num_elements()` are both `None`. ValueError: if `num_steps < 1`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name) as name: self._distribution_fn = distribution_fn self._sample0 = sample0 diff --git a/tensorflow/contrib/distributions/python/ops/batch_reshape.py b/tensorflow/contrib/distributions/python/ops/batch_reshape.py index c709318f76552e1188f735f5bafff4be0537baed..519077bc9ab1063a1135486cfae34656f3f68157 100644 --- a/tensorflow/contrib/distributions/python/ops/batch_reshape.py +++ b/tensorflow/contrib/distributions/python/ops/batch_reshape.py @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -72,6 +72,14 @@ class BatchReshape(distribution_lib.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution, batch_shape, @@ -103,7 +111,7 @@ class BatchReshape(distribution_lib.Distribution): ValueError: if `batch_shape` size is not the same as a `distribution.batch_shape` size. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) name = name or "BatchReshape" + distribution.name with ops.name_scope(name, values=[batch_shape]) as name: # The unexpanded batch shape may contain up to one dimension of -1. @@ -353,6 +361,14 @@ class BatchReshape(distribution_lib.Distribution): return runtime_assertions +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def calculate_reshape(original_shape, new_shape, validate=False, name=None): """Calculates the reshaped dimensions (replacing up to one -1 in reshape).""" batch_shape_static = tensor_util.constant_value_as_shape(new_shape) @@ -385,6 +401,14 @@ def calculate_reshape(original_shape, new_shape, validate=False, name=None): return expanded_new_shape, batch_shape_static, validations +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def validate_init_args_statically(distribution, batch_shape): """Helper to __init__ which makes or raises assertions.""" if batch_shape.shape.ndims is not None: diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 51478dbeffaabc58ce3662f25f06bc579e8a407e..e141f8b5c6423bd6cce4d09da6f49d55b3e25a24 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -24,23 +24,27 @@ @@CholeskyOuterProduct @@ConditionalBijector @@Exp +@@FillTriangular @@Gumbel @@Identity @@Inline @@Invert @@Kumaraswamy @@MaskedAutoregressiveFlow +@@MatrixInverseTriL @@Ordered @@Permute @@PowerTransform @@RealNVP @@Reshape +@@ScaleTriL @@Sigmoid @@SinhArcsinh @@SoftmaxCentered @@Softplus @@Softsign @@Square +@@TransformDiagonal @@Weibull @@masked_autoregressive_default_template @@ -63,22 +67,26 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import * +from tensorflow.contrib.distributions.python.ops.bijectors.fill_triangular import * from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * +from tensorflow.contrib.distributions.python.ops.bijectors.matrix_inverse_tril import * from tensorflow.contrib.distributions.python.ops.bijectors.ordered import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import * from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * +from tensorflow.contrib.distributions.python.ops.bijectors.scale_tril import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.softplus import * from tensorflow.contrib.distributions.python.ops.bijectors.softsign import * from tensorflow.contrib.distributions.python.ops.bijectors.square import * +from tensorflow.contrib.distributions.python.ops.bijectors.transform_diagonal import * from tensorflow.python.ops.distributions.bijector import * from tensorflow.python.ops.distributions.identity_bijector import Identity diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index c9e31d7712f09f6c4b4cc6ae51a34c42a19c291d..4d6a46e7358933fdf512f49eae2673f35953c90a 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -23,6 +23,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "AbsoluteValue", @@ -70,6 +71,14 @@ class AbsoluteValue(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="absolute_value"): """Instantiates the `AbsoluteValue` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index b4c2939eb914d50475ba6b1c1e979a804090f641..25f29452c3949600b8a4153a8585dd7269bd3b2b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -36,6 +37,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _as_tensor(x, name): """Convenience to convert to `Tensor` or leave as `None`.""" return None if x is None else ops.convert_to_tensor(x, name=name) @@ -97,6 +106,14 @@ class Affine(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale_identity_multiplier=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index 59f9742d576a7804f401d3a47ba31ae61d6c6e54..91301f15ad87e133777371b346864ecf7b964f27 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.util import deprecation __all__ = [ @@ -88,6 +89,14 @@ class AffineLinearOperator(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py index cd792e2c8cf48602daf9fb5eb56b8c34bac050c7..460d906231bd30f8cec4fe21d42afe7b2a05805e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_scalar.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -52,6 +53,14 @@ class AffineScalar(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift=None, scale=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py index 224cec8a63dba53a528490117efac890312fe8d5..f19f147dd645b4f805f1905899b44293284d4225 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/batch_normalization.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -34,6 +35,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _undo_batch_normalization(x, mean, variance, @@ -128,6 +137,14 @@ class BatchNormalization(bijector.Bijector): Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, batchnorm_layer=None, training=True, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index b158a51bb022b5e2ea3afda74e97b9dc131665a6..910774ea5bb4106a948567144c46c6db23a2c6e0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -31,10 +32,26 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _use_static_shape(input_tensor, ndims): return input_tensor.shape.is_fully_defined() and isinstance(ndims, int) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _compute_min_event_ndims(bijector_list, compute_forward=True): """Computes the min_event_ndims associated with the give list of bijectors. @@ -142,6 +159,14 @@ class Chain(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, bijectors=None, validate_args=False, name=None): """Instantiates `Chain` bijector. @@ -234,7 +259,7 @@ class Chain(bijector.Bijector): if not self.bijectors: return ildj - event_ndims = self._maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( self.inverse_min_event_ndims) if _use_static_shape(y, event_ndims): @@ -248,12 +273,15 @@ class Chain(bijector.Bijector): if _use_static_shape(y, event_ndims): event_shape = b.inverse_event_shape(event_shape) - event_ndims = self._maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( event_shape.ndims) else: event_shape = b.inverse_event_shape_tensor(event_shape) - event_ndims = self._maybe_get_event_ndims_statically( - array_ops.size(event_shape)) + event_ndims = array_ops.size(event_shape) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) + if event_ndims_ is not None: + event_ndims = event_ndims_ + y = b.inverse(y, **kwargs.get(b.name, {})) return ildj @@ -270,7 +298,7 @@ class Chain(bijector.Bijector): if not self.bijectors: return fldj - event_ndims = self._maybe_get_event_ndims_statically( + event_ndims = self._maybe_get_static_event_ndims( self.forward_min_event_ndims) if _use_static_shape(x, event_ndims): @@ -283,21 +311,14 @@ class Chain(bijector.Bijector): x, event_ndims=event_ndims, **kwargs.get(b.name, {})) if _use_static_shape(x, event_ndims): event_shape = b.forward_event_shape(event_shape) - event_ndims = self._maybe_get_event_ndims_statically(event_shape.ndims) + event_ndims = self._maybe_get_static_event_ndims(event_shape.ndims) else: event_shape = b.forward_event_shape_tensor(event_shape) - event_ndims = self._maybe_get_event_ndims_statically( - array_ops.size(event_shape)) + event_ndims = array_ops.size(event_shape) + event_ndims_ = self._maybe_get_static_event_ndims(event_ndims) + if event_ndims_ is not None: + event_ndims = event_ndims_ x = b.forward(x, **kwargs.get(b.name, {})) return fldj - - def _maybe_get_event_ndims_statically(self, event_ndims): - event_ndims_ = super(Chain, self)._maybe_get_event_ndims_statically( - event_ndims) - if event_ndims_ is None: - return event_ndims - return event_ndims_ - - diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 268c8d03426d435dc38412ac1bd05c674bd05d2b..8267ee7df89f69f8d610e9507e0cca9f4a5d4323 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -69,6 +70,14 @@ class CholeskyOuterProduct(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="cholesky_outer_product"): """Instantiates the `CholeskyOuterProduct` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py index 9fc1bbf052b419d07a9db149b990c2b80190d72b..07627e1e45eae6b63d830b2adf036bdc3b1d2895 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops.bijectors import power_transform +from tensorflow.python.util import deprecation __all__ = [ @@ -47,6 +48,14 @@ class Exp(power_transform.PowerTransform): over the event space. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="exp"): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py new file mode 100644 index 0000000000000000000000000000000000000000..31a9ca27e519bc312813668bf621a875838f12a0 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/fill_triangular.py @@ -0,0 +1,165 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FillTriangular bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as dist_util +from tensorflow.python.util import deprecation + + +__all__ = [ + "FillTriangular", +] + + +class FillTriangular(bijector.Bijector): + """Transforms vectors to triangular. + + Triangular matrix elements are filled in a clockwise spiral. + + Given input with shape `batch_shape + [d]`, produces output with + shape `batch_shape + [n, n]`, where + `n = (-1 + sqrt(1 + 8 * d))/2`. + This follows by solving the quadratic equation + `d = 1 + 2 + ... + n = n * (n + 1)/2`. + + #### Example + + ```python + b = tfb.FillTriangular(upper=False) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[4, 0, 0], + # [6, 5, 0], + # [3, 2, 1]] + + b = tfb.FillTriangular(upper=True) + b.forward([1, 2, 3, 4, 5, 6]) + # ==> [[1, 2, 3], + # [0, 5, 6], + # [0, 0, 4]] + + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + upper=False, + validate_args=False, + name="fill_triangular"): + """Instantiates the `FillTriangular` bijector. + + Args: + upper: Python `bool` representing whether output matrix should be upper + triangular (`True`) or lower triangular (`False`, default). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._upper = upper + super(FillTriangular, self).__init__( + forward_min_event_ndims=1, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return dist_util.fill_triangular(x, upper=self._upper) + + def _inverse(self, y): + return dist_util.fill_triangular_inverse(y, upper=self._upper) + + def _forward_log_det_jacobian(self, x): + return array_ops.zeros_like(x[..., 0]) + + def _inverse_log_det_jacobian(self, y): + return array_ops.zeros_like(y[..., 0, 0]) + + def _forward_event_shape(self, input_shape): + batch_shape, d = input_shape[:-1], input_shape[-1].value + if d is None: + n = None + else: + n = vector_size_to_square_matrix_size(d, self.validate_args) + return batch_shape.concatenate([n, n]) + + def _inverse_event_shape(self, output_shape): + batch_shape, n1, n2 = (output_shape[:-2], + output_shape[-2].value, + output_shape[-1].value) + if n1 is None or n2 is None: + m = None + elif n1 != n2: + raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2)) + else: + m = n1 * (n1 + 1) / 2 + return batch_shape.concatenate([m]) + + def _forward_event_shape_tensor(self, input_shape_tensor): + batch_shape, d = input_shape_tensor[:-1], input_shape_tensor[-1] + n = vector_size_to_square_matrix_size(d, self.validate_args) + return array_ops.concat([batch_shape, [n, n]], axis=0) + + def _inverse_event_shape_tensor(self, output_shape_tensor): + batch_shape, n = output_shape_tensor[:-2], output_shape_tensor[-1] + if self.validate_args: + is_square_matrix = check_ops.assert_equal( + n, output_shape_tensor[-2], message="Matrix must be square.") + with ops.control_dependencies([is_square_matrix]): + n = array_ops.identity(n) + d = math_ops.cast(n * (n + 1) / 2, output_shape_tensor.dtype) + return array_ops.concat([batch_shape, [d]], axis=0) + + +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) +def vector_size_to_square_matrix_size(d, validate_args, name=None): + """Convert a vector size to a matrix size.""" + if isinstance(d, (float, int, np.generic, np.ndarray)): + n = (-1 + np.sqrt(1 + 8 * d)) / 2. + if float(int(n)) != n: + raise ValueError("Vector length is not a triangular number.") + return int(n) + else: + with ops.name_scope(name, "vector_size_to_square_matrix_size", [d]) as name: + n = (-1. + math_ops.sqrt(1 + 8. * math_ops.to_float(d))) / 2. + if validate_args: + with ops.control_dependencies([check_ops.assert_equal( + math_ops.to_float(math_ops.to_int32(n)), n, + message="Vector length is not a triangular number")]): + n = array_ops.identity(n) + return math_ops.cast(n, d.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index e656a258e56e71898ecb719dd2af876f158cf799..71e562a927a30a17d695b81c566f981db7553ad9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Gumbel", @@ -45,6 +46,14 @@ class Gumbel(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=0., scale=1., diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index 2bde956d1345129285acae4684256c5ac828b9a1..1504bd27204f728c0cb519159230e945128c4740 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -43,6 +44,14 @@ class Inline(bijector.Bijector): The above example is equivalent to the `Bijector` `Exp()`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, forward_fn=None, inverse_fn=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index 84a3289ba2160ed22a2bc7030dd612ba9ca6f6df..a648676d4b1956e5c27f67a71e6bd93d0d7fc97d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Invert", @@ -40,6 +41,14 @@ class Invert(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, bijector, validate_args=False, name=None): """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py index 97000c17262d3efdef10274711364c2bc2083bd4..33b75a04d34fdd01bc0d854d4e5b9c45a737b122 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "Kumaraswamy", @@ -44,6 +45,14 @@ class Kumaraswamy(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration1=None, concentration0=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 83667b0e80cfcc1c4f0617cdc739221f24439665..b8f2a4b2c731bdaee78692c036fb9f2fba4e3760 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import template as template_ops from tensorflow.python.ops import variable_scope as variable_scope_lib from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -186,6 +187,14 @@ class MaskedAutoregressiveFlow(bijector.Bijector): Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, shift_and_log_scale_fn, is_constant_jacobian=False, @@ -296,6 +305,14 @@ MASK_INCLUSIVE = "inclusive" MASK_EXCLUSIVE = "exclusive" +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): """Generate the slices for building an autoregressive mask.""" # TODO(b/67594795): Better support of dynamic shape. @@ -313,6 +330,14 @@ def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): return slices +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _gen_mask(num_blocks, n_in, n_out, @@ -327,6 +352,14 @@ def _gen_mask(num_blocks, return mask +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def masked_dense(inputs, units, num_blocks=None, @@ -399,6 +432,14 @@ def masked_dense(inputs, return layer.apply(inputs) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def masked_autoregressive_default_template( hidden_layers, shift_only=False, @@ -515,6 +556,14 @@ def masked_autoregressive_default_template( "masked_autoregressive_default_template", _fn) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): """Clips input while leaving gradient unaltered.""" with ops.name_scope(name, "clip_by_value_preserve_grad", diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py new file mode 100644 index 0000000000000000000000000000000000000000..49e6192f067edec4890dcfa107876a5104c14dd4 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/matrix_inverse_tril.py @@ -0,0 +1,154 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MatrixInverseTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + + +__all__ = [ + "MatrixInverseTriL", +] + + +class MatrixInverseTriL(bijector.Bijector): + """Computes `g(L) = inv(L)`, where `L` is a lower-triangular matrix. + + `L` must be nonsingular; equivalently, all diagonal entries of `L` must be + nonzero. + + The input must have `rank >= 2`. The input is treated as a batch of matrices + with batch shape `input.shape[:-2]`, where each matrix has dimensions + `input.shape[-2]` by `input.shape[-1]` (hence `input.shape[-2]` must equal + `input.shape[-1]`). + + #### Examples + + ```python + tfd.bijectors.MatrixInverseTriL().forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 0], [-2, 1]], i.e., inv(x) + + tfd.bijectors.MatrixInverseTriL().inverse(y=[[1., 0], [-2, 1]]) + # Result: [[1., 0], [2, 1]], i.e., inv(y). + ``` + + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, validate_args=False, name="matrix_inverse_tril"): + """Instantiates the `MatrixInverseTriL` bijector. + + Args: + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + super(MatrixInverseTriL, self).__init__( + forward_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + with ops.control_dependencies(self._assertions(x)): + shape = array_ops.shape(x) + return linalg_ops.matrix_triangular_solve( + x, linalg_ops.eye(shape[-1], batch_shape=shape[:-2]), lower=True) + + def _inverse(self, y): + return self._forward(y) + + def _forward_log_det_jacobian(self, x): + # Calculation of the Jacobian: + # + # Let X = (x_{ij}), 0 <= i,j < n, be a matrix of indeterminates. Let Z = + # X^{-1} where Z = (z_{ij}). Then + # + # dZ/dx_{ij} = (d/dt | t=0) Y(t)^{-1}, + # + # where Y(t) = X + t*E_{ij} and E_{ij} is the matrix with a 1 in the (i,j) + # entry and zeros elsewhere. By the product rule, + # + # 0 = d/dt [Identity matrix] + # = d/dt [Y Y^{-1}] + # = Y d/dt[Y^{-1}] + dY/dt Y^{-1} + # + # so + # + # d/dt[Y^{-1}] = -Y^{-1} dY/dt Y^{-1} + # = -Y^{-1} E_{ij} Y^{-1}. + # + # Evaluating at t=0, + # + # dZ/dx_{ij} = -Z E_{ij} Z. + # + # Taking the (r,s) entry of each side, + # + # dz_{rs}/dx_{ij} = -z_{ri}z_{sj}. + # + # Now, let J be the Jacobian dZ/dX, arranged as the n^2-by-n^2 matrix whose + # (r*n + s, i*n + j) entry is dz_{rs}/dx_{ij}. Considering J as an n-by-n + # block matrix with n-by-n blocks, the above expression for dz_{rs}/dx_{ij} + # shows that the block at position (r,i) is -z_{ri}Z. Hence + # + # J = -KroneckerProduct(Z, Z), + # det(J) = (-1)^(n^2) (det Z)^(2n) + # = (-1)^n (det X)^(-2n). + with ops.control_dependencies(self._assertions(x)): + return (-2. * math_ops.cast(array_ops.shape(x)[-1], x.dtype.base_dtype) * + math_ops.reduce_sum( + math_ops.log(math_ops.abs(array_ops.matrix_diag_part(x))), + axis=-1)) + + def _assertions(self, x): + if not self.validate_args: + return [] + shape = array_ops.shape(x) + is_matrix = check_ops.assert_rank_at_least( + x, 2, message="Input must have rank at least 2.") + is_square = check_ops.assert_equal( + shape[-2], shape[-1], message="Input must be a square matrix.") + above_diagonal = array_ops.matrix_band_part( + array_ops.matrix_set_diag( + x, array_ops.zeros(shape[:-1], dtype=dtypes.float32)), + 0, -1) + is_lower_triangular = check_ops.assert_equal( + above_diagonal, array_ops.zeros_like(above_diagonal), + message="Input must be lower triangular.") + # A lower triangular matrix is nonsingular iff all its diagonal entries are + # nonzero. + diag_part = array_ops.matrix_diag_part(x) + is_nonsingular = check_ops.assert_none_equal( + diag_part, array_ops.zeros_like(diag_part), + message="Input must have all diagonal entries nonzero.") + return [is_matrix, is_square, is_lower_triangular, is_nonsingular] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py index 3f03592f314cc13e8a9ea7e2ae18c5bb1f14e74f..fb393218b6b47764f45b5055bbf15cc17aba219e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/ordered.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -57,6 +58,14 @@ class Ordered(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="ordered"): super(Ordered, self).__init__( forward_min_event_ndims=1, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index 12a16a3f2ba3da53077307fd97d3f10d99b2c81f..f182a1adcbb6b11af2376cd271f903d50e50f1a0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -74,6 +75,14 @@ class Permute(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, permutation, validate_args=False, name=None): """Creates the `Permute` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py index 71f123f2a998458edaa9c8da07ea2932f62625ca..16264fe728a334db347304500767ce5876f9db7e 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -41,6 +42,14 @@ class PowerTransform(bijector.Bijector): This bijector is equivalent to the `Exp` bijector when `c=0`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, power=0., validate_args=False, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py index 66e8a5b9b356867424d1d47efaf848fc6903c371..773ae2446118051a61636bc21de6b81dfacda746 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import template as template_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -126,6 +127,14 @@ class RealNVP(bijector.Bijector): Processing Systems_, 2017. https://arxiv.org/abs/1705.07057 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, num_masked, shift_and_log_scale_fn, @@ -228,6 +237,14 @@ class RealNVP(bijector.Bijector): return math_ops.reduce_sum(log_scale, axis=-1) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def real_nvp_default_template( hidden_layers, shift_only=False, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 5497c422e4d51e259435692dac722f801e8844ac..c8282229a30fabff0c4c267d0bdfcdbce4f5f3d9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -36,10 +37,26 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _static_ndims_from_shape(shape): return shape.shape.with_rank_at_least(1)[0].value +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _ndims_from_shape(shape): return array_ops.shape(shape)[0] @@ -86,6 +103,14 @@ class Reshape(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, event_shape_out, event_shape_in=(-1,), validate_args=False, name=None): """Creates a `Reshape` bijector. diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py new file mode 100644 index 0000000000000000000000000000000000000000..6fbe8665781211ca803feb8bf5a8c04fb0b969e8 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/scale_tril.py @@ -0,0 +1,123 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ScaleTriL bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops.bijectors import affine_scalar +from tensorflow.contrib.distributions.python.ops.bijectors import chain +from tensorflow.contrib.distributions.python.ops.bijectors import fill_triangular +from tensorflow.contrib.distributions.python.ops.bijectors import softplus +from tensorflow.contrib.distributions.python.ops.bijectors import transform_diagonal +from tensorflow.python.util import deprecation + +__all__ = [ + "ScaleTriL", +] + + +class ScaleTriL(chain.Chain): + """Transforms unconstrained vectors to TriL matrices with positive diagonal. + + This is implemented as a simple `tfb.Chain` of `tfb.FillTriangular` + followed by `tfb.TransformDiagonal`, and provided mostly as a + convenience. The default setup is somewhat opinionated, using a + Softplus transformation followed by a small shift (`1e-5`) which + attempts to avoid numerical issues from zeros on the diagonal. + + #### Examples + + ```python + tfb = tf.contrib.distributions.bijectors + b = tfb.ScaleTriL( + diag_bijector=tfb.Exp(), + diag_shift=None) + b.forward(x=[0., 0., 0.]) + # Result: [[1., 0.], + # [0., 1.]] + b.inverse(y=[[1., 0], + [.5, 2]]) + # Result: [log(2), .5, log(1)] + + # Define a distribution over PSD matrices of shape `[3, 3]`, + # with `1 + 2 + 3 = 6` degrees of freedom. + dist = tfd.TransformedDistribution( + tfd.Normal(tf.zeros(6), tf.ones(6)), + tfb.Chain([tfb.CholeskyOuterProduct(), tfb.ScaleTriL()])) + + # Using an identity transformation, ScaleTriL is equivalent to + # tfb.FillTriangular. + b = tfb.ScaleTriL( + diag_bijector=tfb.Identity(), + diag_shift=None) + + # For greater control over initialization, one can manually encode + # pre- and post- shifts inside of `diag_bijector`. + b = tfb.ScaleTriL( + diag_bijector=tfb.Chain([ + tfb.AffineScalar(shift=1e-3), + tfb.Softplus(), + tfb.AffineScalar(shift=0.5413)]), # softplus_inverse(1.) + # = log(expm1(1.)) = 0.5413 + diag_shift=None) + ``` + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + diag_bijector=None, + diag_shift=1e-5, + validate_args=False, + name="scale_tril"): + """Instantiates the `ScaleTriL` bijector. + + Args: + diag_bijector: `Bijector` instance, used to transform the output diagonal + to be positive. + Default value: `None` (i.e., `tfb.Softplus()`). + diag_shift: Float value broadcastable and added to all diagonal entries + after applying the `diag_bijector`. Setting a positive + value forces the output diagonal entries to be positive, but + prevents inverting the transformation for matrices with + diagonal entries less than this value. + Default value: `1e-5` (i.e., no shift is applied). + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + Default value: `False` (i.e., arguments are not validated). + name: Python `str` name given to ops managed by this object. + Default value: `scale_tril`. + """ + + if diag_bijector is None: + diag_bijector = softplus.Softplus(validate_args=validate_args) + + if diag_shift is not None: + diag_bijector = chain.Chain([affine_scalar.AffineScalar(shift=diag_shift), + diag_bijector]) + + super(ScaleTriL, self).__init__( + [transform_diagonal.TransformDiagonal(diag_bijector=diag_bijector), + fill_triangular.FillTriangular()], + validate_args=validate_args, + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py index 5df8c886315ff75cdc884e3b9b4665fb64bb109d..194b318fce31a13f84e7b664b58cebb24fc9a264 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -31,6 +32,14 @@ __all__ = [ class Sigmoid(bijector.Bijector): """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="sigmoid"): super(Sigmoid, self).__init__( forward_min_event_ndims=0, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index 2a32e8abcde940b0056b0faf2955ec1b3bd71803..241fba2cb7ec33b7b02c1ca79051f1b826d7d2aa 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -26,12 +26,21 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ "SinhArcsinh", ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _sqrtx2p1(x): """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" return array_ops.where( @@ -88,6 +97,14 @@ class SinhArcsinh(bijector.Bijector): `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, skewness=None, tailweight=None, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index f52b91550edff7390d8094a4508d862674e85d59..20ee0d340833d5c5275e2ab52a89dcdf7198add1 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -60,6 +61,14 @@ class SoftmaxCentered(bijector.Bijector): makes the (forward) image non-open and the theorem does not directly apply. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="softmax_centered"): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py index 96a938c803418ff818f9c531754b47ba1eb8667a..3df84ef8b04c2c8f6be91ecd1c972ad1484b4285 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import bijector from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -80,6 +81,14 @@ class Softplus(bijector.Bijector): "hinge_softness": ( "Nonzero floating point `Tensor`. Controls the softness of what " "would otherwise be a kink at the origin. Default is 1.0")}) + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, hinge_softness=None, validate_args=False, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py index b4a658c171b8313358754228aabbfa4bf93fd84d..f96a4bb01de59a21107b9e7c14f929e13e358ac9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softsign.py @@ -22,6 +22,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -51,6 +52,14 @@ class Softsign(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="softsign"): super(Softsign, self).__init__( forward_min_event_ndims=0, diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/square.py b/tensorflow/contrib/distributions/python/ops/bijectors/square.py index 2ccfdc95970e387e708603e2614ad29fb6a18db3..294460a80f6209797831ea361e64efe677f71e59 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/square.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/square.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -49,6 +50,14 @@ class Square(bijector.Bijector): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, validate_args=False, name="square"): """Instantiates the `Square` bijector. @@ -81,4 +90,3 @@ class Square(bijector.Bijector): is_valid = check_ops.assert_non_negative( t, message="All elements must be non-negative.") return control_flow_ops.with_dependencies([is_valid], t) - diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py new file mode 100644 index 0000000000000000000000000000000000000000..9b7a3b026b8dcc31bed49c489d77b9c184f463cb --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/transform_diagonal.py @@ -0,0 +1,111 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TransformDiagonal bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation + +__all__ = [ + "TransformDiagonal", +] + + +class TransformDiagonal(bijector.Bijector): + """Applies a Bijector to the diagonal of a matrix. + + #### Example + + ```python + b = tfb.TransformDiagonal(diag_bijector=tfb.Exp()) + + b.forward([[1., 0.], + [0., 1.]]) + # ==> [[2.718, 0.], + [0., 2.718]] + ``` + + """ + + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) + def __init__(self, + diag_bijector, + validate_args=False, + name="transform_diagonal"): + """Instantiates the `TransformDiagonal` bijector. + + Args: + diag_bijector: `Bijector` instance used to transform the diagonal. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._diag_bijector = diag_bijector + super(TransformDiagonal, self).__init__( + forward_min_event_ndims=2, + inverse_min_event_ndims=2, + validate_args=validate_args, + name=name) + + def _forward(self, x): + diag = self._diag_bijector.forward(array_ops.matrix_diag_part(x)) + return array_ops.matrix_set_diag(x, diag) + + def _inverse(self, y): + diag = self._diag_bijector.inverse(array_ops.matrix_diag_part(y)) + return array_ops.matrix_set_diag(y, diag) + + def _forward_log_det_jacobian(self, x): + # We formulate the Jacobian with respect to the flattened matrices + # `vec(x)` and `vec(y)`. Suppose for notational convenience that + # the first `n` entries of `vec(x)` are the diagonal of `x`, and + # the remaining `n**2-n` entries are the off-diagonals in + # arbitrary order. Then the Jacobian is a block-diagonal matrix, + # with the Jacobian of the diagonal bijector in the first block, + # and the identity Jacobian for the remaining entries (since this + # bijector acts as the identity on non-diagonal entries): + # + # J_vec(x) (vec(y)) = + # ------------------------------- + # | J_diag(x) (diag(y)) 0 | n entries + # | | + # | 0 I | n**2-n entries + # ------------------------------- + # n n**2-n + # + # Since the log-det of the second (identity) block is zero, the + # overall log-det-jacobian is just the log-det of first block, + # from the diagonal bijector. + # + # Note that for elementwise operations (exp, softplus, etc) the + # first block of the Jacobian will itself be a diagonal matrix, + # but our implementation does not require this to be true. + return self._diag_bijector.forward_log_det_jacobian( + array_ops.matrix_diag_part(x), event_ndims=1) + + def _inverse_log_det_jacobian(self, y): + return self._diag_bijector.inverse_log_det_jacobian( + array_ops.matrix_diag_part(y), event_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py index a22560fe80298b762795e7b0e7aea2db55823065..8903a70d98ae144731b12047e5074d0450b59378 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py @@ -24,6 +24,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.util import deprecation __all__ = [ @@ -47,6 +48,14 @@ class Weibull(bijector.Bijector): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, scale=1., concentration=1., diff --git a/tensorflow/contrib/distributions/python/ops/binomial.py b/tensorflow/contrib/distributions/python/ops/binomial.py index 24b26bf124c78c8320b9a6bc3b900e6c7a93f5e4..b349e5966dd750fdf96c0b211dce02658c9400b7 100644 --- a/tensorflow/contrib/distributions/python/ops/binomial.py +++ b/tensorflow/contrib/distributions/python/ops/binomial.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation _binomial_sample_note = """ @@ -42,6 +43,14 @@ to integer values. """ +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _bdtr(k, n, p): """The binomial cumulative distribution function. @@ -130,6 +139,14 @@ class Binomial(distribution.Distribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, total_count, logits=None, @@ -163,7 +180,7 @@ class Binomial(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._total_count = self._maybe_assert_valid_total_count( ops.convert_to_tensor(total_count, name="total_count"), diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index f5ffdd873124d6626dca26f603592bd0b030d7b3..cb5223b0557080e10bf24c3e1cb432f15fd5e7e3 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -29,7 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ "Cauchy", @@ -93,6 +93,14 @@ class Cauchy(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -121,7 +129,7 @@ class Cauchy(distribution.Distribution): Raises: TypeError: if `loc` and `scale` have different `dtype`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index 08cdc1582892cc7d308bd60f082dde082704f57f..e9a7b39070f3d76693ad54852ed0847a0980d2a6 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -25,7 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import gamma -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -64,6 +64,14 @@ class Chi2(gamma.Gamma): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, validate_args=False, @@ -84,7 +92,7 @@ class Chi2(gamma.Gamma): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) # Even though all stats of chi2 are defined for valid parameters, this is # not true in the parent class "gamma." therefore, passing # allow_nan_stats=True @@ -115,12 +123,20 @@ class Chi2(gamma.Gamma): class Chi2WithAbsDf(Chi2): """Chi2 with parameter transform `df = floor(abs(df))`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, validate_args=False, allow_nan_stats=True, name="Chi2WithAbsDf"): - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[df]) as name: super(Chi2WithAbsDf, self).__init__( df=math_ops.floor( diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 10b45361358b40a3c8fd725f27ad84ef9b8a37f5..3598c8d23ea9007fb359ae4931738fb61ede4ccc 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -20,7 +20,6 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import conditional_distribution from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import transformed_distribution @@ -106,7 +105,7 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian( y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access @@ -131,7 +130,7 @@ class ConditionalTransformedDistribution( bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) - event_ndims = self._maybe_get_event_ndims_statically() + event_ndims = self._maybe_get_static_event_ndims() ildj = self.bijector.inverse_log_det_jacobian( y, event_ndims=event_ndims, **bijector_kwargs) if self.bijector._is_injective: # pylint: disable=protected-access @@ -220,14 +219,14 @@ class ConditionalTransformedDistribution( inv_cdf = self.distribution.quantile(value, **distribution_kwargs) return self.bijector.forward(inv_cdf, **bijector_kwargs) - def _maybe_get_event_ndims_statically(self): + def _maybe_get_static_event_ndims(self): if self.event_shape.ndims is not None: return self.event_shape.ndims event_ndims = array_ops.size(self.event_shape_tensor()) - static_event_ndims = tensor_util.constant_value(event_ndims) + event_ndims_ = distribution_util.maybe_get_static_value(event_ndims) - if static_event_ndims is not None: - return static_event_ndims + if event_ndims_ is not None: + return event_ndims_ return event_ndims diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index 6d7d6d307bd0f815344c8a0e347f45ae11ba6462..ad853ee293f86565c1af601214522f53d936b70a 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -32,7 +32,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ "Deterministic", @@ -44,6 +44,14 @@ __all__ = [ class _BaseDeterministic(distribution.Distribution): """Base class for Deterministic distributions.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, @@ -87,7 +95,7 @@ class _BaseDeterministic(distribution.Distribution): Raises: ValueError: If `loc` is a scalar. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, atol, rtol]) as name: loc = ops.convert_to_tensor(loc, name="loc") if is_vector and validate_args: @@ -204,6 +212,14 @@ class Deterministic(_BaseDeterministic): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, @@ -309,6 +325,14 @@ class VectorDeterministic(_BaseDeterministic): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, atol=None, diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 289e1d50e1146a641c0cc433ece3465aed73b1c2..6959b3e8775d2dd488b4ee3252d143ef376d58f9 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -21,12 +21,19 @@ from __future__ import print_function from tensorflow.contrib import linalg from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib + +# The following two lines are redundant, in a sense. The first enables +# good coding practice *within* this file (`util.prefer_static_value` +# rather than `prefer_static_value`). The second ensures that users +# also get the core utils when they import this file. +from tensorflow.python.ops.distributions import util from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import @@ -484,3 +491,75 @@ def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution, def static_value(x): """Returns the static value of a `Tensor` or `None`.""" return tensor_util.constant_value(ops.convert_to_tensor(x)) + + +def move_dimension(x, source_idx, dest_idx): + """Move a single tensor dimension within its shape. + + This is a special case of `tf.transpose()`, which applies + arbitrary permutations to tensor dimensions. + + Args: + x: Tensor of rank `ndims`. + source_idx: Integer index into `x.shape` (negative indexing is + supported). + dest_idx: Integer index into `x.shape` (negative indexing is + supported). + + Returns: + x_perm: Tensor of rank `ndims`, in which the dimension at original + index `source_idx` has been moved to new index `dest_idx`, with + all other dimensions retained in their original order. + + Example: + + ```python + x = tf.placeholder(shape=[200, 30, 4, 1, 6]) + x_perm = _move_dimension(x, 1, 1) # no-op + x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6] + x_perm = _move_dimension(x, 0, -2) # equivalent to previous + x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1] + ``` + """ + ndims = util.prefer_static_rank(x) + if isinstance(source_idx, int): + dtype = dtypes.int32 + else: + dtype = dtypes.as_dtype(source_idx.dtype) + + # Handle negative indexing. Since ndims might be dynamic, this makes + # source_idx and dest_idx also possibly dynamic. + if source_idx < 0: + source_idx = ndims + source_idx + if dest_idx < 0: + dest_idx = ndims + dest_idx + + # Construct the appropriate permutation of dimensions, depending + # whether the source is before or after the destination. + def move_left_permutation(): + return util.prefer_static_value( + array_ops.concat([ + math_ops.range(0, dest_idx, dtype=dtype), + [source_idx], + math_ops.range(dest_idx, source_idx, dtype=dtype), + math_ops.range(source_idx+1, ndims, dtype=dtype)], axis=0)) + + def move_right_permutation(): + return util.prefer_static_value( + array_ops.concat([ + math_ops.range(0, source_idx, dtype=dtype), + math_ops.range(source_idx+1, dest_idx+1, dtype=dtype), + [source_idx], + math_ops.range(dest_idx+1, ndims, dtype=dtype)], axis=0)) + + def x_permuted(): + return array_ops.transpose( + x, perm=smart_cond.smart_cond(source_idx < dest_idx, + move_right_permutation, + move_left_permutation)) + + # One final conditional to handle the special case where source + # and destination indices are equal. + return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx), + lambda: x, + x_permuted) diff --git a/tensorflow/contrib/distributions/python/ops/estimator.py b/tensorflow/contrib/distributions/python/ops/estimator.py index 98edd337fe02ffbf53c6ecd9ebda9424231ea2fe..bdec6527d5378d6e86aa8e6279cc6ee672083e56 100644 --- a/tensorflow/contrib/distributions/python/ops/estimator.py +++ b/tensorflow/contrib/distributions/python/ops/estimator.py @@ -23,6 +23,7 @@ from tensorflow.contrib.learn.python.learn.estimators.head import _RegressionHea from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.util import deprecation __all__ = [ @@ -30,6 +31,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def estimator_head_distribution_regression(make_distribution_fn, label_dimension=1, logits_dimension=None, @@ -77,6 +86,14 @@ def estimator_head_distribution_regression(make_distribution_fn, class _DistributionRegressionHead(_RegressionHead): """Creates a _RegressionHead instance from an arbitrary `Distribution`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, make_distribution_fn, label_dimension, diff --git a/tensorflow/contrib/distributions/python/ops/geometric.py b/tensorflow/contrib/distributions/python/ops/geometric.py index 446cff6ec242f25178fed0c6a424791fa9f176ad..d62f024aa2a081f0ec231015af1f26a8851518e9 100644 --- a/tensorflow/contrib/distributions/python/ops/geometric.py +++ b/tensorflow/contrib/distributions/python/ops/geometric.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Geometric(distribution.Distribution): @@ -55,6 +56,14 @@ class Geometric(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, logits=None, probs=None, @@ -85,7 +94,7 @@ class Geometric(distribution.Distribution): name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index ed9ea6f4f3ffe18fb6bf1e0a7d57728d010e0f01..acdea4d61d3ada7e9f4f0aa7bc58c5643db2802b 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -29,7 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class _Gumbel(distribution.Distribution): @@ -97,6 +97,14 @@ class _Gumbel(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -125,7 +133,7 @@ class _Gumbel(distribution.Distribution): Raises: TypeError: if loc and scale are different dtypes. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py index 7e12767f6d8f6c61565ecf266d3b222de68c0e40..b02c4031069191592b8acc1a90313450f98af6d7 100644 --- a/tensorflow/contrib/distributions/python/ops/half_normal.py +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import special_math -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -86,6 +86,14 @@ class HalfNormal(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, scale, validate_args=False, @@ -106,7 +114,7 @@ class HalfNormal(distribution.Distribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index fa89fff3b7b2f8266a44c446a0c9807790b3aed8..0672702b96c1eb81c176774554df3f5922a0319e 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -29,7 +29,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import kullback_leibler -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Independent(distribution_lib.Distribution): @@ -95,6 +95,14 @@ class Independent(distribution_lib.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, distribution, reinterpreted_batch_ndims=None, validate_args=False, name=None): @@ -117,7 +125,7 @@ class Independent(distribution_lib.Distribution): ValueError: if `reinterpreted_batch_ndims` exceeds `distribution.batch_ndims` """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) name = name or "Independent" + distribution.name self._distribution = distribution with ops.name_scope(name) as name: @@ -259,6 +267,14 @@ class Independent(distribution_lib.Distribution): @kullback_leibler.RegisterKL(Independent, Independent) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_independent(a, b, name="kl_independent"): """Batched KL divergence `KL(a || b)` for Independent distributions. diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 85e8e10466038e5e55ef4b754f82c0c2c2543b6d..70d050d7a647b38928ddb1c788db0e6957ac0f03 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -95,6 +96,14 @@ class InverseGamma(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration, rate, @@ -125,7 +134,7 @@ class InverseGamma(distribution.Distribution): Raises: TypeError: if `concentration` and `rate` are different dtypes. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration, rate]) as name: with ops.control_dependencies([ check_ops.assert_positive(concentration), @@ -274,13 +283,21 @@ class InverseGamma(distribution.Distribution): class InverseGammaWithSoftplusConcentrationRate(InverseGamma): """`InverseGamma` with softplus of `concentration` and `rate`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration, rate, validate_args=False, allow_nan_stats=True, name="InverseGammaWithSoftplusConcentrationRate"): - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[concentration, rate]) as name: super(InverseGammaWithSoftplusConcentrationRate, self).__init__( concentration=nn.softplus(concentration, diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 66682b2ff5493f8565410138e770b45ffc6b5d77..e3712dd84e36609d6bba4a5a39866046c0c8d1d8 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import uniform -from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.util import deprecation __all__ = [ "Kumaraswamy", @@ -41,6 +41,14 @@ _kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in `[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _harmonic_number(x): """Compute the harmonic number from its analytic continuation. @@ -59,7 +67,6 @@ def _harmonic_number(x): return math_ops.digamma(x + one) - math_ops.digamma(one) -@tf_export("distributions.Kumaraswamy") class Kumaraswamy(transformed_distribution.TransformedDistribution): """Kumaraswamy distribution. @@ -125,6 +132,14 @@ class Kumaraswamy(transformed_distribution.TransformedDistribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, concentration1=None, concentration0=None, diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 0103283259b0526b5a108ea1836f95709eedc067..02e3bad51ee48188acf83cb09359861c9e6932c7 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -31,7 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Logistic(distribution.Distribution): @@ -92,6 +92,14 @@ class Logistic(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -120,7 +128,7 @@ class Logistic(distribution.Distribution): Raises: TypeError: if loc and scale are different dtypes. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: with ops.control_dependencies([check_ops.assert_positive(scale)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index d54f30dc634ab5c8aa82066056266747b63eec21..3b7114ef067c0aaede23fff04c40d1dc6e830f1c 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -32,6 +32,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import categorical from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class Mixture(distribution.Distribution): @@ -66,6 +67,14 @@ class Mixture(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, cat, components, @@ -116,7 +125,7 @@ class Mixture(distribution.Distribution): matching static batch shapes, or all components do not have matching static event shapes. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) if not isinstance(cat, categorical.Categorical): raise TypeError("cat must be a Categorical distribution, but saw: %s" % cat) diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index c7c90cf875484a1753577227bf22de878d00a502..8ffee940d03c9a5204f2ac6f7acd9ea482adae1a 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class MixtureSameFamily(distribution.Distribution): @@ -95,6 +96,14 @@ class MixtureSameFamily(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, mixture_distribution, components_distribution, @@ -130,7 +139,7 @@ class MixtureSameFamily(distribution.Distribution): ValueError: if `mixture_distribution` categories does not equal `components_distribution` rightmost batch shape. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name) as name: self._mixture_distribution = mixture_distribution self._components_distribution = components_distribution @@ -321,6 +330,14 @@ class MixtureSameFamily(distribution.Distribution): return x +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _outer_squared_difference(x, y): """Convenience function analogous to tf.squared_difference.""" z = x - y diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index cad398582b9c939e8e96cf498638869ccd3701bd..cd0c282ba6cebf784261a4e821f36ce4eed98fe0 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -22,6 +22,7 @@ from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops from tensorflow.python.ops import nn +from tensorflow.python.util import deprecation __all__ = [ @@ -134,6 +135,14 @@ class MultivariateNormalDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -193,7 +202,7 @@ class MultivariateNormalDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): @@ -218,13 +227,21 @@ class MultivariateNormalDiag( class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag): """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`.""" + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale_diag, validate_args=False, allow_nan_stats=True, name="MultivariateNormalDiagWithSoftplusScale"): - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[scale_diag]) as name: super(MultivariateNormalDiagWithSoftplusScale, self).__init__( loc=loc, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 1c11594df3ad2612dd8746bb8785d86390b69937..d8401801f21afbe8fd042053c6a38a31a2539438 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -22,6 +22,7 @@ from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -141,6 +142,14 @@ class MultivariateNormalDiagPlusLowRank( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -215,7 +224,7 @@ class MultivariateNormalDiagPlusLowRank( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 47d7d13cf357f1ac657641420602c92eefdad197..dbc4c1b3dc956641f3e38ffafe3a3410bd3e2097 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -24,7 +24,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -113,6 +113,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, covariance_matrix=None, @@ -156,7 +164,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): Raises: ValueError: if neither `loc` nor `covariance_matrix` are specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) # Convert the covariance_matrix up to a scale_tril and call MVNTriL. with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 79916fef8d7b752649dcc673a84ea45ccf460905..efe5a6d0d99ca8fa9e0274049423bb3c4eef2d6f 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -27,6 +27,7 @@ from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = [ @@ -133,6 +134,14 @@ class MultivariateNormalLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, @@ -170,7 +179,7 @@ class MultivariateNormalLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: @@ -266,6 +275,14 @@ class MultivariateNormalLinearOperator( @kullback_leibler.RegisterKL(MultivariateNormalLinearOperator, MultivariateNormalLinearOperator) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_brute_force(a, b, name=None): """Batched KL divergence `KL(a || b)` for multivariate Normals. diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index d6b0ed994ec0a62e9b7684e7478130052a1fd300..d9110947ecdbba1a63669573f46db17b02e512ab 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -22,6 +22,7 @@ from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop from tensorflow.python.framework import ops from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ @@ -134,6 +135,14 @@ class MultivariateNormalTriL( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_tril=None, @@ -179,7 +188,7 @@ class MultivariateNormalTriL( Raises: ValueError: if neither `loc` nor `scale_tril` are specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) def _convert_to_tensor(x, name): return None if x is None else ops.convert_to_tensor(x, name=name) if loc is None and scale_tril is None: diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py index 1085c56dc86c8d45bdab2e7cecedf44663e5c408..6acfc5746a0cc20e916de81b71f90e08d8d91ad5 100644 --- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py +++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class NegativeBinomial(distribution.Distribution): @@ -51,6 +52,14 @@ class NegativeBinomial(distribution.Distribution): * `n!` is the factorial of `n`. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, total_count, logits=None, @@ -90,7 +99,7 @@ class NegativeBinomial(distribution.Distribution): name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[total_count, logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( logits, probs, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py index a4b9f3b78d4fdcc328bac84623114b921b9ded49..0c762f17c9b770ecada57b6ce60a4825ba374dd9 100644 --- a/tensorflow/contrib/distributions/python/ops/onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/onehot_categorical.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class OneHotCategorical(distribution.Distribution): @@ -83,6 +84,14 @@ class OneHotCategorical(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, logits=None, @@ -115,7 +124,7 @@ class OneHotCategorical(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( name=name, logits=logits, probs=probs, validate_args=validate_args, @@ -233,6 +242,14 @@ class OneHotCategorical(distribution.Distribution): @kullback_leibler.RegisterKL(OneHotCategorical, OneHotCategorical) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _kl_categorical_categorical(a, b, name=None): """Calculate the batched KL divergence KL(a || b) with a, b OneHotCategorical. diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index b34539402102b8f289d4eb289fcb82f4030f4e8c..3d055085cc7386e57a71aa310458b7666bb9a396 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = [ "Poisson", @@ -65,6 +66,14 @@ class Poisson(distribution.Distribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, rate=None, log_rate=None, @@ -93,7 +102,7 @@ class Poisson(distribution.Distribution): TypeError: if `rate` is not a float-type. TypeError: if `log_rate` is not a float-type. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[rate]) as name: if (rate is None) == (log_rate is None): raise ValueError("Must specify exactly one of `rate` and `log_rate`.") diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index fe72091d7d759e54c51eb666f2ceacc8371e55fd..7a7ad1be35b80ff0f000181ea0778ab282a8220f 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -33,6 +33,7 @@ from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.distributions import transformed_distribution as transformed_lib +from tensorflow.python.util import deprecation __all__ = [ @@ -42,6 +43,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_lognormal_gauss_hermite( loc, scale, quadrature_size, validate_args=False, name=None): # pylint: disable=unused-argument @@ -85,6 +94,14 @@ def quadrature_scheme_lognormal_gauss_hermite( return grid, probs +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_lognormal_quantiles( loc, scale, quadrature_size, validate_args=False, name=None): @@ -214,6 +231,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): validate_args=True) """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -255,7 +280,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): TypeError: if `quadrature_grid` and `quadrature_probs` have different base `dtype`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale]) as name: if loc is not None: loc = ops.convert_to_tensor(loc, name="loc") @@ -417,6 +442,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): axis=[-2, -1]) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [distribution_util.static_value(x) for x in args] diff --git a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py index 584d2c385fced95ec496bb8dae9556e5c376b66d..ef3bdfa75fcaa8df17db1238ceadadf788601356 100644 --- a/tensorflow/contrib/distributions/python/ops/quantized_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/quantized_distribution.py @@ -27,10 +27,19 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import distribution as distributions from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation __all__ = ["QuantizedDistribution"] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def _logsum_expbig_minus_expsmall(big, small): """Stable evaluation of `Log[exp{big} - exp{small}]`. @@ -228,6 +237,14 @@ class QuantizedDistribution(distributions.Distribution): https://arxiv.org/abs/1711.10433 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, distribution, low=None, @@ -263,7 +280,7 @@ class QuantizedDistribution(distributions.Distribution): `Distribution` or continuous. NotImplementedError: If the base distribution does not implement `cdf`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) values = ( list(distribution.parameters.values()) + [low, high]) diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py index 0362996e684fb34b15cd98a2fc40df58087fbe95..7e1f64dc425e6a576bfbe1bb456901fddfac26e1 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_bernoulli.py @@ -19,15 +19,16 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import logistic +from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid # Bijectors must be directly imported because `remove_undocumented` prevents # individual file imports. -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class RelaxedBernoulli(transformed_distribution.TransformedDistribution): @@ -131,6 +132,14 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): Gumbel-Softmax. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, temperature, logits=None, @@ -165,7 +174,7 @@ class RelaxedBernoulli(transformed_distribution.TransformedDistribution): Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs, temperature]) as name: with ops.control_dependencies([check_ops.assert_positive(temperature)] if validate_args else []): diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 910c430ae7f026a3ac9ce50d1d5936d4454cba41..9b5bd7576f2a3c364e21da76dd3905a8c6e35829 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -31,6 +31,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class ExpRelaxedOneHotCategorical(distribution.Distribution): @@ -125,6 +126,14 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): A Continuous Relaxation of Discrete Random Variables. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, temperature, @@ -162,7 +171,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[logits, probs, temperature]) as name: self._logits, self._probs = distribution_util.get_logits_and_probs( @@ -368,6 +377,14 @@ class RelaxedOneHotCategorical( A Continuous Relaxation of Discrete Random Variables. 2016. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__( self, temperature, diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py index 6a7f28713acefd2285b07a212e2e47a6db1ae5e1..4f348be2806aa3ade7c1ea2a7bc68ca26db6447f 100644 --- a/tensorflow/contrib/distributions/python/ops/shape.py +++ b/tensorflow/contrib/distributions/python/ops/shape.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util import deprecation class _DistributionShape(object): @@ -166,6 +167,14 @@ class _DistributionShape(object): "free," i.e., during graph construction. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, batch_ndims=None, event_ndims=None, diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index f04dc8da39140240edbe4efb75de30e321436d55..a9d0fb4ccfb1803873f7fe17089f3e7c7f10f4b7 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation __all__ = [ "SinhArcsinh", @@ -94,6 +95,14 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc, scale, @@ -132,7 +141,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[loc, scale, skewness, tailweight]) as name: diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py index 9c69435fac109914ff29b307dfad105f62849339..c25e8c51d7705b641699fb05623c7b0fb4950e1b 100644 --- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py +++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py @@ -140,6 +140,7 @@ __all__ = [ "assert_true_mean_equal_by_dkwm", "min_discrepancy_of_true_means_detectable_by_dkwm", "min_num_samples_for_dkwm_mean_test", + "assert_true_mean_in_interval_by_dkwm", "assert_true_mean_equal_by_dkwm_two_sample", "min_discrepancy_of_true_means_detectable_by_dkwm_two_sample", "min_num_samples_for_dkwm_mean_two_sample_test", @@ -209,17 +210,17 @@ def _maximum_mean(samples, envelope, high, name=None): separately. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `envelope` and `high`. - envelope: Floating-point tensor of sizes of admissible CDF + envelope: Floating-point `Tensor` of sizes of admissible CDF envelopes (i.e., the `eps` above). - high: Floating-point tensor of upper bounds on the distributions' - supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. `samples <= high`. name: A name for this operation (optional). Returns: - bound: Floating-point tensor of upper bounds on the true means. + bound: Floating-point `Tensor` of upper bounds on the true means. Raises: InvalidArgumentError: If some `sample` is found to be larger than @@ -254,17 +255,17 @@ def _minimum_mean(samples, envelope, low, name=None): separately. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `envelope` and `low`. - envelope: Floating-point tensor of sizes of admissible CDF + envelope: Floating-point `Tensor` of sizes of admissible CDF envelopes (i.e., the `eps` above). - low: Floating-point tensor of lower bounds on the distributions' - supports. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. `samples >= low`. name: A name for this operation (optional). Returns: - bound: Floating-point tensor of lower bounds on the true means. + bound: Floating-point `Tensor` of lower bounds on the true means. Raises: InvalidArgumentError: If some `sample` is found to be smaller than @@ -300,12 +301,12 @@ def _dkwm_cdf_envelope(n, error_rate, name=None): probability above. Args: - n: Tensor of numbers of samples drawn. - error_rate: Floating-point tensor of admissible rates of mistakes. + n: `Tensor` of numbers of samples drawn. + error_rate: Floating-point `Tensor` of admissible rates of mistakes. name: A name for this operation (optional). Returns: - eps: Tensor of maximum distances the true CDF can be from the + eps: `Tensor` of maximum distances the true CDF can be from the empirical CDF. This scales as `O(sqrt(-log(error_rate)))` and as `O(1 / sqrt(n))`. The shape is the broadcast of `n` and `error_rate`. @@ -324,8 +325,8 @@ def _check_shape_dominates(samples, parameters): sample counts end up inflated. Args: - samples: A Tensor whose shape is to be protected against broadcasting. - parameters: A list of Tensors who are parameters for the statistical test. + samples: A `Tensor` whose shape is to be protected against broadcasting. + parameters: A list of `Tensor`s who are parameters for the statistical test. Returns: samples: Return original `samples` with control dependencies attached @@ -369,19 +370,23 @@ def true_mean_confidence_interval_by_dkwm( members. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low` and `high`. - low: Floating-point tensor of lower bounds on the distributions' + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' supports. - high: Floating-point tensor of upper bounds on the distributions' + high: Floating-point `Tensor` of upper bounds on the distributions' supports. - error_rate: *Scalar* admissible total rate of mistakes. + error_rate: *Scalar* floating-point `Tensor` admissible total rate + of mistakes. name: A name for this operation (optional). Returns: - low: A floating-point tensor of stochastic lower bounds on the true means. - high: A floating-point tensor of stochastic upper bounds on the true means. + low: A floating-point `Tensor` of stochastic lower bounds on the + true means. + high: A floating-point `Tensor` of stochastic upper bounds on the + true means. """ with ops.name_scope( name, "true_mean_confidence_interval_by_dkwm", @@ -436,15 +441,17 @@ def assert_true_mean_equal_by_dkwm( the assertion will insist on stronger evidence to fail any one member. Args: - samples: Floating-point tensor of samples from the distribution(s) + samples: Floating-point `Tensor` of samples from the distribution(s) of interest. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low` and `high`. - low: Floating-point tensor of lower bounds on the distributions' + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' supports. - high: Floating-point tensor of upper bounds on the distributions' + high: Floating-point `Tensor` of upper bounds on the distributions' supports. - expected: Floating-point tensor of expected true means. - false_fail_rate: *Scalar* admissible total rate of mistakes. + expected: Floating-point `Tensor` of expected true means. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. name: A name for this operation (optional). Returns: @@ -454,20 +461,8 @@ def assert_true_mean_equal_by_dkwm( with ops.name_scope( name, "assert_true_mean_equal_by_dkwm", [samples, low, high, expected, false_fail_rate]): - samples = ops.convert_to_tensor(samples, name="samples") - low = ops.convert_to_tensor(low, name="low") - high = ops.convert_to_tensor(high, name="high") - expected = ops.convert_to_tensor(expected, name="expected") - false_fail_rate = ops.convert_to_tensor( - false_fail_rate, name="false_fail_rate") - samples = _check_shape_dominates(samples, [low, high, expected]) - min_mean, max_mean = true_mean_confidence_interval_by_dkwm( - samples, low, high, error_rate=false_fail_rate) - less_op = check_ops.assert_less( - min_mean, expected, message="Mean confidence interval too high") - with ops.control_dependencies([less_op]): - return check_ops.assert_greater( - max_mean, expected, message="Mean confidence interval too low") + return assert_true_mean_in_interval_by_dkwm( + samples, low, high, expected, expected, false_fail_rate) def min_discrepancy_of_true_means_detectable_by_dkwm( @@ -487,30 +482,35 @@ def min_discrepancy_of_true_means_detectable_by_dkwm( with the same `false_pass_rate`. Args: - n: Tensor of numbers of samples to be drawn from the distributions + n: `Tensor` of numbers of samples to be drawn from the distributions of interest. - low: Floating-point tensor of lower bounds on the distributions' + low: Floating-point `Tensor` of lower bounds on the distributions' supports. - high: Floating-point tensor of upper bounds on the distributions' + high: Floating-point `Tensor` of upper bounds on the distributions' supports. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - discr: Tensor of lower bounds on the distances between true + discr: `Tensor` of lower bounds on the distances between true means detectable by a DKWM-based test. For each batch member `i`, of `K` total, drawing `n[i]` samples from some scalar distribution supported on `[low[i], high[i]]` is enough to detect a difference in means of size `discr[i]` or more. Specifically, we guarantee that (a) if the true mean is the expected - mean, `assert_true_mean_equal_by_dkwm` will fail with probability at - most `false_fail_rate / K` (which amounts to `false_fail_rate` if - applied to the whole batch at once), and (b) if the true mean - differs from the expected mean by at least `discr[i]`, - `assert_true_mean_equal_by_dkwm` will pass with probability at most - `false_pass_rate`. + mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with + probability at most `false_fail_rate / K` (which amounts to + `false_fail_rate` if applied to the whole batch at once), and (b) if + the true mean differs from the expected mean (resp. falls outside + the expected interval) by at least `discr[i]`, + `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with + probability at most `false_pass_rate`. The detectable discrepancy scales as @@ -558,17 +558,19 @@ def min_num_samples_for_dkwm_mean_test( on a scalar distribution supported on `[low, high]`. Args: - discrepancy: Floating-point tensor of desired upper limits on mean + discrepancy: Floating-point `Tensor` of desired upper limits on mean differences that may go undetected with probability higher than `1 - false_pass_rate`. - low: Tensor of lower bounds on the distributions' support. - high: Tensor of upper bounds on the distributions' support. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + low: `Tensor` of lower bounds on the distributions' support. + high: `Tensor` of upper bounds on the distributions' support. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - n: Tensor of numbers of samples to be drawn from the distributions + n: `Tensor` of numbers of samples to be drawn from the distributions of interest. The `discrepancy`, `low`, and `high` tensors must have @@ -578,12 +580,15 @@ def min_num_samples_for_dkwm_mean_test( some scalar distribution supported on `[low[i], high[i]]` is enough to detect a difference in means of size `discrepancy[i]` or more. Specifically, we guarantee that (a) if the true mean is the expected - mean, `assert_true_mean_equal_by_dkwm` will fail with probability at - most `false_fail_rate / K` (which amounts to `false_fail_rate` if - applied to the whole batch at once), and (b) if the true mean - differs from the expected mean by at least `discrepancy[i]`, - `assert_true_mean_equal_by_dkwm` will pass with probability at most - `false_pass_rate`. + mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with + probability at most `false_fail_rate / K` (which amounts to + `false_fail_rate` if applied to the whole batch at once), and (b) if + the true mean differs from the expected mean (resp. falls outside + the expected interval) by at least `discrepancy[i]`, + `assert_true_mean_equal_by_dkwm` + (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with + probability at most `false_pass_rate`. The required number of samples scales as `O((high[i] - low[i])**2)`, `O(-log(false_fail_rate/K))`, @@ -610,6 +615,76 @@ def min_num_samples_for_dkwm_mean_test( return math_ops.maximum(n1, n2) +def assert_true_mean_in_interval_by_dkwm( + samples, low, high, expected_low, expected_high, + false_fail_rate=1e-6, name=None): + """Asserts the mean of the given distribution is in the given interval. + + More precisely, fails if there is enough evidence (using the + [Dvoretzky-Kiefer-Wolfowitz-Massart inequality] + (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval)) + that the mean of the distribution from which the given samples are + drawn is _outside_ the given interval with statistical significance + `false_fail_rate` or stronger, otherwise passes. If you also want + to check that you are gathering enough evidence that a pass is not + spurious, see `min_num_samples_for_dkwm_mean_test` and + `min_discrepancy_of_true_means_detectable_by_dkwm`. + + Note that `false_fail_rate` is a total false failure rate for all + the assertions in the batch. As such, if the batch is nontrivial, + the assertion will insist on stronger evidence to fail any one member. + + Args: + samples: Floating-point `Tensor` of samples from the distribution(s) + of interest. Entries are assumed IID across the 0th dimension. + The other dimensions must broadcast with `low` and `high`. + The support is bounded: `low <= samples <= high`. + low: Floating-point `Tensor` of lower bounds on the distributions' + supports. + high: Floating-point `Tensor` of upper bounds on the distributions' + supports. + expected_low: Floating-point `Tensor` of lower bounds on the + expected true means. + expected_high: Floating-point `Tensor` of upper bounds on the + expected true means. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. + name: A name for this operation (optional). + + Returns: + check: Op that raises `InvalidArgumentError` if any expected mean + interval does not overlap with the corresponding confidence + interval. + """ + with ops.name_scope( + name, "assert_true_mean_in_interval_by_dkwm", + [samples, low, high, expected_low, expected_high, false_fail_rate]): + samples = ops.convert_to_tensor(samples, name="samples") + low = ops.convert_to_tensor(low, name="low") + high = ops.convert_to_tensor(high, name="high") + expected_low = ops.convert_to_tensor(expected_low, name="expected_low") + expected_high = ops.convert_to_tensor(expected_high, name="expected_high") + false_fail_rate = ops.convert_to_tensor( + false_fail_rate, name="false_fail_rate") + samples = _check_shape_dominates( + samples, [low, high, expected_low, expected_high]) + min_mean, max_mean = true_mean_confidence_interval_by_dkwm( + samples, low, high, false_fail_rate) + # Assert that the interval [min_mean, max_mean] intersects the + # interval [expected_low, expected_high]. This is true if + # max_mean >= expected_low and min_mean <= expected_high. + # By DeMorgan's law, that's also equivalent to + # not (max_mean < expected_low or min_mean > expected_high), + # which is a way of saying the two intervals are not disjoint. + check_confidence_interval_can_intersect = check_ops.assert_greater_equal( + max_mean, expected_low, message="Confidence interval does not " + "intersect: true mean smaller than expected") + with ops.control_dependencies([check_confidence_interval_can_intersect]): + return check_ops.assert_less_equal( + min_mean, expected_high, message="Confidence interval does not " + "intersect: true mean greater than expected") + + def assert_true_mean_equal_by_dkwm_two_sample( samples1, low1, high1, samples2, low2, high2, false_fail_rate=1e-6, name=None): @@ -630,23 +705,26 @@ def assert_true_mean_equal_by_dkwm_two_sample( the assertion will insist on stronger evidence to fail any one member. Args: - samples1: Floating-point tensor of samples from the + samples1: Floating-point `Tensor` of samples from the distribution(s) A. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low1`, `high1`, `low2`, and `high2`. - low1: Floating-point tensor of lower bounds on the supports of the + The support is bounded: `low1 <= samples1 <= high1`. + low1: Floating-point `Tensor` of lower bounds on the supports of the distributions A. - high1: Floating-point tensor of upper bounds on the supports of + high1: Floating-point `Tensor` of upper bounds on the supports of the distributions A. - samples2: Floating-point tensor of samples from the + samples2: Floating-point `Tensor` of samples from the distribution(s) B. Entries are assumed IID across the 0th dimension. The other dimensions must broadcast with `low1`, `high1`, `low2`, and `high2`. - low2: Floating-point tensor of lower bounds on the supports of the + The support is bounded: `low2 <= samples2 <= high2`. + low2: Floating-point `Tensor` of lower bounds on the supports of the distributions B. - high2: Floating-point tensor of upper bounds on the supports of + high2: Floating-point `Tensor` of upper bounds on the supports of the distributions B. - false_fail_rate: *Scalar* admissible total rate of mistakes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of mistakes. name: A name for this operation (optional). Returns: @@ -676,20 +754,10 @@ def assert_true_mean_equal_by_dkwm_two_sample( # and sample counts should be valid; however, because the intervals # scale as O(-log(false_fail_rate)), there doesn't seem to be much # room to win. - min_mean_1, max_mean_1 = true_mean_confidence_interval_by_dkwm( - samples1, low1, high1, false_fail_rate / 2.) min_mean_2, max_mean_2 = true_mean_confidence_interval_by_dkwm( samples2, low2, high2, false_fail_rate / 2.) - # I want to assert - # not (max_mean_1 < min_mean_2 or min_mean_1 > max_mean_2), - # but I think I only have and-combination of asserts, so use DeMorgan. - check_confidence_intervals_can_intersect = check_ops.assert_greater_equal( - max_mean_1, min_mean_2, message="Confidence intervals do not " - "intersect: samples1 has a smaller mean than samples2") - with ops.control_dependencies([check_confidence_intervals_can_intersect]): - return check_ops.assert_less_equal( - min_mean_1, max_mean_2, message="Confidence intervals do not " - "intersect: samples2 has a smaller mean than samples1") + return assert_true_mean_in_interval_by_dkwm( + samples1, low1, high1, min_mean_2, max_mean_2, false_fail_rate / 2.) def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( @@ -710,22 +778,24 @@ def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample( with the same `false_pass_rate`. Args: - n1: Tensor of numbers of samples to be drawn from the distributions A. - low1: Floating-point tensor of lower bounds on the supports of the + n1: `Tensor` of numbers of samples to be drawn from the distributions A. + low1: Floating-point `Tensor` of lower bounds on the supports of the distributions A. - high1: Floating-point tensor of upper bounds on the supports of + high1: Floating-point `Tensor` of upper bounds on the supports of the distributions A. - n2: Tensor of numbers of samples to be drawn from the distributions B. - low2: Floating-point tensor of lower bounds on the supports of the + n2: `Tensor` of numbers of samples to be drawn from the distributions B. + low2: Floating-point `Tensor` of lower bounds on the supports of the distributions B. - high2: Floating-point tensor of upper bounds on the supports of + high2: Floating-point `Tensor` of upper bounds on the supports of the distributions B. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - discr: Tensor of lower bounds on the distances between true means + discr: `Tensor` of lower bounds on the distances between true means detectable by a two-sample DKWM-based test. For each batch member `i`, of `K` total, drawing `n1[i]` samples @@ -776,24 +846,26 @@ def min_num_samples_for_dkwm_mean_two_sample_test( (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval). Args: - discrepancy: Floating-point tensor of desired upper limits on mean + discrepancy: Floating-point `Tensor` of desired upper limits on mean differences that may go undetected with probability higher than `1 - false_pass_rate`. - low1: Floating-point tensor of lower bounds on the supports of the + low1: Floating-point `Tensor` of lower bounds on the supports of the distributions A. - high1: Floating-point tensor of upper bounds on the supports of + high1: Floating-point `Tensor` of upper bounds on the supports of the distributions A. - low2: Floating-point tensor of lower bounds on the supports of the + low2: Floating-point `Tensor` of lower bounds on the supports of the distributions B. - high2: Floating-point tensor of upper bounds on the supports of + high2: Floating-point `Tensor` of upper bounds on the supports of the distributions B. - false_fail_rate: *Scalar* admissible total rate of false failures. - false_pass_rate: *Scalar* admissible rate of false passes. + false_fail_rate: *Scalar* floating-point `Tensor` admissible total + rate of false failures. + false_pass_rate: *Scalar* floating-point `Tensor` admissible rate + of false passes. name: A name for this operation (optional). Returns: - n1: Tensor of numbers of samples to be drawn from the distributions A. - n2: Tensor of numbers of samples to be drawn from the distributions B. + n1: `Tensor` of numbers of samples to be drawn from the distributions A. + n2: `Tensor` of numbers of samples to be drawn from the distributions B. For each batch member `i`, of `K` total, drawing `n1[i]` samples from scalar distribution A supported on `[low1[i], high1[i]]` and `n2[i]` diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index cd6d7499595d88d18de339371d4a07fe780662d9..ece03fe4aab3cc3046e0958d883ca9388517b94b 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -40,6 +40,7 @@ from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.ops.linalg import linear_operator_lower_triangular as linop_tril_lib +from tensorflow.python.util import deprecation __all__ = [ @@ -49,6 +50,14 @@ __all__ = [ ] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_softmaxnormal_gauss_hermite( normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): @@ -111,6 +120,14 @@ def quadrature_scheme_softmaxnormal_gauss_hermite( return grid, probs +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def quadrature_scheme_softmaxnormal_quantiles( normal_loc, normal_scale, quadrature_size, validate_args=False, name=None): @@ -318,6 +335,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): https://arxiv.org/abs/1801.03080 """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, mix_loc, temperature, @@ -395,7 +420,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): ValueError: if `not distribution.is_scalar_batch`. ValueError: if `not distribution.is_scalar_event`. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[mix_loc, temperature]) as name: if not scale or len(scale) < 2: raise ValueError("Must specify list (or list-like object) of scale " @@ -779,6 +804,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): return array_ops.reshape(p, shape=expand_shape) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def maybe_check_quadrature_param(param, name, validate_args): """Helper which checks validity of `loc` and `scale` init args.""" with ops.name_scope(name="check_" + name, values=[param]): @@ -812,6 +845,14 @@ def maybe_check_quadrature_param(param, name, validate_args): return param +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def determine_batch_event_shapes(grid, endpoint_affine): """Helper to infer batch_shape and event_shape.""" with ops.name_scope(name="determine_batch_event_shapes"): @@ -850,6 +891,14 @@ def determine_batch_event_shapes(grid, endpoint_affine): return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def interpolate_loc(grid, loc): """Helper which interpolates between two locs.""" if len(loc) != 2: @@ -876,6 +925,14 @@ def interpolate_loc(grid, loc): return [x[..., k] for k in range(deg)] # list(shape:[B, e]) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def interpolate_scale(grid, scale): """Helper which interpolates between two scales.""" if len(scale) != 2: @@ -892,6 +949,14 @@ def interpolate_scale(grid, scale): ])[0] for q in range(deg)] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def linop_scale(w, op): # We assume w > 0. (This assumption only relates to the is_* attributes.) with ops.name_scope("linop_scale", values=[w]): @@ -927,6 +992,14 @@ def linop_scale(w, op): "Unsupported Linop type ({})".format(type(op).__name__)) +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [distribution_util.static_value(x) for x in args] @@ -935,6 +1008,14 @@ def concat_vectors(*args): return [val for vec in args_ for val in vec] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def add(x, y): """Adds inputs; interprets `None` as zero.""" if x is None: @@ -944,11 +1025,27 @@ def add(x, y): return x + y +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def vec_osquare(x): """Computes the outer-product of a (batch of) vector, i.e., x.T x.""" return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :] +@deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def softmax(x, axis, name=None): """Equivalent to tf.nn.softmax but works around b/70297725.""" with ops.name_scope(name, "softmax", [x, axis]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index 3465d66b30501e7aebd9904d2ae2206d628c10b7..73356a3625c9a1aa15af5b6c1cf2ccb0c514b39a 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import vector_exponential_linear_operator as vector_exponential_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -116,6 +117,14 @@ class VectorExponentialDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -175,7 +184,7 @@ class VectorExponentialDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name) as name: with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index 2c31b019845d7e4558eb3047af84732a2ae03986..9a47b4855763a25b484ad04a3415d191f19256f7 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -26,6 +26,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import exponential from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = ["VectorExponentialLinearOperator"] @@ -138,6 +139,14 @@ class VectorExponentialLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, @@ -175,7 +184,7 @@ class VectorExponentialLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 6a36018d6f1b83955ef9080ec11c74c08a670075..e68ddc569c95ff63760b4b2f6d7a92f17240a558 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import vector_laplace_linear_operator as vector_laplace_linop from tensorflow.python.framework import ops +from tensorflow.python.util import deprecation __all__ = [ @@ -151,6 +152,14 @@ class VectorLaplaceDiag( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -210,7 +219,7 @@ class VectorLaplaceDiag( Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name): with ops.name_scope("init", values=[ loc, scale_diag, scale_identity_multiplier]): diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index 97e5c76d800acd800e34a9e66a3c5fdd7ce4f660..3923161a332a77e4eaab8d65d96fd8c278c872ec 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import laplace from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.linalg import linalg +from tensorflow.python.util import deprecation __all__ = [ @@ -154,6 +155,14 @@ class VectorLaplaceLinearOperator( """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale=None, @@ -191,7 +200,7 @@ class VectorLaplaceLinearOperator( ValueError: if `scale` is unspecified. TypeError: if not `scale.dtype.is_floating` """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) if scale is None: raise ValueError("Missing required `scale` parameter.") if not scale.dtype.is_floating: diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index ff5ca4525700aedc88d75e391bf0c2415c2afa13..49ffff24caec8d6c525f65f06796d10548d5ec40 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation __all__ = [ "VectorSinhArcsinhDiag", @@ -95,6 +96,14 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): ``` """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, loc=None, scale_diag=None, @@ -163,7 +172,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): Raises: ValueError: if at most `scale_identity_multiplier` is specified. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope( name, diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 4742f7521816d4643354017495f3380c78ac7bc2..f289b39e51aff36780541a0545ed9e6cfe21dd4e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import student_t from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.util import deprecation class _VectorStudentT(transformed_distribution.TransformedDistribution): @@ -121,6 +122,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, loc=None, @@ -175,7 +184,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) graph_parents = [df, loc, scale_identity_multiplier, scale_diag, scale_tril, scale_perturb_factor, scale_perturb_diag] with ops.name_scope(name) as name: diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index f555867e7f3c2a6bc797e9b3d56da2fa434aba6f..f1accaaa4c920344608015c792a2c3606de1337f 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.util import deprecation __all__ = [ "WishartCholesky", @@ -73,6 +74,14 @@ class _WishartLinearOperator(distribution.Distribution): this class. """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale_operator, @@ -107,7 +116,7 @@ class _WishartLinearOperator(distribution.Distribution): ValueError: if df < k, where scale operator event shape is `(k, k)` """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) self._cholesky_input_output_matrices = cholesky_input_output_matrices with ops.name_scope(name) as name: with ops.name_scope("init", values=[df, scale_operator]): @@ -501,6 +510,14 @@ class WishartCholesky(_WishartLinearOperator): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale, @@ -530,7 +547,7 @@ class WishartCholesky(_WishartLinearOperator): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name, values=[scale]) as name: with ops.name_scope("init", values=[scale]): scale = ops.convert_to_tensor(scale) @@ -617,6 +634,14 @@ class WishartFull(_WishartLinearOperator): """ + @deprecation.deprecated( + "2018-10-01", + "The TensorFlow Distributions library has moved to " + "TensorFlow Probability " + "(https://github.com/tensorflow/probability). You " + "should update all references to use `tfp.distributions` " + "instead of `tf.contrib.distributions`.", + warn_once=True) def __init__(self, df, scale, @@ -646,7 +671,7 @@ class WishartFull(_WishartLinearOperator): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ - parameters = distribution_util.parent_frame_arguments() + parameters = dict(locals()) with ops.name_scope(name) as name: with ops.name_scope("init", values=[scale]): scale = ops.convert_to_tensor(scale) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index d7909dd5a2691a015a6afed2caa475b39ca7ebc3..adf92c27ea0a27c5741bcdd175b277462cb28d02 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -106,7 +106,8 @@ class Iterator(iterator_ops.EagerIterator, checkpointable.CheckpointableBase): target_device=target, buffer_size=10, container="", - shared_name=_generate_shared_name("function_buffer_resource")) + shared_name=_generate_shared_name( + "contrib_eager_iterator_function_buffer_resource")) self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long handle=self._buffer_resource_handle, handle_device=self._device) diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index c1fd9e0ed020beeb722204edf1adfe1dfcf8ff03..1d9371c7ac405dbf0ec40210270b90f2cf9b9a25 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -7,6 +7,8 @@ py_library( name = "examples_pip", deps = [ "//tensorflow/contrib/eager/python/examples/gan:mnist", + "//tensorflow/contrib/eager/python/examples/l2hmc", + "//tensorflow/contrib/eager/python/examples/l2hmc:neural_nets", "//tensorflow/contrib/eager/python/examples/linear_regression", "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/BUILD b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..7bdf9053de749af9d09b12ba7b848e21c1fdb8f0 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/BUILD @@ -0,0 +1,39 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_library( + name = "neural_nets", + srcs = ["neural_nets.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + ], +) + +py_library( + name = "l2hmc", + srcs = ["l2hmc.py"], + srcs_version = "PY2AND3", + deps = [ + ":neural_nets", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", + ], +) + +cuda_py_test( + name = "l2hmc_test", + size = "large", + srcs = ["l2hmc_test.py"], + additional_deps = [ + ":l2hmc", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py new file mode 100644 index 0000000000000000000000000000000000000000..729d8525fab31ee214178ca1bcb18dbd069f767a --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py @@ -0,0 +1,326 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""L2HMC compatible with TensorFlow's eager execution. + +Reference [Generalizing Hamiltonian Monte Carlo with Neural +Networks](https://arxiv.org/pdf/1711.09268.pdf) + +Code adapted from the released TensorFlow graph implementation by original +authors https://github.com/brain-research/l2hmc. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import numpy.random as npr +import tensorflow as tf +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets + + +class Dynamics(tf.keras.Model): + """Dynamics engine of naive L2HMC sampler. + + Args: + x_dim: dimensionality of observed data + loglikelihood_fn: log-likelihood function of conditional probability + n_steps: number of leapfrog steps within each transition + eps: initial value learnable scale of step size + """ + + def __init__(self, x_dim, loglikelihood_fn, n_steps=25, eps=.1): + super(Dynamics, self).__init__() + + self.x_dim = x_dim + self.potential = loglikelihood_fn + self.n_steps = n_steps + + self._construct_time() + self._construct_masks() + + self.position_fn = neural_nets.GenericNet(x_dim, factor=2.) + self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.) + + self.eps = tfe.Variable( + initial_value=eps, name="eps", dtype=tf.float32, trainable=True) + + def apply_transition(self, position): + """Propose a new state and perform the accept or reject step.""" + + # Simulate dynamics both forward and backward; + # Use sampled Bernoulli masks to compute the actual solutions + position_f, momentum_f, accept_prob_f = self.transition_kernel( + position, forward=True) + position_b, momentum_b, accept_prob_b = self.transition_kernel( + position, forward=False) + + # Decide direction uniformly + forward_mask = tf.cast( + tf.random_uniform(shape=[tf.shape(position)[0]]) > .5, tf.float32) + backward_mask = 1. - forward_mask + + # Obtain proposed states + position_post = ( + forward_mask[:, None] * position_f + + backward_mask[:, None] * position_b) + momentum_post = ( + forward_mask[:, None] * momentum_f + + backward_mask[:, None] * momentum_b) + + # Probability of accepting the proposed states + accept_prob = forward_mask * accept_prob_f + backward_mask * accept_prob_b + + # Accept or reject step + accept_mask = tf.cast( + accept_prob > tf.random_uniform(tf.shape(accept_prob)), tf.float32) + reject_mask = 1. - accept_mask + + # Samples after accept/reject step + position_out = ( + accept_mask[:, None] * position_post + reject_mask[:, None] * position) + + return position_post, momentum_post, accept_prob, position_out + + def transition_kernel(self, position, forward=True): + """Transition kernel of augmented leapfrog integrator.""" + + lf_fn = self._forward_lf if forward else self._backward_lf + + # Resample momentum + momentum = tf.random_normal(tf.shape(position)) + position_post, momentum_post = position, momentum + sumlogdet = 0. + # Apply augmented leapfrog steps + for i in range(self.n_steps): + position_post, momentum_post, logdet = lf_fn(position_post, momentum_post, + i) + sumlogdet += logdet + + accept_prob = self._compute_accept_prob(position, momentum, position_post, + momentum_post, sumlogdet) + + return position_post, momentum_post, accept_prob + + def _forward_lf(self, position, momentum, i): + """One forward augmented leapfrog step. See eq (5-6) in paper.""" + + t = self._get_time(i) + mask, mask_inv = self._get_mask(i) + sumlogdet = 0. + + momentum, logdet = self._update_momentum_forward(position, momentum, t) + sumlogdet += logdet + + position, logdet = self._update_position_forward(position, momentum, t, + mask) + sumlogdet += logdet + + position, logdet = self._update_position_forward(position, momentum, t, + mask_inv) + sumlogdet += logdet + + momentum, logdet = self._update_momentum_forward(position, momentum, t) + sumlogdet += logdet + + return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + + def _backward_lf(self, position, momentum, i): + """One backward augmented leapfrog step. See Appendix A in paper.""" + + # Reversed index/sinusoidal time + t = self._get_time(self.n_steps - i - 1) + mask, mask_inv = self._get_mask(self.n_steps - i - 1) + sumlogdet = 0. + + momentum, logdet = self._update_momentum_backward(position, momentum, t) + sumlogdet += logdet + + position, logdet = self._update_position_backward(position, momentum, t, + mask) + sumlogdet += logdet + + position, logdet = self._update_position_backward(position, momentum, t, + mask_inv) + sumlogdet += logdet + + momentum, logdet = self._update_momentum_backward(position, momentum, t) + sumlogdet += logdet + + return position, momentum, tf.reduce_sum(sumlogdet, axis=1) + + def _update_momentum_forward(self, position, momentum, t): + """Update v in the forward leapfrog step.""" + + grad = self.grad_potential(position) + scale, translation, transformed = self.momentum_fn([position, grad, t]) + scale *= .5 * self.eps + transformed *= self.eps + momentum = ( + momentum * tf.exp(scale) - + .5 * self.eps * (tf.exp(transformed) * grad - translation)) + + return momentum, scale + + def _update_position_forward(self, position, momentum, t, mask): + """Update x in the forward leapfrog step.""" + + mask_inv = 1. - mask + scale, translation, transformed = self.position_fn( + [momentum, mask * position, t]) + scale *= self.eps + transformed *= self.eps + position = ( + mask * position + + mask_inv * (position * tf.exp(scale) + self.eps * + (tf.exp(transformed) * momentum + translation))) + + return position, mask_inv * scale + + def _update_momentum_backward(self, position, momentum, t): + """Update v in the backward leapfrog step. Inverting the forward update.""" + + grad = self.grad_potential(position) + scale, translation, transformed = self.momentum_fn([position, grad, t]) + scale *= -.5 * self.eps + transformed *= self.eps + momentum = ( + tf.exp(scale) * (momentum + .5 * self.eps * + (tf.exp(transformed) * grad - translation))) + + return momentum, scale + + def _update_position_backward(self, position, momentum, t, mask): + """Update x in the backward leapfrog step. Inverting the forward update.""" + + mask_inv = 1. - mask + scale, translation, transformed = self.position_fn( + [momentum, mask_inv * position, t]) + scale *= -self.eps + transformed *= self.eps + position = ( + mask_inv * position + mask * tf.exp(scale) * + (position - self.eps * tf.exp(transformed) * momentum + translation)) + + return position, mask * scale + + def _compute_accept_prob(self, position, momentum, position_post, + momentum_post, sumlogdet): + """Compute the prob of accepting the proposed state given old state.""" + + old_hamil = self.hamiltonian(position, momentum) + new_hamil = self.hamiltonian(position_post, momentum_post) + + return tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.)) + + def _construct_time(self): + """Convert leapfrog step index into sinusoidal time.""" + + self.ts = [] + for i in range(self.n_steps): + t = tf.constant( + [ + np.cos(2 * np.pi * i / self.n_steps), + np.sin(2 * np.pi * i / self.n_steps) + ], + dtype=tf.float32) + self.ts.append(t[None, :]) + + def _get_time(self, i): + """Get sinusoidal time for i-th augmented leapfrog step.""" + + return self.ts[i] + + def _construct_masks(self): + """Construct different binary masks for different time steps.""" + + self.masks = [] + for _ in range(self.n_steps): + idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2] + mask = np.zeros((self.x_dim,)) + mask[idx] = 1. + mask = tf.constant(mask, dtype=tf.float32) + self.masks.append(mask[None, :]) + + def _get_mask(self, i): + """Get binary masks for i-th augmented leapfrog step.""" + + m = self.masks[i] + return m, 1. - m + + def kinetic(self, v): + """Compute the kinetic energy.""" + + return .5 * tf.reduce_sum(v**2, axis=1) + + def hamiltonian(self, position, momentum): + """Compute the overall Hamiltonian.""" + + return self.potential(position) + self.kinetic(momentum) + + def grad_potential(self, position, check_numerics=True): + """Get gradient of potential function at current location.""" + + if not tf.executing_eagerly(): + # TODO(lxuechen): Change this to tfe.gradients_function when it works + grad = tf.gradients(self.potential(position), position)[0] + else: + grad = tfe.gradients_function(self.potential)(position)[0] + + if check_numerics: + return tf.check_numerics(grad, message="gradient of potential") + + return grad + + +# Examples of unnormalized log density/probabilities +def get_scg_energy_fn(): + """Get energy function for 2d strongly correlated Gaussian.""" + + # Avoid recreating tf constants on each invocation of gradients + mu = tf.constant([0., 0.]) + sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]]) + sigma_inv = tf.matrix_inverse(sigma) + + def energy(x): + """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + + xmmu = x - mu + return .5 * tf.diag_part( + tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) + + return energy + + +def get_multivariate_gaussian_energy_fn(x_dim=2): + """Get energy function for 2d strongly correlated Gaussian.""" + + mu = tf.random_normal(shape=[x_dim]) + # Lower triangularize and positive diagonal + l = tf.sigmoid( + tf.matrix_band_part(tf.random_normal(shape=[x_dim, x_dim]), -1, 0)) + # Exploit Cholesky decomposition + sigma = tf.matmul(l, tf.transpose(l)) + sigma *= 100. # Small covariance causes extreme numerical instability + sigma_inv = tf.matrix_inverse(sigma) + + def energy(x): + """Unnormalized log density/energy of 2d strongly correlated Gaussian.""" + + xmmu = x - mu + return .5 * tf.diag_part( + tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) + + return energy diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e33b4cae4c73388dfd78542c9907953f137ad710 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc_test.py @@ -0,0 +1,264 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests l2hmc fit to 2D strongly correlated Gaussian executed eagerly.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import numpy.random as npr +import tensorflow as tf +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.l2hmc import l2hmc + + +def get_default_hparams(): + return tf.contrib.training.HParams( + x_dim=2, + n_samples=200, + n_steps=10, + eps=.1, + n_iters=10, + learning_rate=.0003, + n_warmup_iters=3) + + +# Relevant functions for benchmarking +def compute_loss(dynamics, x, scale=.1, eps=1e-4): + """Compute loss defined in equation (8).""" + + z = tf.random_normal(tf.shape(x)) + x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) + z_, _, z_accept_prob, _ = dynamics.apply_transition(z) + + # Add eps for numerical stability; following released impl + x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps + z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps + + loss = tf.reduce_mean( + (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) + + return loss, x_out + + +def loss_and_grads(dynamics, x, loss_fn=compute_loss): + """Obtain loss value and gradients.""" + + with tf.GradientTape() as tape: + loss_val, x_out = loss_fn(dynamics, x) + grads = tape.gradient(loss_val, dynamics.variables) + + return loss_val, grads, x_out + + +def warmup(dynamics, optimizer, n_iters=1, n_samples=200, loss_fn=compute_loss): + """Warmup optimization to reduce overhead.""" + + samples = tf.random_normal( + shape=[n_samples, dynamics.x_dim], dtype=tf.float32) + + for _ in range(n_iters): + _, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) + optimizer.apply_gradients(zip(grads, dynamics.variables)) + + +def fit(dynamics, + samples, + optimizer, + loss_fn=compute_loss, + n_iters=5000, + verbose=True, + logdir=None, + decay_lr=True): + """Fit L2HMC sampler with given log-likelihood function.""" + + if logdir: + summary_writer = tf.contrib.summary.create_file_writer(logdir) + + for i in range(n_iters): + loss, grads, samples = loss_and_grads(dynamics, samples, loss_fn=loss_fn) + # TODO(lxuechen): Proper learning rate decay + if decay_lr: + grads = [grad * .96**(i // 1000) for grad in grads] + optimizer.apply_gradients(zip(grads, dynamics.variables)) + if verbose: + print("Iteration %d: loss %.4f" % (i, loss)) + + if logdir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("loss", loss) + + +class L2hmcTest(tf.test.TestCase): + """Unit tests for l2hmc in both eager and graph mode.""" + + def test_apply_transition(self): + """Testing function `Dynamics.apply_transition` in graph and eager mode.""" + + # Eager mode testing + hparams = get_default_hparams() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=l2hmc.get_scg_energy_fn(), + n_steps=hparams.n_steps, + eps=hparams.eps) + samples = tf.random_normal(shape=[hparams.n_samples, hparams.x_dim]) + x_, v_, x_accept_prob, x_out = dynamics.apply_transition(samples) + + self.assertEqual(x_.shape, v_.shape) + self.assertEqual(x_out.shape, samples.shape) + self.assertEqual(x_.shape, x_out.shape) + self.assertEqual(x_accept_prob.shape, (hparams.n_samples,)) + + # Graph mode testing + with tf.Graph().as_default(): + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=l2hmc.get_scg_energy_fn(), + n_steps=hparams.n_steps, + eps=hparams.eps) + x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) + x_, v_, x_accept_prob, x_out = dynamics.apply_transition(x) + samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + np_x_, np_v_, np_x_accept_prob, np_x_out = sess.run( + [x_, v_, x_accept_prob, x_out], feed_dict={x: samples}) + + self.assertEqual(np_x_.shape, np_v_.shape) + self.assertEqual(samples.shape, np_x_out.shape) + self.assertEqual(np_x_.shape, np_x_out.shape) + self.assertEqual(np_x_accept_prob.shape, (hparams.n_samples,)) + + +class L2hmcBenchmark(tf.test.Benchmark): + """Eager and graph benchmarks for l2hmc.""" + + def _get_energy_fn(self): + """Get specific energy function according to FLAGS.""" + + if FLAGS.energy_fn == "scg": + energy_fn = l2hmc.get_scg_energy_fn() + elif FLAGS.energy_fn == "multivariate_gaussian": + energy_fn = l2hmc.get_multivariate_gaussian_energy_fn(x_dim=FLAGS.x_dim) + else: + raise ValueError("No such energy function %s" % FLAGS.energy_fn) + + return energy_fn + + def benchmark_graph(self): + """Benchmark Graph performance.""" + + hparams = get_default_hparams() + tf.reset_default_graph() + with tf.Graph().as_default(): + energy_fn = self._get_energy_fn() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=energy_fn, + n_steps=hparams.n_steps, + eps=hparams.eps) + x = tf.placeholder(tf.float32, shape=[None, hparams.x_dim]) + loss, x_out = compute_loss(dynamics, x) + + global_step = tf.Variable(0., name="global_step", trainable=False) + learning_rate = tf.train.exponential_decay( + hparams.learning_rate, global_step, 1000, 0.96, staircase=True) + optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) + train_op = optimizer.minimize(loss, global_step=global_step) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + # Warmup to reduce initialization effect when timing + samples = npr.normal(size=[hparams.n_samples, hparams.x_dim]) + for _ in range(hparams.n_warmup_iters): + _, _, _, _ = sess.run( + [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) + + # Training + start_time = time.time() + for i in range(hparams.n_iters): + samples, loss_np, _, _ = sess.run( + [x_out, loss, train_op, learning_rate], feed_dict={x: samples}) + print("Iteration %d: loss %.4f" % (i, loss_np)) + wall_time = time.time() - start_time + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="graph_train_%s" % ("gpu" + if tf.test.is_gpu_available() else "cpu"), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + def benchmark_eager(self): + self._benchmark_eager() + + def benchmark_eager_defun(self): + self._benchmark_eager(defun=True) + + def _benchmark_eager(self, defun=False): + """Benchmark Eager performance.""" + + hparams = get_default_hparams() + energy_fn = self._get_energy_fn() + dynamics = l2hmc.Dynamics( + x_dim=hparams.x_dim, + loglikelihood_fn=energy_fn, + n_steps=hparams.n_steps, + eps=hparams.eps) + optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) + loss_fn = tfe.defun(compute_loss) if defun else compute_loss + + # Warmup to reduce initialization effect when timing + warmup(dynamics, optimizer, n_iters=hparams.n_warmup_iters, loss_fn=loss_fn) + + # Training + samples = tf.random_normal( + shape=[hparams.n_samples, hparams.x_dim], dtype=tf.float32) + start_time = time.time() + fit(dynamics, + samples, + optimizer, + loss_fn=loss_fn, + n_iters=hparams.n_iters, + decay_lr=True) + wall_time = time.time() - start_time + examples_per_sec = hparams.n_samples / wall_time + + self.report_benchmark( + name="eager_train_%s%s" % ("gpu" if tf.test.is_gpu_available() else + "cpu", "_defun" if defun else ""), + iters=hparams.n_iters, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + del dynamics + del loss_fn + + +if __name__ == "__main__": + tf.flags.DEFINE_string("energy_fn", "scg", + ("The energy function/unnormalized log-probability. " + "Either be `scg` or `multivariate_gaussian`")) + tf.flags.DEFINE_integer("x_dim", 2, "Dimensionality of observation space.") + FLAGS = tf.flags.FLAGS + tf.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py new file mode 100644 index 0000000000000000000000000000000000000000..e230ad5e259df5b450897bd815e901e3934cd293 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py @@ -0,0 +1,84 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Neural nets utility for L2HMC compatible with TensorFlow's eager execution. + +Reference [Generalizing Hamiltonian Monte Carlo with Neural +Networks](https://arxiv.org/pdf/1711.09268.pdf) + +Code adapted from the released TensorFlow graph implementation by original +authors https://github.com/brain-research/l2hmc. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import tensorflow.contrib.eager as tfe + + +class GenericNet(tf.keras.Model): + """Generic neural net with different initialization scale based on input. + + Args: + x_dim: dimensionality of observed data + factor: factor of variance scaling initializer + n_hidden: number of hidden units + """ + + def __init__(self, x_dim, factor, n_hidden=10): + super(GenericNet, self).__init__() + + self.v_layer = _custom_dense(n_hidden, 1. / 3.) + self.x_layer = _custom_dense(n_hidden, factor / 3.) + self.t_layer = _custom_dense(n_hidden, 1. / 3.) + self.h_layer = _custom_dense(n_hidden) + + # Scale + self.scale_layer = _custom_dense(x_dim, .001) + self.coeff_scale = tfe.Variable( + initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True) + # Translation + self.translation_layer = _custom_dense(x_dim, factor=.001) + # Transformation + self.transformation_layer = _custom_dense(x_dim, .001) + self.coeff_transformation = tfe.Variable( + initial_value=tf.zeros([1, x_dim]), + name='coeff_transformation', + trainable=True) + + def call(self, inputs): + v, x, t = inputs + h = self.v_layer(v) + self.x_layer(x) + self.t_layer(t) + h = tf.nn.relu(h) + h = self.h_layer(h) + h = tf.nn.relu(h) + scale = tf.nn.tanh(self.scale_layer(h)) * tf.exp(self.coeff_scale) + translation = self.translation_layer(h) + transformation = ( + tf.nn.tanh(self.transformation_layer(h)) * tf.exp( + self.coeff_transformation)) + + return scale, translation, transformation + + +def _custom_dense(units, factor=1.): + """Custom dense layer with specified weight initialization.""" + + return tf.keras.layers.Dense( + units=units, + use_bias=True, + kernel_initializer=tf.contrib.layers.variance_scaling_initializer( + factor=factor * 2., mode='FAN_IN', uniform=False), + bias_initializer=tf.constant_initializer(0., dtype=tf.float32)) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 2259c20741ab689dbe0d08d32ff05fc7f8a2100d..099b712fc06d1d3eb9ab4095f8db7283690bda76 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -75,7 +75,6 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): mse = lambda xs, ys: mean_square_loss(model, xs, ys) loss_and_grads = tfe.implicit_value_and_gradients(mse) - tf.train.get_or_create_global_step() if logdir: # Support for TensorBoard summaries. Once training has started, use: # tensorboard --logdir= @@ -87,12 +86,13 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): if verbose: print("Iteration %d: loss = %s" % (i, loss.numpy())) - optimizer.apply_gradients(grads, global_step=tf.train.get_global_step()) + optimizer.apply_gradients(grads) if logdir: with summary_writer.as_default(): with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar("loss", loss) + tf.contrib.summary.scalar("loss", loss, step=i) + tf.contrib.summary.scalar("step", i, step=i) def synthetic_dataset(w, b, noise_level, batch_size, num_batches): diff --git a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb index 9fd2d8d1254e32ae75ab5b085986c6e1c05e76f4..51d10a778413cfbb574b4e22e8adcb18bd731dee 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb @@ -1,495 +1,429 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "Eager Execution Tutorial: Basics", - "version": "0.3.2", - "views": {}, - "default_view": {}, - "provenance": [ - { - "file_id": "0B0kLcpwLFwKEVm9XNkFueGk4bTg", - "timestamp": 1504118841551 - } - ] - } - }, "cells": [ { + "cell_type": "markdown", "metadata": { - "id": "U9i2Dsh-ziXr", - "colab_type": "text" + "colab_type": "text", + "id": "U9i2Dsh-ziXr" }, - "cell_type": "markdown", "source": [ - "# Eager Execution Tutorial: Basics\n", + "# An introduction to TensorFlow\n", "\n", - "This notebook introduces the basics of using TensorFlow's eager execution capabilities. It covers concepts such as:\n", + "This is an introductory tutorial for using TensorFlow. It will cover:\n", "\n", "* Importing required packages\n", - "* Enabling eager execution\n", - "* Creating and using TensorFlow Tensors and Variables\n", - "* Using TensorFlow interactively\n", - "* Using GPUs with eager execution enabled\n", - "\n", - "This notebook does *not* cover modeling topics, such as gradients." + "* Creating and using Tensors\n", + "* Using GPU acceleration\n" ] }, { + "cell_type": "markdown", "metadata": { - "id": "z1JcS5iBXMRO", - "colab_type": "text" + "colab_type": "text", + "id": "z1JcS5iBXMRO" }, - "cell_type": "markdown", "source": [ - "# Step 1: Import Eager\n", + "## Import TensorFlow\n", "\n", - "The key imports for eager execution are the following:" + "To get started, import the `tensorflow` module and enable eager execution.\n", + "Eager execution enables a more interactive frontend to TensorFlow, the details of which we will discuss much later." ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "RlIWhyeLoYnG", - "colab_type": "code", + "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, - "cellView": "code" + "colab_type": "code", + "id": "RlIWhyeLoYnG" }, - "cell_type": "code", + "outputs": [], "source": [ - "# Import TensorFlow.\n", "import tensorflow as tf\n", "\n", - "# Import TensorFlow eager execution support (subject to future changes).\n", - "tfe = tf.contrib.eager" - ], - "execution_count": 0, - "outputs": [] + "tf.enable_eager_execution()" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "H9UySOPLXdaw", - "colab_type": "text" + "colab_type": "text", + "id": "H9UySOPLXdaw" }, - "cell_type": "markdown", "source": [ - "# Step 2: Enable eager execution\n", + "## Tensors\n", "\n", - "All future TensorFlow calls will execute the\n", - "underlying TensorFlow ops immediately:" + "A Tensor is a multi-dimensional array. Similar to NumPy `ndarray` objects, `Tensor` objects have a data type and a shape. Additionally, Tensors can reside in accelerator (like GPU) memory. TensorFlow offers a rich library of operations ([tf.add](https://www.tensorflow.org/api_docs/python/tf/add), [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul), [tf.linalg.inv](https://www.tensorflow.org/api_docs/python/tf/linalg/inv) etc.) that consume and produce Tensors. These operations automatically convert native Python types. For example:\n" ] }, { - "metadata": { - "id": "WPTUfGq6kJ5w", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, "cell_type": "code", - "source": [ - "tf.enable_eager_execution()" - ], "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "twBfWd5xyu_d", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Step 3: Interactively Use TensorFlow!\n", - "\n", - "Now you can call TensorFlow functions and get results, immediately! No more `tf.Sessions`!\n", - "\n", - "TensorFlow will automatically wrap native Python types for you with operator overloading for TensorFlow Tensors." - ] - }, - { "metadata": { - "id": "ngUe237Wt48W", - "colab_type": "code", + "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 125 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 320, + "status": "ok", + "timestamp": 1526420535530, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 }, - "cellView": "code" + "id": "ngUe237Wt48W", + "outputId": "b1a1cd60-4eb3-443d-cd6b-68406390784e" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(3, shape=(), dtype=int32)\n", + "tf.Tensor([4 6], shape=(2,), dtype=int32)\n", + "tf.Tensor(25, shape=(), dtype=int32)\n", + "tf.Tensor(6, shape=(), dtype=int32)\n", + "tf.Tensor(aGVsbG8gd29ybGQ, shape=(), dtype=string)\n", + "tf.Tensor(13, shape=(), dtype=int32)\n" + ] + } + ], "source": [ "print(tf.add(1, 2))\n", "print(tf.add([1, 2], [3, 4]))\n", "print(tf.square(5))\n", "print(tf.reduce_sum([1, 2, 3]))\n", "print(tf.encode_base64(\"hello world\"))\n", - "print(\"\")\n", "\n", - "x = tf.constant(2)\n", - "y = tf.constant(3)\n", - "print(x * y + 1)\n", - "\n", - "# Most TensorFlow ops are directly usable with eager execution, giving\n", - "# results immediately.\n", - "print(tf.contrib.signal.hamming_window(x * y + 1))" - ], - "execution_count": 0, - "outputs": [] + "# Operator overloading is also supported\n", + "print(tf.square(2) + tf.square(3))" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "IDY4WsYRhP81", - "colab_type": "text" + "colab_type": "text", + "id": "IDY4WsYRhP81" }, - "cell_type": "markdown", "source": [ - "Numpy arrays are supported, too:" + "Each Tensor has a shape and a datatype" ] }, { - "metadata": { - "id": "lCUWzso6mbqR", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, "cell_type": "code", - "source": [ - "import numpy as np\n", - "\n", - "ones = np.ones([3, 3])\n", - "\n", - "print(\"numpy 3x3 matrix of 1s:\")\n", - "print(ones)\n", - "print(\"\")\n", - "\n", - "print(\"Multiplied by 42:\")\n", - "print(tf.multiply(ones, 42))" - ], "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "PBNP8yTRfu_X", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "# Step 4: Define and Print TensorFlow Variables\n", - "\n", - "To define TensorFlow variables, use the `get_variable()` function as follows:" - ] - }, - { "metadata": { - "id": "3Twf_Rw-gQFM", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 53 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 215, + "status": "ok", + "timestamp": 1526420538162, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 }, - "cellView": "code" + "id": "srYWH1MdJNG7", + "outputId": "5e4ac41c-5115-4e50-eba0-42e249c16561" }, - "cell_type": "code", - "source": [ - "x = tfe.Variable(0.)" + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 2)\n", + "\u003cdtype: 'int32'\u003e\n" + ] + } ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "45G7094TxsMb", - "colab_type": "text" - }, - "cell_type": "markdown", "source": [ - "## Printing TensorFlow Variables" + "x = tf.matmul([[1]], [[2, 3]])\n", + "print(x.shape)\n", + "print(x.dtype)" ] }, { + "cell_type": "markdown", "metadata": { - "id": "UJBJeZ5XxuwA", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" + "colab_type": "text", + "id": "eBPw8e8vrsom" }, - "cell_type": "code", "source": [ - "# This does NOT print the Variable's actual value:\n", - "print(\"Printing a TensorFlow Variable:\")\n", - "print(x)\n", - "print(\"\")\n", + "The most obvious differences between NumPy arrays and TensorFlow Tensors are:\n", "\n", - "\n", - "print(\"Printing a TensorFlow Variable's value as a numpy array:\")\n", - "print(x.numpy())" - ], - "execution_count": 0, - "outputs": [] + "1. Tensors can be backed by accelerator memory (like GPU, TPU).\n", + "2. Tensors are immutable." + ] }, { + "cell_type": "markdown", "metadata": { - "id": "2njjWHcTpBEn", - "colab_type": "text" + "colab_type": "text", + "id": "Dwi1tdW3JBw6" }, - "cell_type": "markdown", "source": [ - "## Changing a TensorFlow Variable's value\n", + "### NumPy Compatibility\n", "\n", - "To change a TensorFlow Variable's value, use its `.assign()` or `.assign_add()` method:" + "Conversion between TensorFlow Tensors and NumPy ndarrays is quite simple as:\n", + "* TensorFlow operations automatically convert NumPy ndarrays to Tensors.\n", + "* NumPy operations automatically convert Tensors to NumPy ndarrays.\n", + "\n", + "Tensors can be explicitly converted to NumPy ndarrays by invoking the `.numpy()` method on them.\n", + "These conversions are typically cheap as the array and Tensor share the underlying memory representation if possible. However, sharing the underlying representation isn't always possible since the Tensor may be hosted in GPU memory while NumPy arrays are always backed by host memory, and the conversion will thus involve a copy from GPU to host memory." ] }, { - "metadata": { - "id": "v3wr6Erbo_hB", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, "cell_type": "code", - "source": [ - "x.assign(42)\n", - "print(x)\n", - "\n", - "x.assign_add(3)\n", - "print(x)" - ], "execution_count": 0, - "outputs": [] - }, - { "metadata": { - "id": "uhtynjHVpTB5", - "colab_type": "text" - }, - "cell_type": "markdown", - "source": [ - "## Use a Variable just like any other Tensor" - ] - }, - { - "metadata": { - "id": "7PbktdnHoehR", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - } - } + }, + "height": 251 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 238, + "status": "ok", + "timestamp": 1526420540562, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "lCUWzso6mbqR", + "outputId": "fd0a22bc-8249-49dd-fcbd-63161cc47e46" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorFlow operations convert numpy arrays to Tensors automatically\n", + "tf.Tensor(\n", + "[[ 42. 42. 42.]\n", + " [ 42. 42. 42.]\n", + " [ 42. 42. 42.]], shape=(3, 3), dtype=float64)\n", + "And NumPy operations convert Tensors to numpy arrays automatically\n", + "[[ 43. 43. 43.]\n", + " [ 43. 43. 43.]\n", + " [ 43. 43. 43.]]\n", + "The .numpy() method explicitly converts a Tensor to a numpy array\n", + "[[ 42. 42. 42.]\n", + " [ 42. 42. 42.]\n", + " [ 42. 42. 42.]]\n" + ] + } + ], "source": [ - "print(x + 3)\n", + "import numpy as np\n", "\n", - "# This code will broadcast the value across the list of numbers:\n", - "print(x * [1, 2, 4])" - ], - "execution_count": 0, - "outputs": [] + "ndarray = np.ones([3, 3])\n", + "\n", + "print(\"TensorFlow operations convert numpy arrays to Tensors automatically\")\n", + "tensor = tf.multiply(ndarray, 42)\n", + "print(tensor)\n", + "\n", + "\n", + "print(\"And NumPy operations convert Tensors to numpy arrays automatically\")\n", + "print(np.add(tensor, 1))\n", + "\n", + "print(\"The .numpy() method explicitly converts a Tensor to a numpy array\")\n", + "print(tensor.numpy())" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "GVChqwlwy1SI", - "colab_type": "text" + "colab_type": "text", + "id": "PBNP8yTRfu_X" }, - "cell_type": "markdown", "source": [ - "# Step 5: Debug Errors with Instant Feedback\n", + "## GPU acceleration\n", "\n", - "TensorFlow's eager execution helps you identify and debug runtime issues through interactive exploration of code snippets.\n", - "\n", - "Below, we'll define a length-4 vector, and attempt two `tf.slice()` operations,\n", - "one being legal and the other being illegal, leading to a runtime error that is\n", - "raised immediately." + "Many TensorFlow operations can be accelerated by using the GPU for computation. Without any annotations, TensorFlow automatically decides whether to use the GPU or CPU for an operation (and copies the tensor between CPU and GPU memory if necessary). Tensors produced by an operation are typically backed by the memory of the device on which the operation executed. For example:" ] }, { - "metadata": { - "id": "23ap04N0v4k0", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - }, - "cellView": "code" - }, "cell_type": "code", - "source": [ - "vector = tf.constant([10.0, 20.0, 30.0, 40.0])" - ], "execution_count": 0, - "outputs": [] - }, - { "metadata": { - "id": "FCUMsIYxxRRa", - "colab_type": "code", + "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - } + }, + "height": 53 }, - "cellView": "code" - }, - "cell_type": "code", - "source": [ - "# Works, because the values of `begin` and `size` (the 2nd and 3rd input\n", - "# arguments) are within the bound of `vector`.\n", - "print(tf.slice(vector, [1], [3]))" - ], - "execution_count": 0, - "outputs": [] - }, - { - "metadata": { - "id": "T8me2oCNxpFp", "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } + "executionInfo": { + "elapsed": 340, + "status": "ok", + "timestamp": 1526420543562, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 }, - "cellView": "code" + "id": "3Twf_Rw-gQFM", + "outputId": "2239ae2b-adf3-4895-b1f3-464cf5361d1b" }, - "cell_type": "code", - "source": [ - "# The following does NOT work, because the value of `size` (the 3rd\n", - "# argument) causes the indices to go out of the bounds of `vector`. The\n", - "# error is raised immediately.\n", - "try:\n", - " print(tf.slice(vector, [1], [4]))\n", - "except tf.OpError as e:\n", - " print(\"Caught error: %s\" % e)" + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Is there a GPU available: False\n", + "Is the Tensor on GPU #0: False\n" + ] + } ], - "execution_count": 0, - "outputs": [] + "source": [ + "x = tf.random_uniform([3, 3])\n", + "\n", + "print(\"Is there a GPU available: \"),\n", + "print(tf.test.is_gpu_available())\n", + "\n", + "print(\"Is the Tensor on GPU #0: \"),\n", + "print(x.device.endswith('GPU:0'))" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "irxJhAgar84v", - "colab_type": "text" + "colab_type": "text", + "id": "vpgYzgVXW2Ud" }, - "cell_type": "markdown", "source": [ - "# Step 6: Using the GPU\n", - "\n", - "You can explicitly place Tensors on the GPU by calling a Tensor's `.gpu()` method. The `.device` property tells you whether the Tensor is backed by CPU or GPU memory.\n", + "### Device Names\n", "\n", - "The first operation executing on the GPU may be slow as TensorFlow initializes. Subsequent uses will be much faster." + "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:\u003cN\u003e` if the tensor is placed on the `N`-th tensor on the host." ] }, { + "cell_type": "markdown", "metadata": { - "id": "7J4N9baqaKCL", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "colab_type": "text", + "id": "ZWZQCimzuqyP" }, - "cell_type": "code", "source": [ - "# Create some Tensors\n", - "SIZE = 1000\n", - "tensor = tf.random_normal([SIZE, SIZE])\n", - "print(tensor.device)\n", "\n", "\n", - "if tf.test.is_gpu_available():\n", - " gpu_tensor = tensor.gpu()\n", - " cpu_tensor = tensor.cpu()\n", - "else:\n", - " print(\"GPU not available.\")\n", - " cpu_tensor = tensor" - ], - "execution_count": 0, - "outputs": [] + "### Explicit Device Placement\n", + "\n", + "The term \"placement\" in TensorFlow refers to how individual operations are assigned (placed on) a device for execution. As mentioned above, when there is no explicit guidance provided, TensorFlow automatically decides which device to execute an operation, and copies Tensors to that device if needed. However, TensorFlow operations can be explicitly placed on specific devices using the `tf.device` context manager. For example:" + ] }, { + "cell_type": "code", + "execution_count": 0, "metadata": { - "id": "4E-2n7VbzY1n", - "colab_type": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 - } - } + }, + "height": 53 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1762, + "status": "ok", + "timestamp": 1526420547562, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "RjkNZTuauy-Q", + "outputId": "2e613293-ccac-4db2-b793-8ceb5b5adcfd" }, - "cell_type": "code", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "On CPU:\n", + "10 loops, best of 3: 35.8 ms per loop\n" + ] + } + ], "source": [ - "# Time a CPU-based matrix multiplication\n", + "def time_matmul(x):\n", + " %timeit tf.matmul(x, x)\n", "\n", - "print(\"Time to conduct matmul on CPU:\")\n", - "%time tf.matmul(cpu_tensor, cpu_tensor)" - ], - "execution_count": 0, - "outputs": [] + "# Force execution on CPU\n", + "print(\"On CPU:\")\n", + "with tf.device(\"CPU:0\"):\n", + " x = tf.random_uniform([1000, 1000])\n", + " assert x.device.endswith(\"CPU:0\")\n", + " time_matmul(x)\n", + "\n", + "# Force execution on GPU #0 if available\n", + "if tf.test.is_gpu_available():\n", + " with tf.device(\"GPU:0\"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.\n", + " x = tf.random_uniform([1000, 1000])\n", + " assert x.device.endswith(\"GPU:0\")\n", + " time_matmul(x)" + ] }, { + "cell_type": "markdown", "metadata": { - "id": "vbSFW-T5zhZF", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } + "colab_type": "text", + "id": "YEOJTNiOvnpQ" }, - "cell_type": "code", "source": [ - "# Time GPU-based matrix multiplications.\n", + "## Next Steps\n", "\n", - "if tf.test.is_gpu_available():\n", - " # First use of the GPU will be slow:\n", - " print(\"Time to conduct first matmul on GPU:\")\n", - " %time tf.matmul(gpu_tensor, gpu_tensor)\n", - " print()\n", - "\n", - " # Subsequent uses are much faster:\n", - " print(\"Time to conduct second matmul on GPU:\")\n", - " %time tf.matmul(gpu_tensor, gpu_tensor)" - ], - "execution_count": 0, - "outputs": [] + "In this tutorial we covered the most fundamental concepts in TensorFlow - `Tensor`s, operations, and devices.\n", + "In [the next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/2_gradients.ipynb) we will cover automatic differentiation - a building block required for training many machine learning models like neural networks." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "TensorFlow: An introduction", + "provenance": [], + "version": "0.3.2", + "views": {} } - ] -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb index 1e65b27bc8be8b05fefa38dffae7799b1e503bd3..9c1af9c2084bac7ae6369babeaa13720e6199097 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb @@ -7,12 +7,9 @@ "id": "vDJ4XzMqodTy" }, "source": [ - "# Eager Execution: Working with Gradients\n", + "# Automatic Differentiation\n", "\n", - "This notebook demonstrates:\n", - "\n", - "* How to get gradients using TensorFlow's eager execution capabilities\n", - "* How to apply the gradients so you can update your variables" + "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models." ] }, { @@ -22,7 +19,7 @@ "id": "GQJysDM__Qb0" }, "source": [ - "# Setup: Import eager and enable eager execution.\n" + "## Setup\n" ] }, { @@ -40,12 +37,10 @@ }, "outputs": [], "source": [ - "# Import TensorFlow.\n", "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", "\n", - "\n", - "# Enable eager execution.\n", - "tf.enable_eager_execution()" + "tfe = tf.contrib.eager # Shorthand for some symbols" ] }, { @@ -55,28 +50,15 @@ "id": "1CLWJl0QliB0" }, "source": [ - "# Fitting a Simple Linear Model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "-39gouo7mtgu" - }, - "source": [ - "## Step 1: Synthesize some data\n", - "\n", - "To demonstrate fitting a model with TensorFlow's eager execution, we'll fit a linear model to some synthesized data (which includes some noise).\n", + "## Derivatives of a function\n", "\n", - "In the code, we use the variable names `w` and `b` to represent the single weight and bias we'll use to fit our model." + "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: " ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -84,105 +66,53 @@ } }, "colab_type": "code", - "id": "rQsdCg9PfIL-" + "id": "9FViq92UX7P8" }, "outputs": [], "source": [ - "# The constants we'll try to fit our variables to:\n", - "true_w = 3\n", - "true_b = 2\n", - "\n", - "NUM_EXAMPLES = 1000\n", + "from math import pi\n", "\n", - "# Our inputs:\n", - "inputs = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n", + "def f(x):\n", + " return tf.square(tf.sin(x))\n", "\n", - "# Our labels, with noise:\n", - "noise = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n", - "labels = inputs * true_w + true_b + noise" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 347 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 374, - "status": "ok", - "timestamp": 1525154227149, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "O4lsC4ckAcar", - "outputId": "f8becb3f-498b-4cb7-9ef3-608a68cb65d0" - }, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAecAAAFKCAYAAAAnj5dkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xt8VPWdP/7X3M5MkpkkM8mEAAER\nQoICgUBALkUEQ7FucekDEeWL3VZXu121dler39pu1Vbb77b+2m1/3277qNXa2kUptGttt/tDEWqp\nyDWBiC6ES8slXDJJJpfJ3C+/P8JM5nLOmTOTmWQm83r+RebMnJyTAO/z+Xzen/dbFQqFQiAiIqKc\noR7rCyAiIqJYDM5EREQ5hsGZiIgoxzA4ExER5RgGZyIiohzD4ExERJRjtGN9AWE220DWzm02F8Nu\nd2bt/LmukO+/kO8d4P0X8v0X8r0D+XH/VqtJ8lhBjJy1Ws1YX8KYKuT7L+R7B3j/hXz/hXzvQP7f\nf0EEZyIionzC4ExERJRjGJyJiIhyDIMzERFRjmFwJiIiyjEMzkRERDmGwZmIiCjHMDgTERHlGAZn\nIiKiJDy+ADrtTnh8gVH5fjlTvpOIiCjXBIJBbNt9Gq3tNvT0e2Ap1aOxzopNq2uhUWdvfMvgTERE\nJGHb7tPYdfhi5Ovufk/k683NdVn7vpzWJiIiEuHxBdDabhM91treldUpbgZnIiIiEX0OD3r6PaLH\n7ANu9DnEj2UCgzMREZGIMqMellK96DGzyYAyo/ixTGBwJiIiEqHXadBYZxU91lhXCb0ue20pmRBG\nREQkYdPqWgBDa8z2ATfMJgMa6yojr2cLgzMREZEEjVqNzc112LByBvocHpQZ9VkdMYcxOBMRESWh\n12lQZS4ete/HNWciIsqa0a6sNV5w5ExERBk3VpW1xgsGZyIiyrixqqw1XvDxhYiIMmosK2uNFwzO\nRESUUWNZWWu8YHAmIqKMGsvKWuMFgzMREWXUWFbWGi+YEEZERBk3VpW1xgsGZyIiyrixqqw1XjA4\nExFR1ox2Za3xgmvORESUMawIlhmKRs7t7e34x3/8R3zmM5/Bli1bcPnyZTzxxBMIBAKwWq34zne+\nA0EQYj7zzW9+E8eOHYNKpcJTTz2FhoaGrNwAERGNPVYEy6ykPzGn04lvfOMbWLp0aeS1H/zgB9i8\neTO2bt2K6667Djt27Ij5zMGDB3Hu3Dls27YNzz//PJ5//vnMXzkREeWMcEWw7n4PQhiuCLZt9+mx\nvrS8lDQ4C4KAF198EVVVVZHXDhw4gFtvvRUAsGrVKrz//vsxn3n//ffR3NwMAJgxYwb6+vrgcDgy\ned1ERJQjlFQE43R3apJOa2u1Wmi1sW9zuVyRaeyKigrYbLG/lK6uLsyePTvytcVigc1mg9FozMQ1\nExFRCjy+QFYzppNVBHt150mcPG/ndHcKRpytHQqFMvIes7kYWm320uytVlPWzp0PCvn+C/neAd5/\nId+/xVKCl3/3IfYfvwxbrwvW8iIsmTMR962bDY0mc4HRVFYEq7kInXZXwjG9oMW+41ciX4enu4uL\nBDywfm7GrkFMPv/u0wrOxcXFcLvdMBgMuHr1asyUNwBUVVWhq6sr8nVnZyesVvFqMWF2uzOdS1HE\najXBZhvI2vlzXSHffyHfO8D7L+T7t1pN+L+/ao3pDNVpd+HNvWfhdHkz3hmqYUZFzPcKC4WCou9/\n79glfGLxlKztfc6H373cw0Naj07Lli3Dzp07AQBvvfUWVqxYEXN8+fLlkeMffvghqqqqOKVNRDSK\n3F7/qHaG2rS6Fs1NNagoNUCtAipKDVg+pxpur3hwZgMMeUlHzsePH8e//uu/oqOjA1qtFjt37sQL\nL7yA//2//ze2bduGSZMmYf369QCAf/qnf8K3vvUtLFiwALNnz8bdd98NlUqFp59+Ous3QkREw+z9\nyTtDZaI4SPR6dnxFMAA4cd6ObpHrYAMMeUmD85w5c/Dqq68mvP6zn/0s4bXvfe97kT8//vjjI7w0\nIiJKl7l0qDNUssCYLFlM6rjcvubooN9YZxWd7mYDDHks30lENA54fAHYel1AKASruRhWQSsbGLUa\nFbbuapcsGpKsqEh4X3NYONELQMx6tlgDjIYZFqxqnAyPL8AALYHBmYgojwWCQbz+zim898EVuL1D\n68gGQY3mxdfhzlumAxDvDJUsuMod37Byhsx6tg0bVs6IBN3oBhg9/W7sOnIRbae78MfWS9xWJYPB\nmYgoj23bfRrvHOmIec3tDeL3f/4L3G6faGeoZEVD1i2bJnv85nmTJNezu/s9eHXnSXz29lkxAVev\n02BPawf2tHTEvFdstE1sfEFElLfkgiwAtJy0RaaOq8zFkdFssqIhFzsdsscRCsFSKp3Mte/4lYSy\nnUqqiNEwBmciojE0krKWckEWAOwDHtHtSmVGvWRwNZsMqKkySh4XdBpYyorQWCdfuyI+4CZ7IOC2\nqlic1iYiGgPpdHGKz5wOB1mxjGwAMJv0otuV9DqNbLKYqViQPO72BvDG3rPYtLoWLrcf70VV/4oW\nv11L7lq5rSoRgzMR0RhQmu0MyAdyqSAKAAvqrZLZ0GJZ1OFkMQBYv+J6/LntciTJLFprexc2rJyB\nLWvr8T/netAz4E14T3zATfZAwKztWAzORESjLNn6a3S2MyAfyDetrkUoFIrL1tagefFU/O2y6ySv\nITqLWmwfs8Ppg0ckMAOxo+IF9VWKA26yBwIaxuBMRDTKlKy/hqeDlQTy/7WmHnfeUhuzz7lmUrmi\n2tLhZLF4SqehUwm4yR4IaBiDMxHRKEtl/VVpINfrNKixKu9hkKwymNJp6HQCrtQDAQ1jcCYiGmWp\nrL9mOpHK6fFh69uncOJcD+wDXlhK9WiorUTzwhpYSg0x3zuVUTEDbmYxOBMRjQGlgS9TiVThpLL4\nJK/ufg/2tAwVB6mIyxgXGxUDQHefO/JnTk9nB4MzEdEYSGU6eP2K6+F0+3HinB29Do9oIE82TR2f\nVCZGKmNcr9OgoswQyRjv7vfAIKgBqODxBliGMwsYnImIxlB4v7LSzk9LZ1fjnjV1KNZrJd/TWGfF\nw3c1Rs6TrJJYPCUZ49F9mlmGM/MYnImIRlH0CFerUaXc+em941fg8vrxd7fNgqlYkNxmVVwkYP3y\naQCSVxKLl0rGeDSxoE7pYXAmIsqAZNPKYiPcYoMOFzodkfco7fzU0t6F1vY/Y7K1BE63T/Q9+49f\nxicWT1FUSSye2aSH1xeI1OVWGtzjgzqlj8GZiGgElJbhFBvhSgXLZJ2fACAE4KJtUPJ4V68LZzv6\nMH1ymWxSmZhBtw9Pv3woci/rV0xXFNxZhjNzGJyJiEZASRnOVNd8ozs/KR3tJlAB33n9aCQDO7G3\nsx71U83QaVU4ftYO+4Abgk4DtzcQWU+OvhclwZ1lODOHwZmIKE1Ky3CmuuZbbtQDKhUaaitj+h+n\nIngtXyv+YUEsO9zjC8DW68K//eqoZC3tZ+9fhEAgiNZTXeh1eGEQhj7r9QVYhjMLGJyJiNKktHpX\nqmu+To8fT790EGaTgImWYlzucY74WsMPC2L0Og0ErRp2kQYWwNC9bH37FE6et6PP4YXZqMf8ukps\nWDkdDqeP+5yzgMGZiChNSqt3KV3zFbQqeP2hyOh1qNuTF3qdGh5fUPazydgH3Hh150mcPG8XXRuX\nuxdBp8G+qNaQdsdQ4RKNWsWtU1nC3eJERCny+ALotA+NZhvrrKLviV9/3bS6FsvmVMue1+sPib6u\nUqV5oVEEnRr7jl9Bd78HIQxPd7/+zikAww8Q4sSvq7W9Cx6feOcqGhmOnImIFBLLzJ4/sxKrF07G\nsVPdsmU4NWo17l1bj5Pn7Sknebm9QSyfU40j7TbRNWElfH7xkfd7H1zBnbfUQq/TYNPq2si6cp/D\nC0upAbOmluO9qFFzNG6dyh4GZyIihcQys9850oHmpho898BNSctwprqlKUytGirh+T/netIKztWW\nIlzpcYkec3uHksEmVhRj2+7TaDvTjT6HF+VGPRpqK7Bh5QyckHig4Nap7OG0NhGNC+Gp5mxNsybL\nzAYQad0oJRAMIhQKRTKdlQqGgE67SzJhKxmXxy//hlAo8uARnvYOryu/sfes4ql7yhyOnIkor8kV\nAckkpZnZcrbtPo13jqS+Ncpi0qOmypj2vue+QR/0WjU8IlPbBmGogpjcg8ez9y+K/DlZ60jKDAZn\nIsprckVAHr1nYca+z0j7KqdaiCRaSZEOGo0K9VPNMVnTSpUbBcybWYl3Wy8lHFs2txouj1/2wcPh\n9CnuoEWZweBMRHkr2VSz25tkOjcFep0GDTMqsEckwCmZ3k21EEm0C50OPP7DffB4AzAIGoRCoZS2\nVjXOrMTmNXXQadRoOWlDz4AHZSU6LKiz4p5bZ8IfkK5GFr8ljMlfoyPt4Lx9+3a8+eabka+PHz+O\n1tbWyNezZ8/GggULIl+/8sor0Gj4pEVEmZNsqtne78nICCQ8dd52phvAUIJWMDQ03bygXnoKPboZ\nRqqFSOKFE8FSTQibaCnG5jV10KjVQ9nYwRCOtneh1+FB25luaDSnsWl1rWSiGteVx0baf283btyI\njRs3AgAOHjyI//7v/445bjQa8eqrr47s6oiIZCSbajaX6jHQJ56lrEQ4uO48dCGmjGbw2rbfeTMr\nRYtwSK2Dz59Zmdaas5jwA4Icg6DGV/6uKdKAY9vu0zH3Eb0EEH7A4LpybsjItPYPf/hDvPDCC5k4\nFRGRYnJbkxrrKmEQtBgQ+Vwq7R27+z1QSxQBaTvdDc+qQMI5tr7dHjP9HQ6CtyyYhBpriWw3KRWk\nSn7EShaYAeBjDZNQrB/6b37A6cXh/+kUfV+4tCfXlXPHiINzW1sbJk6cCKs1NtXe6/XiscceQ0dH\nB9auXYvPfvazsucxm4uh1WbvL4LVasraufNBId9/Id87MP7v/+G7GlFcJGD/8cuw2V0wl+qxZM5E\nPLh+LoDY+w8Egnj5dx8OvbfXBWt5EZbMmYj71s2GRjO8s/TFNz6ICfhSgdA+4IZG0MFaWRI5/0/e\n+ADvHktclwaAAx9ehcsjPy2tJDADgLXcgAX1VXj70PlIk4toRXotHlg/F3pBi5++eRxvHzwHj1d8\nnTr+PmoUXkOuy+e/+yMOzjt27MCnPvWphNefeOIJ3HHHHVCpVNiyZQuampowd+5cyfPY7SMv7C7F\najXBZhN7fi4MhXz/hXzvQOHc/7qlUzEw6MFRXxfs/R4cOH4ZXq8fD9/ViJ6e4VHq1l3tMUG30+7C\nm3vPwunyxrR3fO+Ysqlns8mAgNcX+RnHnz9essAMDK1jz5tZibbT3TFtHOPNq63E8jnVeOvAedHz\neLx+/OWCHbuOXExa9CT+PsaDfPi7L/fwMOLgfODAAXz1q19NeP2ee+6J/HnJkiVob2+XDc5EROmS\nWkstLhKwfvk0ANlp7xidLDWSrVLRFtRbsbm5Dp5VQ1PvxmIBb+w9G7MWPG9mBUKhEP7tV0clR9qV\n5UUo0mvRclJ8KlvqPig3jKhC2NWrV1FSUgJBEGJeP3v2LB577DGEQiH4/X60tLRg5syZI7pQIiIx\nckFx//HLkYphSoqIAMNJZmLUqqEmFBWlBjQ31WDT6tpIZTKb3Zk0qMtVBrOY9JFzAsPblor1Wmxu\nrsNzD9yEbz64BF/7TBM8ngDeOdJxrWuVuCVzJg7tX05SVWzZnGomfeWgEY2cbTYbLBZL5Ouf/OQn\nWLRoERobG1FdXY0777wTarUaq1evRkNDw4gvlojGv2TJWvHkgq7N7sLZjj5Mn1yWkfaOK+dPwtrF\nU1Fm1EOrUSVkZAs6FTw+8bGsXqfGktlV+GPr5YRjy+dUY8vaetn71WpU2HXkIlpOdsoG3IprmeH3\nrZuNy1f7YTEJku+3lOpx79r6SDY35Y4RBec5c+bgpz/9aeTrBx98MPLnL33pSyM5NREVGLkynHLB\nQy7oqtTAC68fjZxr3sxK7BbZyiTW3hEQ31YUvpb49eVk+5c9viBuXVgDrUYje14p8ZXQxKgAPHpn\nA2qqTNBo1NDrNFhQXyX5uQV1Vk5n5yhWCCOinCBXhlNsL3GY3Eg3nMUcPtetCyejualGci9v9Kg9\nflsRAHT3uSN/Tmd9efeRDty7dlbS7UrxswceX0DR2rGl1ABrXAWvTatrEQyFsO+DK5HEMoOgwfK5\nnM7OZQzORDTmlCZrSYke6fb0u6GSKNBx9FQ3nnvgpoTgGAgGsXVXu+iovaLMkDCir59qTqsUZ9uZ\nHnh8AckymFKzB6saJyddOwbEE7s0ajW2rKnHxltqYbM7AZUK1vIijphzHIMzEY25kXZ80qjVkZHu\n2Y4+vPD6UdH39Qy4I2vQ0eeTG7UHAsGEgiL7jl+BRF0SWcnuRap4idcfkK0IJuhUWNEwSXYkrNdp\nUFOVv/t+Cw2DMxGNuZF2fArT6zSYPrlMeg0awHdePxpJmtq0uhb+QEhy1L637RK8EoU7lBYLiSZ1\nL0Mj91N496h48ZK2092yFcG8vhBUKhUTu8YR/iaJaMyF143FpLoHV+5c4QAXHpFu231adtTu8QbT\nCsJSpO4lvE9bKgD3ObwoNwriB69pbe+KbBsDALfXj067M+Y1yh8cORNRTshk44XwZ9rOdMPW64IK\n4lPCre1dWLds2oi6RSlRbhTQNKtK9F6UFC+xlBrQUFsRU2glXnjKPLxG3namGza7S3HWO+UWBmci\nygnR68YjbbwQPtfnNhTh4LEOfEdiDdo+4IbL45fM9s6EcqOAZ+9bDFOx+MhXSUWy6IeUd1vFR9jh\nKfN0s94pt/AxiohySjiTOZXAHK7SJTaFayrWoUKi4lc4oG1aXYtbF05GNgaWbq8fb+w9i8vdg6LX\nl6wi2c3zqrGqcTL8gRDuWlWLxTdMEH1vY10lAOktXvHT3pTbOHImorwltfXozlumY8cfz0amdvWC\neNSNXgMOBkOi3Z1SMVTeU4VA1NDW7R3K9t7TeikmES08xSy3T3tSZQk+/Isde49dgV7QAAjB7Q3C\nIKgBqOD1BWKm/7v73CPKeqfcweBMRHlLagr35PleXOh0RF53X8u41qiBwLUAbBA0CIVCCASDQxnb\np7pGdC2L6q24Z00dnvv5Ick9yVJTzGLr7cUGbdw9RCd7Dd3EsjnVuDeq7Gemst5p7DE4E1Fekkuk\n6rA5RF8PRI2M3d6h5hHBENBUZ0WvI3mRDznGEgFeXwB2BcVC4gurxK+3F+m1+Porh5Ke5+T53piv\n5Ubh7DyVXxiciSjjUm1ekY6efrdkhrXcnuB477Z2yGZBp3Ieh9MLs0yjiTD7gBu2XhcErTrmZxRe\nb+9U0OEKGPoZxE9VR2eqd/W6RpT1TmOHwZmIMibd5hXp2HVEOrtarppWvFQCebLzHDphg0bBbQo6\nDb63rRV2hw8Wk4AF9VUxPyO56eloekGTMFUdnal+5q/dWX1AouxhtjYRZUx4Dbi734MQYot9ZJLH\nF0Dbaek14urKsUt6CihIKnN7A7A7fACAngEvdh2+iNfeORXJOgcgWUhFKYOgTTnrnXIHR85ElBEj\nbV6RCrkpbQCwlhlwyebMyPfKJLUKQAgQi99/bO3A0XYb7ANeWEr1mDezErcunIyWk12wO8Tv1Xtt\n+YAZ2OMPR85ElBFKmldkytuHz0seU6uAY6d7Mva9MikoEZiBofaWPQPeyIzD7iMdUKlUeOa+RZKl\nO5mBPX4xOBNRRsgV05ALIh5fABc7B3DR5lBUJMPjC2D/h9K9jTO1hpwLWtu7IOg0aJpVJXqcGdjj\nF6e1iQjAyDOsU93GEwgG8do7p7Dvg8uRfbsGQYPlc6tx960zJRPIbL2umD2/41l4xiGTdccpPzA4\nExW4TGZYpxJEtu0+jd1HYrcwhfceq1QqbG6uE39gCOXe0DiV7PBUhGccMll3nPIDgzNRgctkowSl\nQcTjC6DlpPTUdGu7DYFAEG1nuhMeGKzmYhgEdWS0nQuyNZUeP+MQ3gdN4x/XnIkKWLIMa6WNEuIb\nTyRrXtHn8MgW6uju92BP6yXRLVl6nQY33Sje/GG8ELRqNDfVcNq6gHHkTFTAlGRYy43U0p0SL9Jr\nIWhU8AbEh5xS08QtJ20IBEM4fjY3s7GTUauGZuXNJj2cHr/o2rleq8a3/mEpypmFXdAYnIkKWCqN\nEsTWf1OdEo8O5lKBGZCeJu4Z8GSk1OZYWXzDBKxfcT3KjHr8+t0zoslzK+ZPYmAmBmeiQqYkw1pq\ndLx+xfUpFx2JD+ZSykp06Bv0JbyercSr0WAQNNiyth7F+qH/dpmBTXIYnIkKXLIgITU6drr9KU2J\ny61vR6soNWDuDDP+2Ho54Vi+BmYA+FjDxEhgBpQnz1FhYnAmKnByQUIuoJ44Z1c0JR6eDvf6Aoo6\nLTXWVSKYg9ul0lVRKj8iZgY2iWFwJiIA4kFCLmGs1+HB4hsmoPujqwnH6qeWIxAMYeuu9pjpcL3M\nFihTkRY3Xl+B25dch+d/cXjkN5QDyo0CvvaZJpiKxctvEklJOzgfOHAAjz76KGbOnAkAqKurw7/8\ny79Eju/btw/f/e53odFocPPNN+Ohhx4a+dUS0aiSSxgTdBq0X+wFMLwWrNepgJAK7x+/gpb2zphA\nnKz94YDLjwMfXR1KFvPnzh7mkegf9MLl8TM4U8pGNHJevHgxfvCDH4gee+655/DSSy9hwoQJ2LJl\nC9auXYvaWiY6EOWD6MxsqYQxtzcQ2QoUXgv2+EIAQteOiwdYg6BGKAR4fOLHMxGY9Tq15PnV6qEm\nE+kyCBrF5UPZmILSlZVp7QsXLqCsrAwTJ04EAKxcuRLvv/8+gzNRjhPLzJ4/sxKrF07GsVPdsA+4\nUW6U3qOrRCYre5mNAubXWdF2ujsmmW353Go8+zPxqfFQEBC0Knj9qa1rq1XAyvmToFar8M4RZdu5\nGmZYmORFaRlRcD59+jT+4R/+AX19fXj44YexfPlyAIDNZoPFYom8z2Kx4MKFCyO7UiLKOrHM7HeO\ndKC5qQbPPXDTUGKXP4inXzo4hlc5zHXtAeHhDXOgUathLS+CVqPC1rfbJbddhQD4ZfZYSwmGgOam\nKagyF0GlUqG1vQs9A27ZUt/NTVNS/j5EwAiC87Rp0/Dwww/jE5/4BC5cuIBPf/rTeOuttyAI6a2t\nmM3F0Gqz94RptZqydu58UMj3X8j3Dii/f6fLiz+3JW5fAoDWU134zLo5qJlUDrfXD6u5CJ12VyYv\nMy1ubwB7Wjqwp6UDVnMR5s6ohKBTY0/rJdnPpbsl670Pr+LzG+bh0XsWwu3140r3IL7+0/2w9boT\n3qtWD73/wfVzodGMTaVk/t3P3/tPOzhPmDABt99+OwBg6tSpqKysxNWrVzFlyhRUVVWhq6sr8t6r\nV6+iqkq8H2mY3e5M91KSslpNsNkGsnb+XFfI91/I9w4ov/9AMIiv/fSg5FR1d58bD39nN5pmVWHT\n6lo0zKhQVExkNNnsLuw+nN0Zuv0fXMa6pddFpqpLtGrMq60U/VkEg8Af9v0VXq8/5QYimcC/+7l/\n/3IPD2k/zr355pt46aWXAAxNY3d3d2PChKFi9DU1NXA4HLh48SL8fj/27NkTmfImotwQ3axi665T\nuNwj/4Dc6/BGmk9sWl2L5qYaVJQaoFYNJUllmirjZxy5ngEP+hyxWeebVtdiVeMkqCUuOJUGIkRh\naY+cV69ejccffxzvvPMOfD4fnnnmGfz+97+HyWTCmjVr8Mwzz+Cxxx4DANx+++24/vrrM3bRRJS+\n+KQvs0nAoMuv+PPh0pzRhUuMxTq8sfcvOHyiE70O6W5TqcjFMiRq1VDTjmgatRprF0/FHyWm0pU0\nECGKl3ZwNhqN+PGPfyx5fNGiRdi2bVu6pyeiLIlP+pJr3SgmOthEFy7Z3FyHdcum4emXD2YsQOea\nYAii+5ZTaSBCpAT7ORMVEKX1reWIdasKT4+bigUUG8Zv4cGKUr1ooA03EBETbiBClIrx+6+IiBLI\nleNUSq5bVZFei0td2UvuHGuNdVbJQMsuU5RJDM5EBaTMqIfZJIhOZet1ahiLdOgZ8KC8RI95MysQ\nCoVw7HQ3+hxeWEqTd6sCRhb4c9nK+ZNkAy27TFEmMTgTFRC9ToOSIvHgXGUuxlP3LoxJ8Gptt6HP\n4UW5UY+G2gpsWl0LjVqdkenxfPOJm6ZCo06+EsguU5QJDM5EBcTjC8Dp9okec7p9sNmdsJqL8et3\nz8SMiu0OD/a0dECjVmFzc11GpsfzidmoY1IXjSoGZ6ICIhdUu/s9+NrLh2AxCXB6xPfltrZ3Yf2K\n6fjDgXNQqSBbunI8MRbLT1FHNwrhVDZlAoMzUQGR2/ITJre1yj7gxvM/P5y0YMl4M+jyweMLJARe\nsaS4xjprZPqfKF3820NUQOS2/Cih1agKLjADQK8jsTIYMJwU193vQQhDsw/hKmpEI8HgTDTORO87\nFhNdelOVYo1Mf7odI/KcWCERuaQ4luykkeK0NlEOS2UtU2yKdfm8yVi3NDbLOLzlZ/2K6/HLnSdx\n4KNOxaUyg5lrxZxXxAqJyK3fs2QnjRSDM1EOSmctU2zf8Zt7z8Lp8op2RXpj71+w/6POrN1DPlOr\nhmp7W2QKibBkJ2UTgzNRDhILtOGvxQJtsinWDStnxIz8CnGfcipWzp+EtYunxsxYDDi9uNjpQE2V\nEaZiIbJ+L9YukiU7aaQYnIlyTKqBFkg+xWrrdUHQqiPBps/hkc3YLiRTqoxwuv0JJTfDMxRevx/P\n/6IFHTYHgqGhUfVkqxFf+fQCluykrGFwJsox6axlyk2xCjoN/u1XR2Ef8Eamx29fMhVq1VCXpZEQ\ntGp4/bm/EF1jLcGDd8zGntYOtJ3uTgik/kBIcm3/+V+04EKnI/J1MARc6HTg+V+04Nn7FrNkJ2UF\ngzNRjklnLVNuitXtDcDtHcocDk+PO93+EQdmAPjnTQ048FEn3j16KSPny4abGybi3tvqoVGrce/H\n6+FZlZhkp1FDNHlrwOlFh82R8DoAdNgcGHB6I1PcTP6iTOJWKqIck077wUAwiFAoBIMwfMwgqGEQ\nxP+Jnzhnh8UkiB5LxSt/OIlFCLMMAAAgAElEQVSb508as0phgjb5XjCNJvY94UCqZIR7sdMh+dAR\nDA0dJ8oGBmeiHBS9F1mtAipKDWhuqpFcy9y2+zTeOdIRGSEDgNsbhNsrPuXc6/DghussI77OK3YX\n/s8vWyDoUtwwnQHlJTo01FZCp5X/b2xP66W0i4LUVBmhlrg1tWroOFE2cFqbKAeF9yKvWzYtJkM4\nnscXgM3uTDnzWtBpcM+aOpzvdMSsp6bD4xubNefeQR8On1B231KJdMmYigVMthpFf0aTreK/E6JM\nYHAmyiHhoiPRLRvF9jlH74NOJ+s6FArhavcgBl3SdbTHk5EUBfnKpxdIZmsTZQuDM1EOiC86ohc0\nMVPU8fuc4/dBp8rjC+Ibvzgy4uvOFyMpCiJotXj2vsUJ+5yJsolrzkQ5YOvb7TENFKIDc7TW9i4M\nOL0sIJKiTBQFMRULuGGahYGZRgVHzkRjKBAMYuuuU3j36CVF7+/pd+Nip0NyH3Suq7YUodPuyvq2\nK/W1XtNWcxEaZlSwKAjlHQZnojHi8QXwy50n8d7xK4o/oxc0qKkyJu3JnItunj8RaxdNRZFei1f/\nvxM4cd4OlzeYkWIo8UIAHr97PhbPm4yBPldmT040ChiciUZZeH35yImrsDt8KX46hDf2nsWgO9XP\njb332i7jT0cvJ7yejVF0eYke0yeXwSBoMZD50xNlHdeciUZZOJkr9cA8tHd5T+ulhP3LGjWwsnEi\nKkpztxNSIMmOK71O+X9HS26sgtmokzw+n40nKM8xOBMl4fEF0Gl3wuMLiH6d6rmykcwVCAJqlVqy\nslg+ULJf2iBo0NxUg/s/eSMWzpog+p4pVUZsbp6Z6csjGlWc1iaSEL+9yWwSUFIkwOn2Ke6xHC+b\n3aBaT9owt9aSN80o4pmKdBB0atmfT4lBiw0rZ0CjVsd0hOrpd6PMKKBxZiU2r6lT/PsgylUMzkQS\n4vcS9wx40TMwXLQjWY/leIFgEDsPXchKAhQA9A56sfeY8uSyXNNYXwlBq5Hdv20f8ESKiYSrqLEj\nFI1HIwrO3/72t3HkyBH4/X587nOfw8c//vHIsdWrV6O6uhoazdA/lhdeeAETJohPQxHlmlSmn5WW\nhty2+zT2tHSM6LoErQpef462f7rm5saJOHuxHxdtg4o/o9WocO/H6wEAgWAI77Z2iD7AiBUTYUco\nGo/SDs779+/HqVOnsG3bNtjtdnzqU5+KCc4A8OKLL6KkpGTEF0k02uR6KsdTUhoyU2vNC+qtOH6m\nBw63f8TnyhatSoWnP7sIW99uR+upLvQ5vBB0atk15dJiHfyBELQaFTRqFXRa8fdnopgIUT5IOzgv\nWrQIDQ0NAIDS0lK4XC4EAoHISJkon8n1VI6npDSkXLBXqQBTkYB+p3yda4OggVarzunADADvHb+C\njatm4t61s3DX6qFa4V5fAF97+ZDkZ+wOL/ocHuw6clF0WnsoG30yi4lQwUg7OGs0GhQXD40UduzY\ngZtvvjkhMD/99NPo6OjAwoUL8dhjj0Glkm4rZzYXQ6vNXmC3Wk1ZO3c+KOT7T/fel8+bjDf3nlXw\nvkmomVQu+x5TWRGs5qHqWAnXV16EuuvK8WeRPcDRqiuKse+D3F9T9niD8KtUqCwrwmC3EyUmA2pM\nBljLDbD1ukU/Yy0vQs2kcrT96pjo8UAQMOh1qJ5QlvL18O9+4crn+x9xQtiuXbuwY8cOvPzyyzGv\nf+ELX8CKFStQVlaGhx56CDt37sRtt90meR673TnSS5FktZpgsxVuKYJCvv+R3Pu6pVPhdHnR2t4F\n+4Ab5UY9Sop0cLp9sA94YDYZ0FhXiXVLp8JmG4h0lJJKTGqYUSE6Kuwf9CQNzBMtxfjr5fz5Hf7i\nDx/hg9Ndkf3YBkGDynKD5PsbZlTg4qVe0YeXsPfbLmPd0utSmtbm3/3CvHcgP+5f7uFhRMF57969\n+PGPf4yf/vSnMJliv8n69esjf7755pvR3t4uG5yJco1GrcaGlTNwc8NEQKWCtbwIep0mIQgP1cdu\nl2zvGHbnLdNx8nxvpPVgWHxBkXhmowCvP/U91WPp0EedMV+7vQFc7BxEtaUI9gFPZD3ZIGiwfG41\nNq2uhT8QQrlRQK9DfHq/d9CTdttHonyTdnAeGBjAt7/9bbzyyisoLy9POPbFL34RP/rRjyAIAg4d\nOoS1a9eO+GKJRkv8HufogBufHRy/5Sp+i1U4mO88eB4XOh0pX8sN0yx4P4X627nsSo8LZpMejTPL\nsPam61BtKY6MhDVqoHFmJfa0ijcBsYyg7SNRvkk7OP/hD3+A3W7HF7/4xchrN910E+rr67FmzRrc\nfPPN2LRpE/R6PW688UaOmimvSAXcQDAU2fIDAANOLw6f6BQ7BVrbbQgEgmg7042efg9kUi4kLZ9T\njXvWzMTJ8/a8a3QhxT7gwf6POqFRq7FlbX3Msc1r6nC6o1/0IYaZ2lRIVKFQKCc2TWZzbSAf1h6y\nqZDvP5179/gC+OqL+0WDoVoFLLphAjavmYk39v4FR050ot+ZnSYUpmItHr2zAZOtppS7V+ULi0nA\ngvqqmCWAcBvNo+1d6B30wHJtbT+VSmxh/LtfmPcO5Mf9Z23NmSjfhaeci/RauDx+lBn1stuegiHg\nwEdXceCjq0nPrbrWUzhdA04/nvtFCwyCGotvqIJBUCddn843PQPehCprGrUad62qxarGyUAoBKu5\nmCNmKjgMzlSQwmvKLSc70TPgjZTUrCjVY870CpSVCOgdlN93nEym5qTc3iD+dOwKplQZ01qzzgfh\nKmtajUpyrZ/1sqmQMDhTQYpfUw5nT3f3e/DuUfGEpLE2MDg+1pzFhKusxRchSbV+OdF4wUdRKjjZ\natuYbb2D2VnbzgVmkwFFeq3k76W1vSutFp1E+YrBmcalcM9ltzex1GUqdbMzqdwoYGXjJOi16f2z\ny9VlV4Mw8v9GGusq4fL4JX8v4ZE1UaHgtDaNK/H7k63mIjTMqIhZs0ylbnaYkh7J1RVF6OxxiXZT\nUqmAL909H5ayIrSf68XlntQr4uXqwLGi1ICOLvn7MQgaeH0BmE16FBt0GHT50OsYrrIWLkIi9XtR\nUr+caDxhcKZxJX4tudPuSliz1Os0aKyzyvYNjmYx6THnegsOfHQVnmsBWqMeanPo8YWgAhAC4HL7\nJfs0h0LAsz8/BBVUst2Z8pHD5cOqBZPRdrob3f3uoZkBNeDzBSPBd/2K6XA4vZGqamKlTjVqSP5e\nuMeZCg2DM40bcmvJ8T2Xw92NWk7a0DPgiWRriykp0uFPbbG1rwNBYIKlCJe6nAh/rC/JmrDXFwKQ\nE2UFMqp/0Ie1i6bgrlW1kYALICH4FuuH/7uR6sEc/r2E65lHj6yJCgmDM40bcmvJ8T2XNWo1NjfX\nYcPKGZF9zg6XD7sOX0DbmZ5IYGiYYUHbmW7Rc15KMpVbKCylhkgQjg646dTAjv+9SDURIRrvGJxp\n3JBbS5Zas4wOKKZiAfeunRUz5WqzOyVrPcspLdZlrXJYrsnGlLPUyJqoUDBbm8aN8FqymPgAEs7m\njt+eEw7MxmIBv373DL63vS3l67CY9Nh068yUP5ePjEVaTjkTZQFHzjSuxK9ZVpYPZ2sD0t2m7rxl\nOnb88Sxa223o7vcoys6W4nB58dPffZSxe8plTrcfTrcfpmJhrC+FaFxhcKZxJX7Ncsa0Cgz0uSLH\npbpNnTzfG1MaM93APPTZ8Zf0JSUYAi52OnDDNMtYXwrRuMJpbRqXwmuWBmH4+VMum3u81qzONrUK\nqKkyjvVlEI07DM5UMMaqMthYaaqvhEFQlqhVbhSgVg01/qixlkCvG/6vwSBoYCwSn2SbbDVySpso\nCzitTQWjzKiH2SSgZ2Bk3abyhVarwavP3ob/OW2D1+fHd147KloAxSBo8Ox9iyMtM8NFQmx2J6BS\nwVpeBJUqhOd/0YIOmwPB0NCIebLViK98esEY3BnR+MfgTAVDr9Ng1nUW7Dt+ZawvZVS0n+8FANRY\njRhwehGSqrICQNBpYkbAep0GNVWxjeCfvW8xBpxeXOx0oKaKI2aibGJwpoKyec1MtLTb4PbmaKHq\nDOp1eNDV68L2XSfx52OX4Q2IB2ePNxBToEWOqVhg8hfRKOCaM41LUl2p9DoNKssMY3RVo8tsMuB3\ne89i95EO2ezzcIUvIsodHDmTLLEGBbkm+hq1GpVsV6ptu0/jom0w49dQXiKgpEiLrj53zjS2mD3d\njP0fJK9uJlbhKx9+70TjGYMziZIq1hHdenGsiV1jsUEXsy0quivVhpUzJLdSjVTvoBcurz9nArOx\nSIu2093odcgnvy2bUx1T4Ssffu9EhYDBmURJFesAhlsvjjWxa5Tq0dza3oWlN07I6laqXAnMAOBw\n+ZO/CYBOp4r5Oh9+70SFgI/ClCBZ68X4etRjQe4axXT3u/GDX7eNw4aNI/Nu62Vs230aQH783okK\nBYMzJVDSenGspVNQJFm/5ULVctIWWWPO9d87UaFgcKYE4daLYqRaL442uWuk1NgHPJHkr1z/vRMV\nCgZnSpBK68WxIneNwFDVK7UKqCiQbVMjYTbpI1nZuf57JyoUDM4katPqWjQ31aCi1HCt5rIBzU01\nOdW7d/2K6yVrR5cYtHjms4vw/X++BRUcYctaUG+NBN58+L0TFYK0s7W/+c1v4tixY1CpVHjqqafQ\n0NAQObZv3z5897vfhUajwc0334yHHnooIxdLoye+9WIu7nd1OH3wSFT6Cq+dlhn1mDXVjPcKpGQn\nANRPLcPJ831J32cQNFg2N3YrVT783okKQVrB+eDBgzh37hy2bduGM2fO4KmnnsK2bdsix5977jm8\n9NJLmDBhArZs2YK1a9eitpZP3vko3HpxrMgVwwivkYptnwoB+P6ONiw/1Y31N08vmOCsVgGf/cQN\n+MqL+xEQ2dmlF9T40j2NEDRqWM3FkoF3rH/vRIUureD8/vvvo7m5GQAwY8YM9PX1weFwwGg04sKF\nCygrK8PEiRMBACtXrsT777/P4EwpUVIMI7xGGr0vN1p3vwdv7j2LQx8WRmAGhjpFVZmLcUvjZLxz\npCPh+MfmTsT0iWVjcGVElIq0gnNXVxdmz54d+dpiscBms8FoNMJms8FiscQcu3DhQtJzms3F0Gqz\nN31mtZqSv2kcy7f7f/GND0SLYRQXCXhg/dzI6w/f1YjiIgHvf3AJtl636Lku9zizfr1jTa0GplWX\n4juPrIAgaPHIpgUoKdZj//HLsPW6YC0vwpI5E3HfutnQaAor1STf/u5nUiHfO5Df95+RCmGh0MhL\nO9jt2fsP1Go1wWYbyNr5c12u33/81LXT48NbB86Jvve9Y5fwicVTYqZjP7F4CqZXG/Fv29tG65Jz\nyv23z0JDbSVMxQL6+lyR19cvn4Z7b78BZ/7aHfnZ9vRkvq54Lsv1v/vZVMj3DuTH/cs9PKQVnKuq\nqtDV1RX5urOzE1arVfTY1atXUVVVlc63oXFOaura4fZJtnQMF8OoMhcnfL4QVZQa0HTDBMm1Y4Og\n5doxUR5Ka35r+fLl2LlzJwDgww8/RFVVFYxGIwCgpqYGDocDFy9ehN/vx549e7B8+fLMXTGNG+E6\nzt39HoQwPHXderJT8jPlJj28vgA8vkDC5wsR9x8TjU9pjZwXLFiA2bNn4+6774ZKpcLTTz+N3/zm\nNzCZTFizZg2eeeYZPPbYYwCA22+/Hddff31GL5ryn1wdZ49POtQ6nD48/fIhmE0CnJ7Cq/WsUSOS\nhW0Q1AiGQggEg+wYRTTOpL3m/Pjjj8d8PWvWrMifFy1aFLO1igqT3DaodGpjA4DXPxSZegbkWyGO\nFwZBA68vALPJgCK9JqYXtdsbxO4jHVCrVOwYRTTOsGUkZZySbVBye5QNgkZyzXm8U6uG9mhbTAY0\n1lVi/Yrp6HN4sPPQefz52GXRz7S2d2HDyhmc3iYaRxicKeOU9ASW26NcWW5AV687EqD1WjU8/uz3\nSv7nuxrQM+DBG386g95BZf2QM23l/ElYu3hqzGzDG3vP4k9HxQMzEJskR0TjA4MzZZTcWvKREzas\nWzYNpmIBACJlI1vbu2AfcMNsMqDYoMWFTkfsOf1BGAQ13N7EAG0QNCjSa2Af4TS3SgW88t8n0DPg\nRaa3AWvUgFajhscn/4BhEDTYcEstivXD/yyV9K1mxyii8YfBmTJKtieww4OnXz6IpllVkSnu6DrO\nRXotvv7KIYkzq0RftZYXYVJlMQ58JJ3hrUQoNLyOLVb2MhUq1dD5wl2xnti8AIJWjadfPoheh/RD\nhNcXgMPpjQnOStbmmbFNNP4wxZMyKlmf5V6HF7sOX8R/vH0y8lq4jrPL45cMRF5fANWWooTXL3Q6\ncPRUl8gnxk64Jk8wBNh63fiXn+7H7/b9FQvrpVtcAuIjYLmfp1oFrGqcxI5RROMQgzNlVLI+y2F/\nbL2MV986iUBweJhqLNZBL9ECUqdV42qPS/RYsuniseb2BrHr8EWEADQ31Ui2uWyYYUGfwwOPbzgZ\nTu7nubJxMu5dO4vbqIjGIU5rU0Z5fAGsapyMQDCElnYb+mSmcfe0dECjHt4G9Js/nZXM0s71AKzE\nsVPdeO6Bm7B+xfXY+vYpnDhnR6/Dg3KjHiVFOrSd6cYfWy8lZLeLrc031lVyxEw0jjE4U0ZEb5/q\n7vdA0Krg9Sev2xXeBgQA+z6QzkiWky9br6Kzqv/+kzdG9oHvPHQBe1qGO0jFZ7ezxzJR4eF8GGVE\ndClNAIoCMzAcsGx2p2g2thJL50xAjbUkrc+Opvg1Zb1OgzKjHm2nxdfMW9u7Eqa4q2R6MBPR+MHg\nTCOmZLuPlEjAUolnYyvh8wfh8ozNvuRUiGVVy2a3X3twIaLCw+BMafH4Aui0OyNTs+l2hQoHLGt5\nETRq8QCdLGy/13ZFtNKYEjoNUG4U0vqsHL1ODUE7fOUGQYPQtTrY0eSysbl/mahwcc2ZUiJWmrOh\nthJmk6Co3nV8ecropCadVoWAN3E6XC+oMb/Wiv0fXRU950g6UvkCkN17nK4qc3FMMRW3N4B3jnRA\nFVcHW6tRodigE3244P5losLF4EwpESvNuaelA1OqjIqC88caqnHTDdWoqTJGKoUBQ9O7UmvObm8Q\nK+dPwoGPruZ0a0iVauiho6G2AsdOiU/zRyfA9Tk82HnwfEJFNACYUmVkNjZRAWNwJsXk1padbh9W\nNU7C+x9eFc2cVquBSRUl+PAvduw9diVhu5CxWCebdf3j3x5XFJjD1blGm8WkxxfvmoeyEgEXOx0x\n2dfRevrd+OXOkzhx3o6efo/kUrvT7Yc/EMp4KVEiyg8MzqSYfPKSB2sXT8WGW2rx2tvtQ8FnwIOy\nEgGzppZDr9fi3dZLkffHbxd6Y+9fZLdD9Q36FF3jWARmAJhfV4k/HbsUme5Xq4YqhMXT6dR47/iV\nyNdS18tmFkSFjcGZFJNr8xhOXtLrNLj/2h5eW68LCIVQZtRL1sxube/CumXT0HJyZLWxR5v62gjd\nUjq0dh4KhWKm+6WCrldhMRUmgxEVNgZnUkyuzWN08lIgGMSv3z0TGUWWGQXJpCv7gBsXOx2K1qtz\nycrGyVg1fxKgUqGsRJB8+JAaQSfDZDCiwsbgTEmFt0uVGfWKSknGJ43JZUObTQZUmYvSDmIjoVYD\nQZmBrFoFlJYMPViEr89i0mN+XSVUAL6/oy3pw0cIQGmxgH6n/MNH/EicyWBEhY3BmSSJbZsKJ3FJ\nlZJMtSBJY10lXB7/qAdmYCgw11SVoMM2KDoNfcuCydh4S22knaXL40eZUY9fv3tG8cOHqViH/sHk\nswIrGydj7aIpLM1JRAAYnEmG2Lap6CQusWSlPodHtiCI2ahH36AHZpMBc6ab4XT78b3tbZm/eIVc\n7gD+n4eW4Ve7T+PDv3RjwBWA2ShgYVTP6fB9moqFlB8++gd9srMCFpMeC+qHs9aJiAAGZ5IgF4TC\ne3XjR3iBYBA7D12QDEYVpQZ87TNNcLh82HXkIt4/fjntetqZYh9ww+UJwFgsQNBpoXIFoJaoVAbI\nZ6wDQw8f9riSm1KBefmcamxZW8+RMhEl4KM6iVJa8zm6jOe23aexp6VDMhg11lXCVCxgT2sH9rR0\njHlgBoCyEj12HjofadoRwvAMwbbdp2PeGwgGsfPgecm9yRWlBjx17wLJcqBq1VAp0opSA5qbavCZ\n22cxMBORKI6cCUBs0le4W5LctiljsYCtu9qHM7JLBLi80s0nJltLsGl17YiaZGSD3eHBn4+Jt6qM\nnyHYtvs09kTt1Y7XWFeJQDAk2cM6BODxu+dj+uQyBmUiksXgXODkkr6ktk0V6TX4j7dO4v0Ph2td\n9yZJehp0+eDxBbD17VNpN6nIFqmRfnQhELmHCrVqKKFr0+pa+AMhyYcai8nAwExEijA457H40W46\n5JK+Nq2uxcnzvQm1ny/aBnHRNpjS9+lzeLH17VPYF1UdK9eVG/WRQiBy0/yhELB20RRo1Gpo1FC0\nF5yISA6Dcx6SG+2mkvGbLOlr3bJpcLqVlc1Mxlyqx4lzPSM6h0EYCmweXwAqZH9fdEmRLhJM5ab5\nLaWx1byU7AUnIpLD4JyHkm1xUipZ0tfFTkfafZrjOZw+eP3KE8AMggZeXwDma12emhfWwFJqiFz3\nHw6cw5+Oiq8VZ4rTPTQVr9dpFFdHAwCNWo3NzXWSe8GJiJJJKzj7/X585Stfwfnz5xEIBPDEE0+g\nqakp5j2zZ8/GggULIl+/8sor0Gj4H9RIpbPFSUqypK+aKqPkcTla9VBWcnQZaaWBWaUCbmmcjA0r\nZ8Dh9IoGtipzMdYumpr14Gwf8MQ0n0h1RKzXadi4gojSklZw/u1vf4uioiK89tprOHXqFL785S9j\nx44dMe8xGo149dVXM3KRNEzJFielASHZaFDQaVA/1ZzyOnEKA+QEEyxFuPfj9QCAYr30X09LqQEG\nQZ32dixBq4I/EJKdGo9vPpFsRJyJHAAiIiDN4HzHHXfgk5/8JADAYrGgt7c3oxdF0pR0hkqF2Ghw\n/swKBEMhfPXF/ejp90TWeuVaOmaKvd+DAac3UipTPshJFwvR69TwyHSA8vmTL1hLJXDFj4gzlQNA\nRBSWVnDW6XSRP//85z+PBOpoXq8Xjz32GDo6OrB27Vp89rOfTf8qKSKVtU8lxEaDv373DN6JOn84\nKC+bUw29To22Mz2wD7gh6DIftD2+IJ5+6SD6Br2yQa7P4YFH4vuqVEBTfVVM3+R4ZpMeKhVEH3LU\nKmDl/EmKE7gylQNARBSWNDhv374d27dvj3ntkUcewYoVK/Af//Ef+PDDD/HjH/844XNPPPEE7rjj\nDqhUKmzZsgVNTU2YO3eu5Pcxm4uh1WZvKtBqNWXt3KPt4bsaUVwkYP/xy+jqdaGyvAhL5kzEfetm\nQ6MRH6kpuf8aAG6vH21nukWPn+7oww+fWA0AuNLtRCAYwH/vO4cDx6+g15G5vcvhPdPhIFdcJOCB\n9bF/d0xlRbCai9BpdyV83lpehEfubkTFzpN4++A5uDyJQfxj8ycDAN7cezbh2G1Lp+HzG+Ypula5\nn1fbmW58bkMRDMLY5l2Op7/76Sjk+y/kewfy+/6T/q+xceNGbNy4MeH17du3Y/fu3fj3f//3mJF0\n2D333BP585IlS9De3i4bnO12p9JrTpnVaoLNNpC18482jy+AZTdW4dbGSTHTvz094nuPU7n/TrsT\nNpGABwBdvS60n+3CntYOtLbb0i4mkmp7yPeOXcInFk9JmBWYO92Cd450JLx/7nQLnA4P1i+fhs1r\n6/H/vt6KE+ftsA94Iklc65ZOBQA4Xd6EBK9PfWxaxn5eZ/7aPaZJYePt736qCvn+C/negfy4f7mH\nh7Qe6S9cuIDXX38dv/zlL6HXJ65xnj17Fj/84Q/xwgsvIBAIoKWlBbfddls634qiiK1tNsyoQHPT\nFFhKDRlJQkq2pr3r8AXZEpZKTLYaEwqbyOnpF090k4rv/mAQnXYnyox6WIsE3P/JGyWTtUa65SnT\nOQBERECawXn79u3o7e3Fgw8+GHntpZdewiuvvIJFixahsbER1dXVuPPOO6FWq7F69Wo0NDRk7KIL\nldja5p7WS9jTegkVGUpC0us0aKitxJ4WkRHpDEtMyc50LJ9TjXtvq8P2PWfw3gdXIuvVekENny8o\nOqLWC5qEIOfxBXBUYkvZ3qOX8afWy7CU6rF83mSsWzpVdlvTSLY8ZToHgIgIAFShkFib+dGXzemH\nfJjeAOS34nh8AXz1xf1Jp5Kbm2oSkpCU3n94ZN5yshM9A97I9HNFqR6zpprhDwZx4KPOpOdRqYZK\nWsazmPR4/sElkXvz+AKw9boQCASx52iH5L5lg6DB9x75WORzgWAQL/3+I+xXcC2A+M8kk4ZnNBL3\nP491tna+/N3PlkK+/0K+dyA/7j/j09qUWUq24iTrIxyWaiGSaPEj8/Ao1uHy4r3jV2Q2LsWqNhfj\nck9iDsGCemvMdel1GtRYjdi6q122oIj32kNLlbkYgWAQX3/lcErT4iP5mSjBimBElGnchJkDwkFR\nrp9weG0zmehey6mQqzzm8Q1F6WRTLAZBA4OgwZUeZ+TP4f7FqxZMxqrGyfD4YjOnlbSQjF673fp2\ne0qBGUj/Z5Kq8PQ4AzMRjRSD8xhLVo4zHMzCa5vJpJuEpHRkLkavU+OmG6vg9gbg9gYQAiJ/Xjqn\nGg0zLGg73YWvvngAX31xP7buakcgGFT8fcNrtx5fAK2nulK+PiZmEVG+4bT2GEulHGd0Na/ufrfo\nZ9JNQpLLOk5m2Zxqyb2+Le22mCIl8QU65L5vdJ9kYOhn1euQ7xstholZRJRvOHIeY3LT1VK1nZ97\n4CY8/8BNWLVgMipKDVCrhqaOm5tq0m5LqHRkDgwFTVXU92xumiL5gCFVPSw8KyD3fVfOn4R7P14f\nWXcv0mtRbhQUXWNYkXhq+nIAABFXSURBVF6D9Sump/QZIqKxxpHzGEtnK45ep8HEihLc+/F6eFZl\nrtlCfJ1tQacRDa4r50/C2sVTI9/T4wukPOruiZoVSNbtKTphLtWRs8cbgMPplW2iQUSUa/g/Vg5I\ntRVhtEy2JYzOOu7pd+Otw+dx4MOrkc5PBkGDJbOr0Nw0JeFhYNZUs2wt63gqADsPnsfmNXVJs53j\ns8hTUVlexPVmIso7DM45YKRbcTLdqlCv02BPawfebY3d3uT2BrD/w068e63Ax/yZlQgBOHaqC939\nHhgENQAVvL6A5Kg7LBgC9rRegkajjuxBFnvQkEuY0+vUKDFo0evwSn6/JXMmcr2ZiPIOg3MOSXUU\nnK1WhXIBMRwAu/s9CXWtwyPsJbMnoP28XVG3qmR7kOUS5nz+IL5413wIWjWMxQLe2Hs2YfbhvnWz\nJWuOExHlKgbnPJatVoUj2VYFACfP9cKucG04PiM9nnztaj2s5UWRwC42+yDVpYuIKJfxf648pXR/\ntNznO+1O0fcpLXgixZ5CwY9ke5D1Og2KDYldzwCg2KBLGHGzEAgRjQccOecpudFtd78bPf1uTKwo\nSTgWPxVebtRjfl0lNjfPjEyFy2WQK5FKS8hke5A9vgAGXeKj8EGXL7Idi4hoPOHIOU8lG93uOnxB\n9PX4UqF2hwd7Wjrw9VcOR6p2AUMZ5M1NNZF91AZBeQBUEpjVKmDVgslJM9L7HB7YB8SDc6/DMypl\nOYmIRhtHzjkqWQa2XGtHAGg705MwqpSbCr/Q6cDWt9tx79pZAGIzyMOdo/7Udhltp7sjCVfzZ1Zc\ny9Yefq2htgLHTtnQIxFQw0IhYO2iKUkT19gvmYgKEYNzjkklA7t5YY1kcBZLtEqW6PXe8SvYcEtt\npGBHIBjEr989E3MtDTMq0Nw0BZZSQyTwb7wl9kFCo1YlnRK3lIoH1vBDSZFeC5fHjzKjnv2Siajg\nMDjnmFQysC2lBlSkMKosM+pRbtRLJmx5fUG89nY77v/kjZLXEr83GUjcApZODXC5XtLzZ1Zi9cLJ\nMSN0pUVaiIjyEYNzDpGbdj58ohPrlk2DqXi4tnSqpT/1Og3m10lPhQPAifP2SAa3XDa43N7k+Epj\nuw5fQNuZHtnAKtVLOryfurmpBs89cBP7JRNRQWBwziFy0869Di+eefkQFs6KneJOtfTn5uaZOPFX\nOy73OEWP2weGk6yUdsuSEqkBvnaW7Bq6kp7O4QeCTJUqJSLKZQzOOSRZ20a7I3GKO9XSnxq1Gl/5\nu4V47P++B48vmHA8PB0eCIagF9SRql9i70mlbKhc9TMlRU+UPhAQEY0HDM45ROn+4tb2LqxbNi2S\nMKXXaVIq/Vms12HFvEmS0+EAsPXtdtHADADzZlYkJIqNpGyokl7SzMwmokLC4JxjwtPRh090SrZH\n7O534+mXD6LP4U07MIpNh8+bWYFQKISvvrhfMlAaBA2CwRB2tw6vW4+0bKiShxJmZhNRIWFwzjHh\naep1y6bhmZcPSWZWhwN3KoExfho6fjr81++eSTpq93gDOHaqW/RYskQxOeGHhZaTNvQMeGKytcMP\nH0REhYLBOUeZigUsnKW8hGZ8YIwOxIFAEFt3tYtOQ4enw5UkZQFAmVFAr8QDw0jWhePXzqP3OXPE\nTESFhsE5h8VPPZeVSO9RDgfGijJDQhGTMqMeZy/1R94rNtpW2omqcWYl2s50Z61iV/TaefS2MSKi\nQsLgnMPERpNff+WQbGAUKxwitX4cPdpOlpRlMemxoP7a2rbmNCt2ERFlEYNzHogeTcoVHQGkC4eI\niZ6GlkvKWj6nGlvW1kcCb6p7q4mIKDUMznkkEAzCHwhCr1XD4x/a5mQQNFg2txqbVteiu8+taGo6\nLH4aWi7oRmeCp7q3moiIUsPgnMOik7q0GhW+/sphXOh0xLzH7Q3A6fLjctegov3C0eKnoVMNuqns\nrSYiIuXSCs6/+c1v8P3vfx9Tp04FACxbtgyf//znY97z5ptv4uc//znUajXuuusubNy4ceRXWyDE\nOlMZ9Fp02AZF37//o6vY/9FVGAQ1KsuKACQG5ylVRjjdfkXT0Ay6RERjK+2R8+23344nn3xS9JjT\n6cQPf/hD7NixAzqdDnfeeSfWrFmD8vLytC+0kIgldYkF3HhubxAXbYMJgXj5vElYt3Qq/IEQp6GJ\niPJAVqa1jx07hrlz58JkMgEAFixYgJaWFqxevTob3y4vKK1D7fEF0HKyc0Tfa9Dlw9OfXRTZJ1wz\nqRw22wA0anBETESUB9IOzgcPHsT9998Pv9+PJ598EjfeeGPkWFdXFywWS+Rri8UCm00+i9hsLoZW\nm73RnNVqysp53V4/7P0emEv1MAiJP85AIIiXf/ch9h+/DFuvC9byIiyZMxH3rZsNjUad8N4f/Ooo\negbEy3YqZR/woKjEgOnXlURey9b954NCvneA91/I91/I9w7k9/0nDc7bt2/H9u3bY177m7/5Gzzy\nyCO45ZZb0NraiieffBK/+93vJM8RCoWSXojdLt7CMBOsVhNstoGMnlNsXVisxvXWXe0xU9Sddhfe\n3HsWTpc3odzm1l3t2K2wIpgcs0mPgNcXueds3H++KOR7B3j/hXz/hXzvQH7cv9zDQ9LgvHHjRtlk\nrsbGRvT09CAQCECjGRr5VlVVoaurK/Kezs5OzJ8/P5Vrznli68LxVbfkSmKKldtMtkc5vJbc0++G\nTqeGV6TlIwAsqLdyTZmIKI+l3t8PwIsvvojf//73AID29nZYLJZIYAaAefPm4YMPPkB/fz8GBwfR\n0tKCpqamzFxxDkgWdD2+AAD5kpjhAiBhycpnLptTja99pgnPPXATvvW5Jfjuwx/DrQsnwyAM/9wN\nggarF05mMRAiojyX1przunXr8KUvfQmvv/46/H4/nn/+eQDAT37yEyxatAiNjY147LHHcP/990Ol\nUuGhhx6KJIeNB0qCbpW5WHbfcXwBELn3VpTqce/aemjU6pikrv+1ph533lILW68LCIVgvVbpi4iI\n8ltawbm6uhqvvvpqwusPPvhg5M+33XYbbrvttvSvLIcpDbpyJTHjC4DIv1d6mlqv06DGakz3VoiI\nKAexQlgaUgm6qdShZs1qIiICGJzTpjSQplISkzWriYgIYHBOWzbrULN8JhFRYUsrW5uGhQNpvoxw\nPb4AOu3OSEY5ERHlHo6cC4TSoilERDT2GJwLhJKiKURElBvG7ZCJ07fD3F6/oqIpRESUG8bdyFls\n+nb5vMlYt3Rqzk/fKu1clSp7v7KiKURElBvGXXAWm76VajSRK7K9HmwuVV6pjIiIxl5uDyVTpLTm\nda4JP1B093sQwvB68LbdpzNyfoOgRWOdVfRYfNEUIiIae+MqOKfSaCJXjNYDxabVtWhuqkFFqQFq\nFVBRakBzUw2rjxER5aBxNa2dSqOJXKG0icZIsfoYEVH+GFcj53DNazG5On0bfqAQk40HinwrmkJE\nVIjGVXAGxKdv71gxPWenb/PxgYKIiLJrXE1rA+LTtzWTymGzDYz1pUliNyoiIoo27oJzWD41j8jm\nerDHF8DlrkEEfAGOwomI8sS4Dc75KJMPFDF7pwc8sJhYS5uIKF8wOI9TrKVNRJS/OIQah/K1GAsR\nEQ1hcB6H8rEYCxERDWNwHodGe+80ERFlFoPzOMS900RE+Y0JYRmWrbaPqeLeaSKi/MXgnCHZbvuY\nqui90xpBh4DXxxEzEVGe4LR2hmS77WO69DoNJlaWMDATEeURBucM4NYlIiLKJAbnDODWJSIiyqS0\n1px/9KMfYd++fQCAYDCIrq4u7Ny5M3L84sWLWLduHebMmQMAMJvN+MEPfpCBy81N+dhHmoiIclda\nwfnzn/88Pv/5zwMA/vM//xPd3d0J77n++uvx6quvjuzq8kR461J0ucwwbl0iIqJUjShb2+/347XX\nXsMvfvGLTF1P3uLWJSIiypQRBee33noLH/vYx2AwGBKOdXV14Qtf+AI6OzuxefNm3HHHHSP5Vjkv\nm20fiYiosKhCoVBI7g3bt2/H9u3bY1575JFHsGLFCtx///149tlnUVNTE3Pc4XBg586duOOOOzAw\nMICNGzfitddeQ1VVleT38fsD0GoZzIiIiJIGZylOpxMbN27Ef/3XfyV976OPPop77rkHS5YskXyP\nzTaQzmUoYrWasnr+XFfI91/I9w7w/gv5/gv53oH8uH+r1SR5LO2tVCdOnMD06dNFj+3fvx/f+ta3\nAAwF8RMnTuD6669P91sREREVlLSDs81mg8ViiXnt+eefx4ULF9DU1IS+vj5s2rQJn/70p/Hggw9i\nwoQJI75YIiKiQpD2tHamcVo7ewr5/gv53gHefyHffyHfO5Af95+VaW0iIiLKDgZnIiKiHMPgTERE\nlGMYnImIiHJMziSEERER0RCOnImIiHIMgzMREVGOYXAmIiLKMQzOREREOYbBmYiIKMcwOBMREeWY\nggjO3d3d+Pu//3vce++9uPvuu3Hs2LGxvqRR4/f78eSTT+Kee+7BXXfdhcOHD4/1JY26gwcPYunS\npdizZ89YX8qo+uY3v4lNmzbh7rvvRltb21hfzqhrb29Hc3MzfvnLX471pYy6b3/729i0aRM2bNiA\nt956a6wvZ1S5XC48+uij2LJlCzZu3Ji3/+61Y30Bo+HNN9/E3/7t32LdunU4ePAgvv/97+Pll18e\n68saFb/97W9RVFSE1157DadOncKXv/xl7NixY6wva9ScP38eP/vZz7BgwYKxvpRRdfDgQZw7dw7b\ntm3DmTNn8NRTT2Hbtm1jfVmjxul04hvf+AaWLl061pcy6vbv349Tp05h27ZtsNvt+NSnPoWPf/zj\nY31Zo2bPnj3/f3v3D5JaFIAB/BNvRtHfK9ewLVqKIlqaoqJoimgTWguChhqL4g7NRrQooZiDQ2Bo\nBEFDEVE0BOGoREtLiFEXScqSQHhDcHnCe5EP3j3q+X7TuWf6DlzOxz2IB/39/VhYWEA6ncb8/DzG\nx8dFxyqbFOU8NzdnjjOZjFTXV87MzGB6ehoAoKoqXl5eBCeylqZp8Pv90HVddBRLXV9fY3JyEgDQ\n3d2NXC6Ht7c3NDU1CU5mDYfDgVAohFAoJDqK5YaGhjAwMAAAaGlpwcfHB4rFIux2u+Bk1piamjLH\n1bzfS1HOwNf904uLi8jn84hEIqLjWKaurs4cRyIRs6hl0dDQIDqCEIZhoK+vz3xWVRXPz8/SlLOi\nKFAUaba3Ena7HY2NjQCAeDyO0dFRaYr5d7Ozs3h8fEQgEBAd5Z/U3Nsbi8UQi8VK5paXlzEyMoKD\ngwNcXl5ifX29Jo+1v1v73t4eUqlU1b6oP/Hd+mXHf+mVz9nZGeLxeE3udT8RjUZxe3uLlZUVHB0d\nwWaziY5UlporZ4/HA4/HUzJ3c3ODXC6H1tZWjI2NYXV1VVC6/+tPawe+Suv8/Bw7OzslX9K15m/r\nl5HL5YJhGObz09MTNE0TmIisdHV1hUAggN3dXTQ3N4uOY6lkMgmn0wm3243e3l4Ui0Vks1k4nU7R\n0coixa+1T09PcXh4CAC4u7uD2+0WnMg6Dw8PiEaj8Pv9qK+vFx2HLDI8PIyTkxMAQCqVgsvlkuZI\nW3avr6/Y3NxEMBhEW1ub6DiWSyQS5mmBYRh4f39He3u74FTlk+JWqmw2i7W1NeTzeXx+fkLXdQwO\nDoqOZYnt7W0cHx+js7PTnAuHw3A4HAJTWefi4gLhcBj39/dQVRWapklzzLe1tYVEIgGbzYaNjQ30\n9PSIjmSZZDIJr9eLdDoNRVHQ0dEBn88nRVnt7+/D5/Ohq6vLnPN6vSV7QC0rFArQdR2ZTAaFQgFL\nS0uYmJgQHatsUpQzERFRNZHiWJuIiKiasJyJiIgqDMuZiIiowrCciYiIKgzLmYiIqMKwnImIiCoM\ny5mIiKjCsJyJiIgqzC8iivHPF8qqogAAAABJRU5ErkJggg==\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f7a18dfb8d0\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], - "source": [ - "# Plot the Data (Optional)\n", + "assert f(pi/2).numpy() == 1.0\n", "\n", - "import matplotlib.pyplot as plt\n", "\n", - "plt.scatter(inputs, labels)\n", - "plt.show()" + "# grad_f will return a list of derivatives of f\n", + "# with respect to its arguments. Since f() has a single argument,\n", + "# grad_f will return a list with a single element.\n", + "grad_f = tfe.gradients_function(f)\n", + "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "JaFHyAG9nDET" + "id": "v9fPs8RyopCf" }, "source": [ - "## Step 2: Define our TensorFlow variables\n", + "### Higher-order gradients\n", "\n", - "We'll use Keras's object-oriented [`Dense`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense) layer to create our variables. In this case, we'll create a `Dense` layer with a single weight and bias." + "The same API can be used to differentiate as many times as you like:\n" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, - "base_uri": "https://localhost:8080/", - "height": 34 + "height": 276 }, "colab_type": "code", "executionInfo": { - "elapsed": 332, + "elapsed": 730, "status": "ok", - "timestamp": 1525154229931, + "timestamp": 1527005655565, "user": { "displayName": "", "photoUrl": "", @@ -190,54 +120,61 @@ }, "user_tz": 420 }, - "id": "z9r-ZeyrXu3A", - "outputId": "e19a698e-5892-4fcd-80d3-1394605ee72c" + "id": "3D0ZvnGYo0rW", + "outputId": "e23f8cc6-6813-4944-f20f-825b8a03c2ff" }, "outputs": [ { "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAEDCAYAAAAhsS8XAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXd0HNX5sJ/ZXrTq3ZLV3IvcDdgGGwOm2WCbHhJa6C2B\nUBISQioBfoQPkjhACA4QCIQSDITQbGMbsHHvVbZ6s7q0vc18f4xmJVltJa0q+5zDOXhn9s7dqzvv\nfe/briBJkkSYMGHChBkxqAa7A2HChAkTJrSEBXuYMGHCjDDCgj1MmDBhRhhhwR4mTJgwI4ywYA8T\nJkyYEUZYsIcJEybMCCNkgl0URVasWMHtt98eqibDhAkTJkwvCJlgf+2118jJyQlVc2HChAkTppeE\nRLBXVlayceNGrrjiilA0FyZMmDBh+kBIBPvjjz/OQw89hCAIoWguTJgwYcL0gT4L9g0bNhAfH8/E\niRMJVycIEyZMmMFH6GutmGeeeYYPP/wQtVqN2+3Gbrdz3nnn8dRTT3X6HUmSwtp9CKittvH8UxsQ\nxZY/4aXXTGfa7PRB7NXAU1dj5y9PrIfmYUgeFcnya2aQmBI5uB0bYE5WNPHS/9uE6JcHYukVucw8\nPWOQezXw7NhcyCfvH0Bqfi+uumkO4ycnD3KvBpY+C/bWbNu2jdWrV/PCCy90e291tTVUj+03EhIs\nQ7qfWzfls2tzMTNPH01UrJEv/3eU5LRIVnx/5mB3rUP6azw3fnaMQ7vLOX1RNrVVNvIOVZGeFcPS\nq6YNmT6GmlP7KYoi/3ltF9WVNhacO4btXxfi9fi5+Mpc0jJjhkw/+5t9O0r5Zu1xDEYtpy/KZuOn\nR4mOM3HlTbNRqTo3UAynv3swhOPYhymSJJF3sAqtTs35l05mQm4K6VkxVJY2UVdtH+zuDRgOu4ej\n+yqIjDYwbW4a514yiYTkCMqKGnC7vIPdvQFjz9YSqittjJuSxNTZaVywcgoAX3xwCL9PHOTeDRyH\ndpej0ai47PqZTJyWwoTcFOprHBzdf3KwuzaghFSwz507NyhtPUzfOVnehLXRRdbYeLQ6DQATp6UC\ncGhv+WB2bUA5sLMMv19i2pz0gEaWNS4BUZQoOlE3yL0bGDxuHzu+LsRk1jH/nDEApI6OZtL0VFxO\nLyfLmwa5hwNDU4OT+loHozJiiIw2AjB7QSYajYrtXxXg9foHuYcDR1hjH6bkHawCYOzkxMBnmWPj\nMJq1HDtwEt93YBJ7PT4O7CrDYNQwPrfFhpo1Lh6AgmPVg9W1AaWyrBG/X2JCbjIGozbweXqWbIIp\nLawfrK4NKMX58kI+Oic28FmERc/UOWnYbR7yDn53tPawYB+GiKLI8SNVGEzaNvZTtVrFhKkpuF0+\n8o+OfKGWd7gKt8vHlJmj0GrVgc9j4kxExRopzq/7Tixw5cWNAKSkR7f5PHV0NIIApUXfEcHevEMb\nnR3b5vPxU5IAqChpHPA+DRZhwT4MKS2sx+XwMmZCYjuH0MRpKQAc3lsxGF0bUJQXNWdiYpvPBUEg\ne1w8Pq9IyXdAW60oaUAQ5Gig1uj0GhJTIqkqb8Lj9g1S7wYGn89PWXE90XGmgBlGITrWhN6gobIs\nLNjDDGE6MsMoRMUYSUiOoLKsacQ7zaoqrGh1amLiTO2uZY1LAKDgWM1Ad2tA8Xr9VFVYSUi2oNNr\n2l1Py4xBkqC8pGEQejdwVJQ04vOKZJyirYO80CeNiqSpwYXD5h6E3g08YcE+zJAkiZKCOswWHUmp\nHcdpJyRbEEWJupqRGx3jdvloqHWQmGLpMCciMcWCOUJH0fEaRHHkLnBV5U2IokRKelSH10dlyOaZ\nkW5nD5hhcuI6vJ48Sh6fyrLvhiM5LNiHGQ67B6fDS2JyZKdJXgnJcqxrdeXQj8vtLcpv6ywJSRAE\nMsfF43L6RvTLXF4sa+Kn2tcVkkdFodGoKCvqu8b+zjtv8f3vX8Fvf/ton9sKNUX5tWi0KlLSOl7g\nFDPVSJ4LrWm/dwszpKk5aQMgLimi03u+C4JdCeFLTOk8YSMlLYqDu8qpOWkjtRPBN9wpb/YzpHai\nsas1KlLSoygpqMdhc2OK0Pf6WWvWvMsf//hnkpNTet1Gf9BY76Sxzknm2DjUmo51VXlnBye/I3b2\nsMY+zKitkgV7fGLngj02wYxKLVBdaRuobg04VRWyYO/MHAUQlyCPkTJmIw2/X+RkeRNxCWb0Bm2n\n943KaA577IPW/vTTf6C8vIyHH76ft99+s9ft9AeKUzQto/MMW61OQ1xiBFWV1hHve4Kwxj7sUDT2\n+C40drVaRVyCmdpqG36/iFo9stZvSZKoKrditugwWzrXQKNijahUwojLxH17/XF25VXj8fhx+Hzo\nGh1s/+vmTu8X/SJ2RA59egTDxhMd3jNnQiJXLh7TaRsPPPAztm79lj//+UUiI4dWDZ76Zl9SXBfK\nDshmqZqTNqpPWgM295HKyHrjvwPUVNnQGzRERHa9pU5ItiD6pREn1ADsVjcOu6fbIl9qtYqYeBN1\nNfYRWXlU0Ty7W7hVzddFf181VYlApbUhhDLHYxPMXd6XnCbPl5PfATt7WGMfRng9PhrrnM2JJ11X\nx5Tt7BVUn7QGbO4jhZPliuO0+98VlxBBbZWdpgYnUTHtwyKHI1cuHsNdV83g1b9upuhELdf/cG63\ntvN/v7ydpgYnN99xxoirrFpX48Bk1rXJuu2IlsiYRqYxsiughjX2YURts2bSlRlGocWBOvLsy4p9\nPZiyvLGJshZXWzXydi71tXYMJm1QDtHYeBM+r4itaWTFcXs9PqyNLmLiu1+0IyL1mCN0VJY2jcgd\nXGvCgn0YEbCvd2NLBIiNN6NSCdSMwMiYqoqeaeww8hyoPq9fFmixwe1CouPkBa6+ti8L3NDT9Otq\nHED3ZhiQQ2ATUyNx2D3YbZ7+7tqgEhbsw4hgHKcKao2K2AQztVWyA3WkIIoS1ZVWYuJNHWZankpc\n8wtfO8J8DbLfAKI7yLrtiNhmjba+WRD2hnfe+YDIyKHldAzY1+O7F+xAIEu5sa734zAcCAv2YURt\nlQ2VWgj6ZU5ItuD3S4GogZFAU4MTr8dPQlJwfgNThA6DUTPinMi11fIi31E5hY5Q5kx97cgSaMrc\nDkZjB7luDEBDnbPf+jQUCAv2YYIoitRW24mNNwcdvjgS7eyN9fILGR1r7OZOGUEQiE2IoLFeXhBG\nCjXNpqXoYE0xMSYEoa+mmKGHUjYjJi44wR7VPG/CGnuYIUFDnRO/TwzKDKOg3DuS7MuNzZpWVJAC\nDVrMMSOpdk5AsAepsas1KiJjjNTXOEaU47Cuxo7ZokdvCC7AL6yxhxlS9MRxqqBotU0NI2cSN9bL\nmlZUTHAaO7Qkrijmi5GAYpazRBmC/k5snBm3y4fTMTKODHS7vNitnqDNMAAGoxaDUUNDWGMPMxRQ\nbMTdZde1Rm/QojdoaGxw9Ve3BhzFFNMTwa68+HUjJORRkiRqquxExciZtcESHXCgjoxxCETEBBHq\n2JqoWBNNDc4RFVRwKn0W7B6PhyuuuILly5ezbNky/vKXv4SiX2FOQdG6eyLQlPubGpyI4sjYfjfW\nOzGatEFFxCgoERMjJTLGYfPgcfuCdpwqxI4wB2rAcRpkRIxCdIwRSQJr48hReE6lz4Jdp9Px2muv\nsWbNGtasWcOmTZvYt29fKPoWphVNDU7UGhWmCF2PvhcZbUT0S9itwz8xxe8XsTa6Ag6wYNHq1ETF\nGKkbIaYYRTAHa19XiGkWgL0NeWxdtvebb77ijTdeDfq7lZUVfPHFp0Hd+/jjv2bjxvXd3te6lMCa\nNe/x2Wf/C6r9qICdXR6HTz75L9XVLUdJPvnk7ykqKgyqraFKSEoKGI3yi+bxePD5RvYRXINFU4OL\nyChDj9PBFQ2/sd7ZI3vsUMTa6EKS6FVpgMgYIyX5TjxuX4+0/aGIIpCCTU5SUByHvY2MObVs7/z5\nZ7a7x+/3o1ar231eXl7GF198xnnnXdCrZ3eE4gyPjDawfPllQX9PGQfFEf+//33EzJlTSUrKAODh\nh38esj4OFiGZ4aIosnLlSoqLi7n22mvJzc0NRbNhmnG7vLhdvnZnWgZDZLQszGVTTudlTYcDgYiY\nHpqjACKjlHFw9SiyaCjS0EuNXatTY4nU98oU07ps78UXX4LFYuHIkUPcd99DPP74r7FYIsnLO8r4\n8ROZP/9MnnvuaQRBQKvV8OyzL/Dii6soKirkppuu5YILlnLllde0af+ZZ55k9+6dpKSktonaOXr0\nCH/+8zO4XC6ioqL5+c8fIzY2jnvuuQ3RFUt1XSGxH5Zht9sxmUycccYCfve7x3jpJXk3UVlZwcMP\n38+rr77JK6/8nW+++QqHw4lWSmTS9HvYsGEdR44c5sEHH0Sj0fL886t54IF7ufvu+zh8+ADl5eXc\neee9gKzZHz16hB//+AE+//wT3nnnLfx+H5MmTeEnP/npkKrBExLBrlKpWLNmDTabjTvvvJPjx48z\nZkznJUDD9IymZufnqYf0BoMiBEdCZExDLyJiFJQFztroHPaC/Vv315ROK+RP+ZsRCnomTJxjPfh8\nInnfbGgjiGYkTmXlmKWdfu/Usr2ffPLfNt8vLS3mT396AYCHH76Pn/zkp0yZkktEhIamJg+33343\nb731Ok8++f/atb1x45eUlpbwz3++TU1NDd///hUsXXopPp+PZ599iieeeIaoqGjWrfuCF19cxc9+\n9kskScLpsHPdlT9l6VXTWL36bwBkZGTi9/uoqCgnJSWVdes+55xzzgPgssuu4oYbbsbn9XPj9+9k\n566t3P/Idbz33tv88pe/ICGhbWGwRYvO5fbbbwwI9nXrPuf6639IUVEh69Z9zgsvrEatVvPHPz7J\n559/wvnnX9Sjv0V/EtI9aUREBHPnzuWrr77qVrAnJAyPioNDoZ/VzdUMU9OiO+1PZ58b9HLFO5fD\nNyR+S1/64HHKCUaZ2fE9bidttLxb8fukbr87FMapK9xuH6oIAU0v6uyrNSp8PhEkUKtbBLPJqOv2\nd6tUEBdnJjragsViwNj8HYNBy8KFSwPfP/30uTz//HMsW7aMJUuWkJSURHS0CZ1O0+Ezjh07wIoV\nl5KQYCEhwcK8eWcQGWnEZquhoCCfBx+8F0mSEEWRxMREEhIsCAhkpE4nMTmShAQLZrMes9lAQoKF\npUsvZuvWTdxyyy1s2rSeZ599loQEC7t2bebll1/G6XRSVV9FWXkaCQkWtFo1ktQyL7RaNTExJsaO\nTSczM4OKigJGjx5NeXkpixcv4I033uD48WPccceNSJKE2+0mLS15SM2bPgv2uro6tFotFosFl8vF\nli1buPXWW7v9XnX10C9OlZBgGRL9LC2WDyJWaYQO+9NVPyVJQqNVUV1pHfTf0tfxPFkhn5QjIva4\nHalZhlWUNnb53aHyN+8Mr9dPbN5YZo45gwvPn9Lj7x/aU87GT49x9sUTmDA1uc217n63KErU1trw\netVYrS6cTg/V1VZcLi8+X8vcXLHiGqZNm8uWLV9z5ZVX8swzq2hocODx+Dp8htPpwWZzB6653V6a\nmpzU1dnIysrm+edXt+uny+VFY9Gh0amorrZit7uRJDXV1VZOO+0sHn30p8yaNQ+/X8JojKGsrJZf\n/erXrF79OvHxCTx8/2+wNjgpL6vH6/W3+f1er5/6egfV1VYWLDibd99dQ0ZGJvPnL6S62orV6mTJ\nkou47ba7ejR+oSDYxaPPUTHV1dVcd911XHrppVxxxRUsWLCAhQsX9rXZMK1QzCi9McUIgkBktJHG\nBuewzzhsqHNiMut65fxsbYoZziip8PGJPQvxU1Ac6LZ+DPUrKyslOzuHa6+9nilTplBcXIjJZMZu\n79hpO23aTNau/RxRFKmpqWHXrp0AjB6dSX19AwcO7AfA5/NRUJAPtBwy0tE7MWpUGmq1ilde+TuL\nF8tmGI/HgyBAZGQUDoeD44W7geY5ZTJhs3UcMbVw4WK++mpDG5POrFlz2bBhHfX1ssLV1NREZWVl\nr8aqv+izxj5+/Hjef//9UPQlTCcoSTmW6N5FtcihfnacDi8mc8/CJYcKfr+IrcnV6yPN9AY59r1p\nmMcuK6nwSjninqIIdmtTb8YhOHv+O++8ya5dO1Cr1YwfP47TT58PgFqt4cYbv8eFFy5r4zxduPBs\ndu3azvXXX016egYzZswCQKPR8LvfPcmzz/4fNpsNUfRz5ZXXkJWVjd8vtfk9p7J48RKef/5P3HLL\nnYBsJl62bAXXXXcVKSmpZGeNw14vv1sXXbSMxx57DK1Wx/PPr27jO7BYLGRmZlNcXMiECZMAyMzM\n4pZb7uT+++9CFCW0Wi333/8QycnJHfZlMBCkQVLjhvJ2V2GobMtff/5b/H6R6++e1+H17vq5ef0J\n9m4rYcX3Z5CcNnhlV/synvW1dt56aTsTpiZz9sUTetXGO//YQUOtg5t/cmanEQxD5W/eGbu/Lebb\nDflcddOcwCEiPcHn8/PS018xKiOaS66Z3g89bEt/jeen/zlAwbEarr9nXq+UleL8Wj5+ez9zFmQy\ne0HmkP+7KwyYKSZM/6JoqpG91NahVSz7MI6MCZQS6GFyUmssUQZ8PhGnffgesqBkS0b38pg/jUaN\nyawb9lmX1kYXGo0Ko6nr4/A6I1AMrH5kZOGeSliwD3FsTW4kCSKjei/QomLkRUERjsORvsSwKyj2\n2OFsjrE1m1D6Mg4RUfrmeTV8fS7WRheWXiTsKUREGlCpBJrqh+9c6IqwYB/iBBynoRBoI0Fj78OB\n1C3JWsP3ZbY2udHp1d0e3NwVkVEGRFEatsfDuV0+3C5fnzKpVSoBc4QOm3X4zoWuCAv2IU5LclLv\nJ7GinQxrjT0g2Hs/DgHH4TBd4CRJwtroIiKyb6UhlO/3Z2RMf6LsWvpaIiMi0oDd6hmRVR7Dgn2I\n05dQRwWVSsASbRjW205rkwuDSYtW1/tAroDGPkwFmsftw+vxY4nU96kdRSAO13FQ+t3bKDGFiCh5\nHEdCgbxTCQv2IU6LYO/bJI6KNuJyyjVnhhuSJGFvchNhCZFAG6amGGujLIAi+qipBmLZexXyOPhY\nlV1sCDR2kP1YI42wYB/iNDXI3v++xp8PZzu72+XD5xOJ6KOmqtGoMUcM34gQJfbc0kdTjPL94TYO\nu3fv5KGH7gv0uzNTzD333MbRo0e6bU/Z+diaXPzpT39i587tverX22+/idvdsjg89NCPsdsHt0R0\nWLAPYSRJoqnBiSW6995/BWXbaRuG207lRY6w9L3ssCXagK3JNSztqoqG3dcFztI8F4abYAcQBLoV\n7MGiaOyNDU7uvfdeZs2a06t23nnnTdzulrF86qlnMZsHt9Dc8C5MPcJxu3x43H5S0ntvX1dQzBjD\n0Z6oLEZ9FWggh41WljZht7r75LcYDBRTTF8FmlanwWDU9Eiwu1wufvnLn1JdXYUoilx//c0sXnxu\np2V1y8pK+b//exybrQlJEvjtb58gNXUUq1Y9x9atmxEEFddddxPnnHMeu3fvZPXqvxEVFU1BwQkm\nTJjIo4/+FoBvv93Mn//8DNHRMYwdOx6ApkYnGq0qEBnkdrt5/PFfU1RUSEZGBh5PS7TP9u3f8vLL\nf8Pr9TJqVBqPPPIYBoOBK664hLMXXcAXm7/Er1vGV9veZNas09HrDfzvfx/xm9/8AZB3Cf/+9xs8\n8cQzPP30Exw9egi3282iRedw00238u67b1FTU80999xOdHQ0zz33PFdccQkvv/xP3njjNZKTU1ix\n4nIAVq/+G2azmauuupZ//euffPnlF3i9Ps46axE33dR9fa2eEBbsQxjlxeurLRFaBPtwtCfam0In\n2C2tQh6Hm2BXNHb/hv+yY9WePu065tg8iH6J/IffBsAyew4JV1zd6f1bt24mPj6Bp556FgCHw95l\nWd1f//oXXHfdjaxYsZTy8jpEUWTjxvWcOJHHa6/9m/r6Om6++TpmzJgJQF7eMV5//R3i4uK4444f\nsn//XsaPn8hTT/2eP//5RUaNSuOXv/wZ0D6Gfc2adzEajbzyyr84ceI4N910LQCNjQ28+upqnnvu\nr+j1Bt5441Xeeut1brjhZvk3R5pZMu8uRqfHcrSkVB6XOafx9NN/wO12odcbWLfuCxYvXgLAbbfd\nhcViQRRFfvSjO8jPP87ll1/Nv//9ZqCcsYzcr3PPXcJzz/0xINjXr1/LM8/8me3bv6W0tJiXXnoN\nSZJ4+OH72bt3D9OmhS4TOCzYhzC2EAo087DW2BUTRN8XuMCBG43D7+ARa5MLlUpAq1XTVxe4oBKQ\n/CKSJJs3uiM7ewyrVj3HCy/8hTPOWMC0adPJzz9Bfv4J7rvvruayuhLx8Qk4HA5qaqpZsEAuBqjV\nypr1vn17OPfc8wGIiYllxoxZHD58CJPJxKRJk4mPjwdgzJhxVFRUYDAYSU0dxahRaQAsWXIhH3zw\nH3kXm9ayKO/Zs5srmhelnJwxjBkzDoCDBw9QWJjPHXf8EEmS8Pl8TJkyLfC9JUvO57//ymuj7KjV\nak477Qy+/vorFi1azJYtX3PXXT8CYN26z/jwwzX4/X7q6mopKCggO3sMIDX/pyD//9ix42loaKC2\ntob6+noiIyNJTEzinXfeYvv2bdx007VyXXmni9LS4rBg/66gCGFzH6NBWrcxHCMhrMoCF4JxaHEi\nD79xsDW6MVv0JF55NQl33dKn2ibfrDvOvu2lrLxuJkmp3Z/MlZ4+mpdffp0tW77hxRf/wty5p3PW\nWYvIzs5pV1bX4ei4iuOpma6t/60IfwC1WoXf3/HS5fPKu5RTzVGtfVBKu5IkMWfO6Tz22O86bMto\nNBIRaWj3TixefB7/+c/bREZamDhxMkajkYqKct566w1efvmfmM0RPP74r/F4uleSzj77HL78ci21\ntbWcc86SQL9+8IMbuOSSFd1+v7eEnadDmIBtOQQCTa2WD8Iejs5TW5MbQQCzpe+VKZXdj32YmaT8\nPhGH3ROyc2t7GvJYU1ODXq9nyZILuOaa73Ps2NFOy+qaTGYSE5P46qsNAHi9XtxuF9OmzWTdui8Q\nRZH6+nr27dvDpEmTO31mRkYmlZUVlJeXAbB27Wf4mmuntx6H6dNn8PnnnwCQn3+cEyfyAJg8eSr7\n9++lrEw2s7jdLkpKits8IyJSj8ftlw8faWbGjFkcO3aUDz9cEyjVa7fbMRqNmExm6upq+fbbzYH7\nuypJvHjxeaxb9zkbN67n7LPPAeC0007n448/xOl0No9tdaAEcKgIa+xDmFBq7CAvEDVVNiRJGlLn\nM3aHvcmFKUKHStV3PSSwcxlmC5xijuprcpKCEvIYbJJSfv5xVq16DpVKQKPR8sADP+uyrO4vfvFr\n/u//HueVV15CENT89rdPsHDh2Rw8uI8bbrgGQVBx5533EhMTS2FhQZtnKXNTp9Px4IOP8OCDPyI6\nOobc3OmcrKiT+99KsC9ffjmPP/5rbrjhe4wdO45Jk+QDSKKjo3nkkcf41a8ewePxIggCt9xyB+np\no1Hs4Ip5T1kwQD7qc968BXzyycf84he/BmDMmLGMHTueH/zgKlJTR5Gb22LSueSS5TzwwL3Exyfw\n3HPP07q8cVZWNg6Hg4SEJGJj4wCYM+d0iooKuf32GwEwmUw8+uhviYkJnWkwXLa3Cwa7lOcH/9pD\neXEDtz54FuoujkELtp99LXXaV3oznqIo8dLTm0hIsbDyBzND0o9X/vQNOr2G7912Wkj6OBCUFtbz\n0Vt7mTUvg7lnZfW5nzUnrbzzj51MmZnKmUvGhbCnbQn1eH79RR77d5Zx+Q2zSEju+1F0u7YUsXVj\nAVf/cC4xCb2vQzRQhMv2jgDsVjdGs7ZLod4ThmPIo8PuQRSlkJijFMwWPXbr8KpuGKr6KAqBujnD\nLJY9lKGvcjtKlNTwS9zrirBgH6JIkoTN2vc0+tZERA6/kMdQJeW0JsKix+cTh1V5hUCSVojGQT5R\nSh1wTA8X7FY3KrXQp+qWrVHGczifVdARYcE+RHG7fPh9Ysjs69CqNsYwKlVqDziQQ6OpwvAM/VQW\n41Bp7CDb2a2NrmG1c7Fb3Zgj9CHzEQV8DcO48mlH9FmwV1ZWct1113HRRRexbNkyXnvttVD06zuP\nLYQhfgrDWaCFUmMfjg5UpU5MSOdDpB6vx4/X4+/+5iGAKMqRQaGIjlIwRegQhJGnsfc5KkatVvOz\nn/2MiRMnYrfbWblyJfPnzycnJycU/fvOEuqIGBie2afWfjDFBBY42zAah0YXRpMWjVYdsjbNES0L\nvU4/9APkHHYvktTS71AghwHrh/VZBR3RZ409ISGBiRMnAmA2m8nJyaGqqqrPHfuuE8oYdgVThKzp\nDCfB3qKxh84EEXAiD5NxCPhbQjgGMPwWuP5QdkAOIW1qdCGKw8ck1R0htbGXlpZy5MgRcnNzQ9ls\nv2I/sA9nfv5gd6Md/TGJ1WpV83Fg7V9kT2UFjsOHQvasUKE4y3p7aHFHKFv51uMgSRKi14vo9SL5\nhpZT1enwIvqlkO5aAMzNC73d2lI0S3S7se3Zjd/WtuyszWbj/fffDfxbKaHbEU8++XuKigq7fX5X\nbbRGKcMbeCeC0NhffvnFoMvwRkQakEQJR/MC9/bbb+JyOrHu2IansmJIlOHtKSHbf9ntdu69914e\neeQRzGZzt/cHG4/Z3xT+8xU8dfWkLL2IjB9ci1rfdtIMVj+V1OnRmXHExoduPKNiTVSWNRIfFwGS\nSPlHH1P15QYchUUApF9zFaOvvrL3HQ9RPxUcNg9R0UYSE7tPew+WSItcVsDr8ZOQYEH0ejn8uz/Q\nsGcvxwFUKjK+/z3SLuu/lO+eUOlpBCA+IaLN+PV1bqamRcv/I8lt+Z1ODj3zJE2HDiOo1URNyyV1\n6UXEzJqJ293IRx/9h1tvlZNqoqNN6PWaDvvw9NNPtPm3co8oim2SzLpqozVarZqYGBP2etlhmjIq\nqsvviKLIT3/6QPcD0ExisoXjh6vQqNUkJFh49+03mJF/HOn4CZIvWMI//vFy0G0NFUIi2H0+H/fe\ney+XXnoGQVm+AAAgAElEQVQp5557blDfGSpJIEm33U3l6r9R8dHH1Gzbwagf/wRdQiIwuMkqNVXy\nc90eb7d96Ek/DUYNol+iuKgW19frqHn3bVCrMU+bjqesjJI3/43D4SFu2aV9/g196SfIafQ2q5vU\n9KiQ/x10ejX1tQ6qq61UvfUGDXv2ohuVhikhDmt+AUX/fANfXDLmyVNC+tzeUFoip5urNEJgHEIx\nN33N1SGrKps4WVpD2XPP4Dx2FOOEiYgOBw27dtOwZy8Zj/2Gx//2V4qLi1m27BJmzz6NM86YT0ND\nE7fddme7Urv33HMbd999H+PHT2DJkrO46qpr2bbtW+6++8fY7fY2ZXg9Hl+733FqGV673Ul9vYP6\nCh8V1cf42aOrEVRSuzK8F198Cdu3b2XlyivZunUz8+efGVQZ3sYGG3GWCZQUTeTdF/8fVSdP8ugX\nnxEdHcOq85exaNHZg16GVyHYxTwkgv2RRx5hzJgxXH/99aFobkAxZmeT8cvfUPOfd2hY+wU1775N\n6h13D3a3sFvdGIyhdZZBS9hgQ1k19o8+QB1hIePXv0MTFYW3tpbS/3uC2g/eR9DpiD3/wpA+u6co\ntt9Q25ahJUnJumsnDWu/QJeSyuhHHiUpLZ6SbXspfuL3VP79RTIe+y2a6OiQP78nKONgajZBbF5/\ngsK8GsQ+HhaimJSP7KvEsXsHWXlHiZg9h5RbbkdQq7Ht3kX5qj9R9cY/uf32uykszGf16jcAWUB2\nVGp36tRpbZ7hdDrJyRnDD394Gx6Ph6uvXtGuDO+pdFaGt7qqhgN5a3nxpb+RkBTdrgyvTqdn1aqX\nALnMMARXhvfEkZM89PC9HDt8mNOLi3lPq+WPj/6G1IVnN4dVDn4Z3p7SZxv7zp07+eijj/j2229Z\nvnw5K1asYNOmTaHo24Ch0ulIuOp76DOzsO3cgfuUQkEDTX8kJykoNvvyz9Yjud3EX34lmqgoALRx\ncaQ9+DDqqGhqP3i/nZ11oOmPUEeFCIset8tH+auvIOh0pNx+F6pmM5whK5uEK67Cb7VS8dILSOLg\nnrak2MAVm3ioUELBRb+Ir64Oc+40Um6+DUEtKxMRM2ZinjET57Gj2Hbvavd9pdSuIAiBUrunotFo\nWLhwMQBFRYXtyvB2xJ49uwPXWpfhPX7iCI22kzz007u48cbv8emnH3Py5MnA95SCXa1pXYbX7/ez\nZcvXnHmmXE543brPuOmm7/PL395Do/Ukx3ftQLTZEIwmLDNntYqVb1+G9/jxvEAZ3m3btgbK8N50\n07UUFxdRWjq4MqTPGvusWbM4fPhwKPoyqAiCQPzyFZQ9+wy1H35A6l33DFpfPG4fPm9ok5MUApl2\nReUk5owhct78Nte1cfHELDmfmnf+TeNXm4i98KKQ9yFY+iPrVEFxwDk9kPW9a9GPGtXmevQ55+E4\nchj7nt3Y9+4hYkZo6tT0BsWpp8yHeYtzuPSq6SExT73+/Ld4GxsZW7uDhB8/jqBpKxISr7qGwoMH\nqPv4w3YLXDCldnU6Xa+SiToqw+tyekhLnsA//vFih98xGjs+OKW7MrySX8Odt/4Ya1UtKrMZVSft\nwOCV4e0p4czTVpgmT8WQnYNt905cxUWD1g8lWsPcLwJN1vpcmggSr/0BQgcVE6POPAtBr6dh/dpB\njRCx2xRNNfTjYDLJWqkvbhSR889sd10QBOIvXQlA49eDuwPtL40dwKgRcUtajJNz0aWktruujU8g\n9qKlaB0ObLW1PW6/dVZrR2V4O6KzMrwW4yiq6gq6LMPbEd2V4XW6rZRXH8EraIg9/0LM5oghV4a3\np4QFeysEQSDuUnnVrf1wzaD1w94PMewK2qZqAPyJ6RhGZ3R4j9pkJmr+Anz1dR1uwQeK/opbBlDX\nyMJFGJ/b4eIGoE9PR5+ZhX3/PnwNDSHvQ7DYbW40GlW/JBFprDVIggrDWZ0HPcScfwGRlkhydDqu\nu+5q/vrXP7W7p7WG3dn/63Q6Hnro5zz44I+4665bSOlgIQG5DK/D4eCGG77Hm2++zqRJU/B6fKgF\nI8vOv5lf/eoRrr/+Gm677SaKAwpY57sCpQzv1q1bmDdPXsRbl+F96onfkBSdgU+tJ3rxOYEyvD/6\n0R3t2u6sDO95553P7bffyPXXX82jjz6M0+notD8DQbhs7ylIkkTJE7/HdeI4s/72PFbVwJ+LeWhv\nORs/OcbZF09gwtTkbu/vSYTEybfe5D8FSSTGaLns9vaaqoLnZCWFP/8phpwxjP7ZL4Lue6j6CfDZ\n+wfJP1rNdXefEXKtffsfVrFDmMycuUnMXjyx0z42bFhP1euvEb/ycmIvWhrSPgTLK3/+Bp2ubZnh\nkETFNNTz6ZNvURI1kcuun0liSuchpZWvrKbp602kPfAwpgkTO73vVEIVWVZfY+etv29n4rQUFl04\nvs/ttWl7/Vo+3tSI0xTLzQ8uGtJnFYTL9vYSQRCInLcAgLqt2walD/Z+qBMDIIki9p3b0IsunGLX\n2p8uKRlz7jRcJ47jzD8R0n4Ei8Mun5xkNIXWBOEuL0dVIv8mp6/rqCPL3NMRtFoav/lqUIpl+f0i\nTrs3kDUcShq+XI/eKzvIFbNXZ0SedjoA1m1bQ96PYLDb+m/3Ztu1E53fgU8Uhk3dnO4IC/YOiJg+\nAwSB2i3fDsrzbf1kgnDmHcNXX49Rr8Jh93QrqKIXy9tza6tjwAYSu9WDyaxDpQqtBtX09Sb0Pnvg\nGV2hNpmImDUb78mTOPOOhbQfweC094+fQZIkmrZuwaCSfSjdFYYzjp+AOioK687tg+J3USKkQlkA\nDMBvteI8dpSICNkRPFzKK3RHWLB3gCYqCuOYsTQdPoKvqWnAn99iYw/tJLZukxeqiPhI/H4Jj7vr\nF9Q0YSIqgwH7/n0Drq1KkoTD7gm5pir5fDRt+Qa9SYtaLQRV4TFqwVkANH61MaR9CYYWB3Jox8FT\nUY6vpoao0SmAnOHbFYJKhWX2XES7HfuhgyHtSzD0lyPdtncPiCIxaXJSYncL/XAhLNg7IWLGTJAk\n7Ht2D/izbVY3Or0arS50zjLJ58O6cwfqqCgik2SnT3eTWNBoME2egre6Gm9l+xjl/sTjluvRm0L8\nIjvzjuG3Womae7qcpBSEhmYcPwFNfDz23bsGXFvtLweyfd9eAGInjmnznK6wzJVt/IqCMJD0V0CB\nbfdOABInZAItoaXDnbBg74SIGbOAlj/8QOKweUKumdgPHUS02bDMnoup+eVw2LufxObmTEJbsyAY\nKPorxM9+8IDcbm4uZoseh82Dv5sMTkEQME/JRXS5cBUMbME4RZMO9c7Fvm8vCALxM+SSCcEscIbs\nHLTxCdh270Z0D6wA7I8FTnS5cBw8gG5UGjHpSfJzutm5DBfCgr0TtAkJmLMycRw+hN85cLWa/c1H\ntoX6Rbbtkhcoy9zTWqr6BTGJzVOnyvfu3xfS/nSHsuiEWmN3HDyAoNFgHDs+ICQUO3ZXmCdPBhhw\nM0TAaRjCcfA77DiP52HIysIQG41OrwlqLgiCgGXuaUhu14BXArXb3KjVAnpD6Hax9gP7kXw+ImbM\nDJykFLaxfweIPf00JJ8P+/6B01Zb6oKEVrA7jxxGZTJhyMoOtN2dXRVAExWNPjNLNmEM4ALXHxq7\nr6kJd0kxxrHjUOn1LQePBGGGMI6fCCoVjmaNf6DoD03VcfAgiGJgN2a26II+VcvUXBTNcWSABbvV\ng9kSuiPxoGU3HjFzVuDIwWDeieFAWLB3QdzpcwGwD2CSjqNZezSZQ/cie2uq8dZUYxw3HkGlajk5\nJ0jtxDw1F/x+HIcGTqgFxiGEgt1xWNa2TZNk4WQyB7/AqU0mDNk5uAry8XeSldgf2PvBFKPY1825\nzYI9Qq6b4/N2H+pnyM5B0GpxHDkSsv50h9/ffCReCHctkt+Pfd9eNHFx6NNHN5+jGjbFfCcwZWSg\njo7GcfTIgEWFOPohCkJ5CZXEkp4INICIZgFg3zdw5pieHKoQLIq2bWo2q/Rk5wLIJXwlaUC1VbtN\nPrZOG6Iqn5IoYj+wD3VUNPrmzOOWk5S6HweVVotxzFg8pSX4rAMTMRYI+QzhrsVdXITodGKeMhVB\nEFCpBExmXdh5+l1AEARM48bjb2rC26qKXH/SH84yx1G5SJsi2I1mbY+0E31GJmpLJPb9ewes0mGo\nw/wkScJ+8CBqiwV9Wnpz280CLQgnMoBpkrwgOAbQzu6whfbwZldhAX6rFfPU3IBZo+UkpeDGwdg8\nj5xHj4asX13RktcRwnfimNx347gJgc9MEXrstu7zO4YDYcHeDcaxcvqy89jATGJFyChadV+RJAnn\nkcOoLRZ0qXIFQ5VKhdEUvHYiqFSYp0zF39SEp6wsJP3qDiXr1BCirFNPeRn+xgZMkyYHasP0VGM3\nZGahMhqxHzwwIC+/z+vH7fKFdtfSvCgpTnHo+dmnioLgODIwVV0d/RDDrrzPxrHjAp+ZI3T4fWK3\n+R3DgbBg7wbjOFmwO/IGRrAHJnGItp3eqpNytun4CW2KXZkidEFlnyooL4DzeF5I+tUdoc46DZhh\nJrWciNRTk5SgVmOaOAlfTQ3eATiwvT+Sk5S/n6KwyO03C/Ygk3MMGZkIegPOgRLsIfa3SKKIM+8Y\n2oQEtLGxgc+VMOCRkKQUFuzdoEtJQRURMWAae8AUEyKNXdGqTi3cZI7Q4fOKeNzB1cYwjh0LDIxg\n74+sUyVMUQlbBNDpNWi0qh5FQgSiQgbAkRyIkArRIi+JIq4Tx9EmJaGJbCn4pZg4gtXY5XDRcXgq\nK/A19H952lC/E56yMkSHo83iBq1MUiPAzh4W7N0gqFQYx47DV1uLt7am35/nsHnQaENXotXZiWBX\n4sODSVIC0CY3L3An+l+whzrrVBJFXMfz0CYno4mOaXPNZNYFbWMHMI1vti/n9f84hNqR7ikvQ3Q6\nMeaMbfO5orH3xHFomiDbph1H+z865tSjAfuKsvtWduMKph7kdwx1woI9CEwBO3v/F4Gy290hsyVK\nkoTjyBHU0dFok9qW/+2xGUIQMOaMwVdT0+9aWqhj2D3lZYguF8bsMe2umSL0uBxeRDE4k5Q2KUle\n4PKPh6RvXRHqyKCAGWZMW8FuNMsFsHq0c5kwSf7OAJye5rSHVmMP2NfHnaqx93yBG6qEBXsQKBPA\n2c92dlFsLtEaqi1nRTl+axOm8RPbJXa0bL+Df5kVgdDf5phQZ50qZYcNOe0FuzlChySB09GDBS47\nR17gGvv38I1Ql6pV/m6GUwS77EzXYg8iA1dBP3o0KpMJ59H+F+x2m6f5oJG+h3xKkoTz2FFZ2UlI\naHOtJToorLED8MgjjzBv3jyWLVsWiuaGHPr0dFQGQyBEqr9w2r1A6JxErhOyVqnYx1ujJED1RDsZ\nKMEeao3ddUIW7MacnHbXerpzATlJB8DVz3XqQ+08dR0/jspsRpfc/vAWU4QuqNIKCoJKhTFnDN7q\n6n6vgKr4W0KRdeo9eRJ/UxOmcePbtWfqYeLeUCYkgn3lypW8/PLLoWhqSCKo1RjGjMVbWYmvsbHf\nnhNq779SsEoRRK3paagfgD4zE0GjwXm8f80QIR+H/BOoDIZAuGdrejMOxmbN33mifwW70idjCHZw\nvoYGOfs4Z0yHRwGazDo8bj/eILJPFQxZ2QD9WhhNFCWcdk/ozTCnOE4BjCYtKpUwIsoKhESwz549\nm8jIzo/VGgmYAuaY/rOzh7rgkzM/H0GnQz8qrd21nhQCU1BpdegzMuWsvX6s7hdK27LfbsdTUY4h\nK7tjgdbDJCUAQ1YWCEJgR9RfOOweDEYtanXfX9PO7OsKyjj0RGs3ZCuCvf8WOJfTiySFbpFX3l/j\nuHHtrgmCIIcBjwCNPfSn445QAtvvgnwss+cAYPPa2V65G5WgIjMyndSIFLSq4IZUkiRKqmwcKqzH\n6/Oj16pxVcs1SEKhnYhuN56yUoxjxiKo29smjQETRM8msXHMGFwnjuMqyA9E2tS7GjhQewS7186M\nxFySTAndtNKCy+OjqNJKQYUVm9NLXJSB6pPyGZmhMEEoQqejXUvrZ/RES1MZjOhSR+EqKkTy+RA0\nGlw+F2W2SmpddVg9NqbGTySxB+NQ3eCk+KSVmkYXjXYPybEmbFY3lsj+ta8rKHPObvMQGR3cOb+G\nTEWwFwQ+a/JYOVhzBLVKjU6tY5pxLALB/4ayGjvHShpweXx4vCKG5jyLUNVOchacQGU0ouvkIG1T\nhI6aShuSJA3ps0+7Y9AEe7CHsg42Sj995qmUCgL+smI0ESJrDn/G+vxvcPtbBIJereOHs65mUdYZ\nnbbncvt4d30ea7cXU9voanMtFRiFig0HK7GkRzNtbPCC4dTxbDxYDJJEzKTxnY61KUKH2+Xr0d9C\nNWsa9Z99iqqiGPuUFJ7f/k8K6ksC1z/K/4zxcdlcPuVipiVP6rSftY1O3vz8KGu3FeM/JSJlAgIR\nCHyxr4JLzxpDQkzvDxR3VpYCkDRzKrEd/E7RKz9b8kuBvgUzHo1TJnLys1JM9joKLV6e3foyTW5b\n4PoHJ/7HOTkLuHzyxUQbOt7NSpLEoYI63t9wnG2HKmmdKyYAs1FRWu9k8+EqLpqXiVbTdoHuyd+t\nvCgfQaMhbfZU1Pr2QjIxSW5Lq1YF326ChbKUZNyFBcTGGvmycAtv7H0fu7elCqjmoIZrc5dz4biz\nUQkd7zx8fpEvthbxxbZi8kraOqQjgfGo2Ha8moRJSSyYntprgeuz2zlWWUlU7lQSk6La/5wECzGx\nZqrKrUSY9CEvGT2QDJpgD8XJ5f3NqSes65JTsB4/zs8+e4I6dwMx+miWZi3BrDVT2FTCjpO7+eu2\n1yisKueirPPaTEBJktiTV8O/1h6jtsmN2aDh9MlJ5GbHYTHpcHv9HPy2GFu5lX2FdWx9YTNn5qZw\n9TljMXYT097RSfB1u+WEHCk5vdOxNpq0NDW4evS38CXIduqSndt5Ub0Zl8/FxNhxTImbiElrZGvF\nTo7WHucPm1Zx69TrmBrfItwTEixUVDby4TcFfLatBK9PJCnWxPQxcWSlRBJl1lHX5GbfF3l4PH4+\n2JTPf78uYMVZ2Vxw2mhUvXiha/fLBbs8cakd/k63V3ZY19bYqa62djiWHZI6GoAvv3iP12MLUQkq\nFqbNJ9mUiFqlYm3RRj4/vomvCrfx4xm3k2ZpqyHaXV5Wf3yY3XlybkRWSiRzJiQSH2Ug0qyjsKSB\nE5sK8UgSf//gAO9/mcfV54xj1viEwFgG+3cT3W5s+QUYMjKpa/IA7XcnIvKqUlHeSHxK8AuGdnQW\nrq1b+MM7v2efUIlBrWdZ9gVEaE04vE6+LPuKV/e8y9aivVw/+WoidW3bLqux8/f/HqKo0oogQG5O\nHLPGJWAx6dBpVRzeW0HV4WqqrW6een0H//06hmvPG0dKnDnoPiooNeRVqe3fCWU8NVp58SkuriMu\nIaLHz+hvgl10QybYR0LhnO7QjE7HU1GOVFXDBdPO56LMc1GrZC3qtJRZLEybx1/3ruZ/hWupdzdy\n7YTLEQQBUZR4Y+0xvtxVhlolcPEZGSw9IxO9rq0GdnJfJTbgnqum8eaXJ/hqXwWHCuu5fflkclLb\naxhdoURsKHbQjjBF6KmtsuP1+II+hk9jiUSKjcZZkI97ZgLXTbqKuckzA9fnJs8krz6fv+59mb/v\n/ye35t7A5DjZP9Fk9/D/3t7L4aJ6Yix6li/IYt7UZNStbN+SJLH/02MkJ0bww9mjeG/jCd7dcIJj\nJQ3cvHQSEUZt0GMgiSKu/BNok5JQR3T8khqMvXOYGZtNO1VH9hC5KI2bp/6A7KjMwPXTk2ezsWwz\n7+V9xPP7/sFDs+8hSi9r7gUVTTy/5gA1jS7GpUez8qxsxqZFtVEEIlUCJyhk/oxR5Khh3c4yVr2/\nn0vmZ3LJgqwe9dVdUgx+f9dzQTHN9cDGDqDPysK6dQuewgJy58zmqvHLida3zNWLpy7iua//wcHa\nI/xt32vcN/P2wDuzbmcp/15/HJ9fZP6UZFYuzCHmlNBOZ7mVqsPVfO+C8aw7Ws3+/FoeW72dW5dN\nYvaExB711VUom4wMWZ2PnzIOTrsHgt8wDzlC4jz9yU9+wtVXX01BQQGLFi3ivffeC0WzQwqv38s2\nXSUAC8VMlmYtCUxQhWRzIg/MvovRllFsqdjOloodeLx+Vr2/ny93lZGWEMGvb5rLZQtz2gl1kF8q\ntVpgfGYsj14/m6XzMqizunj6zT0cKqzrUX9dBfmoLZFoYuM6vcds7rkDtdZZzwmLG6Nb5Oa0S9oI\ndYWxMdncnnsjgiDw0v5XKWgspqzGzk+e28jhonpmjI3ndzefxpnTUtsIdWjJOjVb9MyfmsKvbpzL\n5KxY9p2o5TevbKemIfjDPjyVFXKmZQeJSQqCIGDsRbnWfK0Nl05gVK3Iw3N+3EaoA6hVahann8ml\n2RfS4G7kxX2v4vF72Hm0isf/uZPaRheXzM/koWtmMC49up15QVlooqMMXLV4LI/dMJuEaAMfflPI\nX98/gKsHhapcRYUAGDK6EGi98DUA7NLLO46JNjO3TP1BG6EOEG2I5I7cG5mdNJ2CpiI+yP8ESZJ4\nb+MJ3vjiGEa9mrtXTuWHSye1E+rQ4sxNSbLw4ytyuXP5FNRqgefXHGDtjpJ293dFQLBndqXs9G4c\nhhohEex//OMf+frrrzlw4AAbNmzgsssuC0WzQ4qPC77ggFGO1811xXRq54vUWbhl6nUY1HrezfuQ\nJ975ht15NUzMiOGn184kNb7zLaTdJod1CYKARq1i5Vk53L1iKn5R5Nl39rEnL7iSBr6GBnx1dRiy\ns7u0R5osPZvEkiTx5tH3qIyRp022tXPn5vjYMdw85Qd4RR+vHnybJ/+1g8paB8vmZXLXyqmdmpdO\nTaOPNOu478ppLJ2XSU2ji6fe3E1NY3DCPbBr6SB+vTXmCB32HhREs3nsvHbk31TGa7FYvZjdnX/v\nvIxFnJ48myJrCX/Z9i9e+OAgGo2K+66axvIzszstcnZqyOeohAgevX4OE0ZHs+tYNb//xza8vuBC\nE92FhYBcfrkzeqOx76s+yIeu3fhVkNOk79SGLggC14xfSaIpnnXFm1i1di0fbykiMcbIo9fPZua4\nzlXj1rH8giAwe0IiP/3eTCLNOv61No/3NgYfkeMqKGhWdmI7vUcJKuhJstZQJJx5GgQn7VWsL/kK\nX1IcqFS4iwq6vD/WEMOKMctw+92UGzczd1Ii9105DVMX5zVKUnO87ikOmxnjEvjRFdNQqWDV+/vZ\nd6K22/4G4tezOtdMAMzmniVkbKvcxeG6YwGNx11U1OX9U+InMit+FtWuKlxRedxxWS4rzsru0lau\nCJbWsdsqQWDlWdmsODNLFu7/Ck64K9Ea3Y2DyaxD9Eu4Xd1rwZIk8caRd2n0WIkeI0cFKZpgRwiC\nwDUTVhKvTeaE8xCaqHruv3IaU7I630lBx4WvIoxa7r9qOtPHxLMnr5oXPjiIr5uDuEHW2AW9ocPE\nJIWeFkRz+py8ceRdVFodmrQ0vKWliN7Ov2vQGPjh5O+jktQckjaQkqziZ9fOJD6qa8d4R+WbM5It\n/PwHs0iKMfLxliI+3VrcbX99TU346moxZGV1qewoCoUzrLGPbCRJ4p28D/FLflZOvBR9Wjru4mIk\nX+dCQJQkDuw04a9PQB1Vx4QZTWi6iUV2OeV6JR3F607OjOX+K6ejUgk8/8EBik927TQLVrD3ZNvZ\n5LHybt6H6NU6zjvjGvk5XQg0gEa7h6Nbk5G8OvTp+czO7d7x4+iiLsiy+Vksbxbu/+/tvTi6EcTu\n4iJQqzuM429NT8Zhb81B9tUcZGx0NhNzF7Y8pwsKym1U7pPNIElTCsgZ1X3OR2dJWhq1ijuWT2ba\n2Hh259Xwj/8d7nKnIbrdchz/6NEdxvG3xmTWBa2xf160AZvXzgWZ5xA1Zjz4/biLuxawhw77cBWN\nQ9B4mTCnhqggok4cNg9GU/vyzfHRRh64egbRETre/vI4Ww5UdtmOMle72rVA730NQ42wYO+GvTUH\nOVx3jImx45iWMAVDZhaSz4e7vPMDJ9798gTbDlUxyn0GBrWeTwq/aBMW2RHdnZw0Lj2aW5ZOwuPx\n8+w7e6lrcnV4H7QW7F072QICLYhJvOb4/3D4nFyacxEJcaloE5PkOO5OhIrXJ7Lq/f1U1/qZrJ+P\niI+Xd/272+d0V6L1kvlZLJmTTkWtg+c/OIC/kxOdJJ8Pd0kx+lFpCJquHcPBVroUJZGP8j9DQDYt\nKDZrxYbdETUNTv7yn/34bdGMi5hMtfsk31bs6PI50PU4aDVqfn7jaeSkRrLl4En+u7nz57uL5bBX\nfWb3DldThB6n3dNtQbQ6Vz1flnxFtD6KxekLMGQpOR6dL/Q7j1bx7/XHMTtyiNPHsa1qBycd1V0+\np7vyzXFRBu6/ajomvYbV/zvMwS78UO4gHKcARlNYsI94vH4v7+V9hFpQc8XYSxAEAUPzC9LZJN5y\nsJJPtxWTEmfivhWncXb6mdi8djaVbu7yWQFbYhfJSbMnJHLl4jE02Dw89+4+3B2kf0uShKuwAG1S\nMmpT1yFhwdZJqXLUsK1yF6nmZM4cdToAhsxMRLsdX017u78kSbzxxVGOlzYyd2Iid5x1PuNixrC7\n4gDHG7rW8oMpJ3Dl2WOYlhPHwYI63lrbcfanp6ICyefDkJnZ5fOgbXJOV+w4uYdK+0lOS5lFkjkR\nTXQ06sjITk1STreP597bh9Xh5drzxnL9tOXoVFo+PPEpTl/nCzO0ONI7K99s1Gu457Jc4iL1vP9V\nAbuPdSwkXUWKwzCzy+eBPA6SJO8eu+Kj/M/wij4uyb4AnVrXUlqgsOPSAkWVVv720SF0WjX3XT6D\n5WMvDCySXeH1+PF5xS7nQlpCBPdenosgwAtrDlDdiXM9GMcpgFqjQm/QhJ2nI5mvirZT56rnrLQz\nSN51YmQAACAASURBVDLLoVXKit/RJC6qtPLqJ0cw6tXcc1kuEUYti9PPxKgxsLZ4Iy5f5xqhI8ia\n00vmpLNoeiolVTZe+7T9IdvemmpEpxNDN1tOCH7b+VnReiQkLsg8J+AgU7a0rg78Det3lbFpbwUZ\nSRZuvGgiKpWKZdnny20Vru/yWcEcqqBSCdx6yWRGJZhZt6uUTXvL292jaNHKgc1dEczOxS/6+Tj/\nc9SCmosyzwVk+7l+dCa+ulr81rbmMUmS+Pt/D1FWbeecWWmcPTONaH0USzIWY/XaWF+8qcs+OZr9\nLV3ZgyPNOu65LBedVsXf/nuI0mpbu3sCAi2I+dCShdv5PC2xlrG9cjdpEanMSZ4BgDYxEUFv6NAU\nY3V4WPX+frw+kdsumUxGsoUZCVPJsKSzu2ofRU2dR7Z0ZZZrzbj0aK49bxx2l49V/9nfTuGRJAlX\nQQGa2Lg2B4x0hnK62HAmLNg7QZREPjwiv8jnjl4Y+FyXOgpBpwts7RRsTi+r3t+Pxydy89JJJMea\nADBpjUFp7cEWvhIEgWvOHUd28zb8y91tTUKK9qgfPbrb36jRqtHp1V1O4hpnHdsqd5FkSmRGYss5\nmYqgcDVHXCjklTbw5to8Ik1a7rlsKnqtHNaZHZXB5MRxHKo7SnFTaafPC/ZlNuo1/OiyXMwGDa9/\nfoyiyraC1V0s90s/OrPLdiC4sgJbKrZT46pjfuppxBlboioMGfLC4TrFzv7p1uJANNTV57SEWy4e\nfSZmjYmNZZvxdGKeCzjSgygtMTrJwg8vnoTb42fVf/bjPCUM0l1UhMpgQJuY1G1bxiAW+k8L5UV+\n+ZiLAou8oFJhGD0aT0V5mxpCoiTx9Bs7qWkO7Zw+Nl6+XxC4NOdCAD488Wmnz+rJwe4Lp49i4fRU\niqtsvHqKwuOrq8NvberWDKNgMssZ2X7fwBzc3h+EBXsn7Ks5RLn1JHOTZ7aJzRXUavTpo3GXlSF6\n5IknShIvfXQoMIFnnFIKYHH6AowaY7PW3vEWvCfHf2k1Ku5cPoUIo5Y31+ZxpJVtUXHkBaOhKc/r\n6kX+vOhLREnkgszFbcLZFE3Y3cq+3OTw8MIHB5GQuGP5FGIjDW3aWjHxAgA+K/qy0+c57B50ejUa\nbfe1t+Ojjdy8dBI+v8hf1+zH4WoxIbiKikClQp/eteMUujdJ+UU/nxauR6vSckHm4jbXlJ1L63E4\nWlzPuxtPEB2h47ZLJreJ1derdZyZdgZ2r6NTW3tXjvSOmDMhkQtPG83Jeif/+KRFqIkuJ57KCvSj\nM7p1nEL3C1yNs5a91QcYbRnFhJi2NWf0ozNAknCXtSzaH35dwK4jVUzJjm2XVDU+dgzjonM4Up9H\nqbX9jgtaFcULsk7M984dR05qJN8ePMmGVgpPixkmSMHeA9/TUCUs2DtAkiQ+L/oSAaGNtq5gyMgA\nUcRdKk/iT74tYn9+LZOz2k9gAKPGyDnpZ2L3Odhcsb3DZ/a0VG1spIHbL52MKEk8+c8d2Jrtoorm\nqE/vXmMHWai5HF78HYTN1bsa+LZiB4nGeGYlTmtzTW0yoU1KDjhQlcWt3upm5VnZjB8d0669qUkT\nyLCks7f6AJX2kx32x9HDEq3TxsSzdF4G1Q0u/v5fOUJEEkXcJcXoUkeh0nbfVncF0fbWHKTe3cAZ\nKbMD2aMKp2rsjTY3z39wEAGB2y+dQmQHv2Vh2jw0Kg3rSr5ClNqPe7C7ltasaM5e3XGkivW7ypr7\nJDtOgxVoxm58DV+WfI2ExOL0s9qZiJQdorJjPFhQx0ffFJIYY+TWZZM7DHFdPPpMud3Srzt8Xkeh\nr12h1ai4Q1F41uUFdnGKshOMWQ5GRmRMWLB3QF7DCYqaSpgzahrJ5vZpywFttaSIYyUNvL+pgBiL\nnluWTeo0RvvMUWegUWnYVLq5y5fZaAo+ZX5SZizLF2RR0+Dk7/89hF8UcRcVoYmL6zSF/lSUhcTl\naO8w21S2Bb/k57yMRe2ybEHeFYgOB97qaj7eXMjBgjpyc+K48PSOXyBBEDg/82wkJD4v2tDuut8v\n4nL0/ASp5QuymZgRw57jNXy2rQRPZQWSxxP0rkWtVmHo4gShDSXfALAwbX67a5rYOFQREbiLChFF\niRc/PEiT3cPli3IYlx7dYXuROgunJc9s1oAPtrvem8ObNWoVt186BYtJy1vr8sgvbwoqMak1gRju\nDsbB4ZWVkmh9FDMTc9tdNzSbvNwlRdRb3fzto4OoVAIPXzen0zIQk+MmkGiMZ0flbqye9v6B3pz5\nGhtpaN7FSc27OF/LLjZowa4cQhMW7COKdc2OrUsnLunwuiLYrScKeOED+bT62y6ZTKSp8wkYoTMz\nO3E61c5aDte1P4HIYfc0F/rv2Z/k4jMymT4ugX0nalm34SB+a1PQmgl0blf1ij42l2/DrDExO2lG\nh99VIi3yt+9nzdcFxEbquXlp54sbwNT4SSQa49lZtReb197mmtOhnCDVs6p6ijM1yqzjvY0nKN4j\nH9emzwh+HMzmjk8QKrGWcaKxgImx4zpc5AVBwDA6A+//Z++9oyS560PfT3WOk3ty3JyjNiqsJAQS\nCiRjHgbDRRhjHDg8Xb/jc1+wr6/TxX6PCxiuMRgso4vBZIQQKGu1knalzTnvTs6xezqHqvdHdfX0\nzHRPV3XXzG6P+nMO54jpqq7f/vpX39/3942jo/zqlYtc7pli++oaHtzdsuDz3tVyDwICL/W8Ns8B\nnm+jkUq3lc8+thFRlPjGL87jV8JeVUTEwMLRQW8OHCWaiHJv850ZN3lLQ4Ncvri7i2/98gLTwRgf\nuX8VazKc3BQMgoF7W+4iLiV4vf/IvM+12NjT2bKymkf2yae4J399iXBvD6bKKoxudQW0SqaYZch4\naIIL41foKGtldXXmI6y1sQmMRgbOX2HKH+VDB1Zk1c7SOdCyH4BDfW/O+yzfLjEGg8Cffmwn5S4L\nxw+eBtRrJpDdvnxq5Cz+WIC9jXdgMWbWuJQN5PQbZzAIAn/4/k05i3QZBAN3Ne0lLsbn2ZgLaVpc\n7rTw2ffJpqmzb+QxD65kB6HobOfjweRvdW8GbV1BmYdTh85QU27j04/M7zE7lzpnLZtq1tPl66Fr\nTmRIPqYYhY0dVTx2ZzvjvjCjF6/JjlOPumJZNocFQZgv0BJigoN9b2I1WrizcU/GewWTCUtzC6He\nPq71TLBzjYcHdub2b+yp34ndZONQ/xFi4uy5L2QePnB3B2tbKrh0sYfE1JSqYAIFRw7TXDFQEuxz\nODxwFAmJu5Lx2pkQTCZC5R5c02NsX1HFQ3vULZpWdzMdZW1cGL/CaHCmNEAsliAaSeTdJabCbeVz\n79tIXVj+zkRt5iYCmZhJzpn9Mh/qO4KAwN2N2WvLm5plrbQiMMZv37eKlU3qKlDubbgDs8HE6/1v\nzTJL5auhKaxvq+T9d3VQ4RtBQsDctLDWnI5ycvGnNTKejvo5PnyaWnsNG6rnt1JTiNfKpYwbouP8\n4Qc24bSpM6cdaJI3+jcH3p7190Ln4X13drCp2YUjMEmgok6V4xRkJcHumF8Q7ezYRaYiXvY27MJh\nzl4CIFBei0FMsMYa5vGH16mqm24zWdnfuJvpqJ+Tw2dmfabFkT4Xo8HAH7x/Ix2CbGcPVOSOClIo\n2diXGQkxwZuDR7Gb7OyY4yxM5/zNca7HnJilBJ/YVampTviB5v1ISBzqnwl9DGl0EmVibWslO8rk\nF/IHF0JZMzLnkmkR90730+nrZn31GjyO7DVNfnFsiCmTi6b4FA/snN9PNBtOs4OdtdsYC41zZWIm\nwUirsywTj+xppSE2yZiljGeOZY62yIQjJdhnopYODxwlLsY50Hxn1gJXsbjIDy7K9+yqiNHRoL5F\n5NqqVVTbqjgxfJpQfCaxphBNFWQB/cntZRiQuBiyc6l7UvW9mWK4lY3nrizaOsDIZJBDI7IA/vB6\nKw6VmxvIG5yAwBsDb836e9BfWK/TCpeVRzrkMT3fk8AXVCeoS6aYZcaZsQtMR/3srd+Z1fwwPBHk\nn5++wIhdFniG4eylBTKxvXYzZRa3XNI3IduU83GWZaJ8eoSIxcGZ4Rg/Oaiu6l0mU8yhPtneqWiU\nmXjr4hDPvd2D112DNRpE1Nip/u5m+USUblstVKABJMZGMSViTLk8PHO4S3VFzJR9OdlvVZREDg8c\nxWIws6dhfmlihe+/dJXzkxA3WamYHtE0VoNg4M7G3UTFGMeGTqX+rkcTa9PYIAAjtiq59rvKcscO\np4V4TCSajIcfD01weeIaK8rbaHRlLiIWiSX4p5+fp9con9hck5kjnrJRba9iXdVqbnq7GUxGSyUS\nIuFQrOAuRmU+OSP3pljGN35+XlXRNKvNJNfoLwn25cGb/UnNpCmzZhIMx/jqT84SjMTZcfc2gJyF\nj+ZiMpjY23AHoXiI06Pn5O/N01mWTsLvJz4+TvmqFdRVO3n+aC+vn82tsc7VTkLxMMeHT1Ftq8xq\nfugemubffn0Zm8XI6js2AvMTdHLR5m6hxd3E2bGLTIbldmip7NsCBFqkV/491u/ZjNlk4F9+dYHB\n8UCOu2bmwZ8U7NcmbzIWnmB77Rbspszmh4On+3nt9ACtdW6cHe3ERoY1N/ne27ALg2DgjYG3U05U\nPZpYK+ty54Ht+EMxvp4hIzMTc9fD4cFjSEhZbeuiJPHtZy7SM+Jn3a4NcvXTXm3vBMD+xt3y8waO\nAoX5W9KJ9HZjcDhZtbGdK71TfO+FqznLM880tS4J9qJnJDjG5clrrKrooN453x4nihL//MsLDE0E\neXB3C3fcK0eKaBVoAPsa5GbYR5LOQz00VeVlcrS384UPyxmZTz13Jecx3GY3z3KYnRw5Q1SMsa9h\nd0bzw4QvzNd/dpZoXOSzj22kZu2qWc9XiyAI3N20FwmJI8nYfj02OGUc9RtW86n3riMUSfA/fniG\nqRyOsJQpxidfd3hQFjCKwJnL2RtjfO/5q7jsZrm+fFurnKDTp635Q7nVzZaajfT7B1NO1KA/e+Er\ntUR65cqW++/blsrI/PavLuYs8JV+gkuICY4MHMNusmUMcQT46cEbnLg6yrrWCn7noY1Y6hsI9/Qg\nqTQFKmyp2YDL7OTtoRPExLgu74QYDhEbGcHa2spnHt1Ia62LQ2cGePlE9sxnBSVxr1g7w5UEexJF\nuNzVON9pKkoS//aby5y/OcHmFdX89r2rMNrtmGvr5BK+Gn/8WkcNqyo6uDp5nbHQuC6mmHBaEkZ9\nlYM/+ZCc/v8/f3aOgbHsGqviMFM0pCMDxxEQ2Nuwc961/lCM//GjM4z7IvzWgRVsW12TSoTKVbo2\nEztrt2IxmHlr8ASiJBIMROXa2xra380lPUFr38Z6Pnh3B+O+MF/58Zl56fbppNvYg7Egp0fPU+fw\nsHJOZySQW9v90y/OYzQKfOHDW/BU2LG2JHMbNJ7gYMZ2/cbAW8RjCaKReEFrQUomz1kbGzGYzXz8\n3WtY21LBiSujPPX8lQXXa7rGfmH8Mt6oj11127EY54/ntdP9/ObtHuqqHPzRBzdjMhqwtrUhRcLE\nRrSZY0wGE3sadhKIBTk7ekGnTb5PTtBqacVqkes3lTkt/ODlaxy9tPD4lBr9UQ2dqm4nSoId2Z76\n9uAJ7CYbWz2bZn0mSRLfe+Eqb5wbpL3ezR+8b2OqNrS1pQUxGCA+kbv5xVz2N8ia4JHB4/os4jkZ\np2tbK/nUe9cRjMT5hx+con8B4a5oJ0OBYTp93ayrWk2lbXb4Zjga58s/OsPAWID37Grh4WQSkqmq\nCoPTSaRXm6YKcvOFHbVbGQ9PcH3qplx72zm/9rYWIr09mKpmErQe3d/OPVsb6Rn28z9/fo5INLM5\nIt0Uc3T4FHExzr6GXfMiO/pH/Xz1x2eIxUU+976NqUggm5J5mYcZQnaiVnJy5CyTPjlRpxDBHh0a\nQopGU2vBZDTw+d/aQmudrLH+7FDmKozpzw0GoryZNItkMsO8fnaAp567gtNm4n//7S2pMFdbARuc\n8k4cHjiqi8Ye7p1dN6m63MYXPrwFq9nIvzxzMWtFTCj+FnklwQ5cmriKN+pjZ922WU5TUZT4wUvX\nOHiqn5Zal1z7Oa0LUioDNQ9tdXvtZmxGK28NHk/VAS/UFCPHLM/UqblzcwMff/cafIEo//D9k/SN\nzM/uA7C7LMSiCd7slU1DiqlIwReM8qUfnqZz0Medm+r5yP2rUgJPEASsLa3ERoZJhNT3I1XY23AH\nMLPBFTIHce8UCa93VsyyIAh84sE1bFtVw8WuSf6/H55KlV9Ix2I1YTAK+H0RDg8cxSAY2DPn1HKj\n38sX//0kvmCM333PWrantXSzNDSC0ZiXYDcIBvbU7ySaiHKm/zKgjzkqPVHNYTPxnz+yLdV16Iev\nXEPMoLkr8z/pnebixBVa3U00u2eHzx483c+Tv76Mw2bi//joduoqHanPlLkP5zEP9c5aVpZ3cHny\nGmNTXnk8eig7afPQ0VDGEx/Zislo4J9+cT6rcz1XeYXbHV0E+6FDh3jooYd48MEH+da3vqXHVy4p\niq17X1LIAATCMf76X9/mpRN9NNY4+dOPbpuXfKMkwITz0E4sRgs767YxFfEy4Z1esPZ2LhKRCNHB\nQawt87vkvGtnM598cC3TwRh///2TnL0xfyErNeBP9VzAaXKwxbMx9dnAWIC/+e5xbvT72Luxjk89\nvG5eeKcyD1GN9mWAVRUdeOzVnB68KNfeLmhzk58/t06O0WDgjz64ib0b67jR7+Pv//0k497ZxdgE\nQcDhtOD1Buj3D7K5ej1llplMxbM3xvh//+MUoUiC33tkPfdtnx3eKZhMWBubiPT1IiXU9SJNZ09y\n7V3ol7OSC5qHLPWCypwW/vSj22iodvD80V65xO2cE4wiSPvGhxElkb1pm3xCFPnF6zd56rkruB1m\n/uxjO2irn53NaU3mNuSzwQHsa5Sf1z0qO/4Lm4ceBLMZS33DrL+vbq7gCx/egtEg8LWfnuVXh7vm\n+R6cRR7yWLBgF0WRv/7rv+Y73/kOv/rVr3j22We5cUN9g9lbTSAW5NzoBeqddbS55UV5Y8DLX/3b\nMY5fGmZjRxX/5eM7MpYLSBU+ykNjB9ifXMTT06FUE+t8CPb0yl1yWjIn5Ny7vYnfe2Q9kViCr/z4\nLN9/8eqsRsj25CKOhBLsqt+O2WAiIYocPNXP3/6vmbKrv//ohlmVChUUAZKPI1kQBFlrj8jfq4um\nmqEAmslo4DOPbuCBO5rpHwvwF//6NgdP9c/SWh1Oi6yhSTMCJhiO893nLvOVH59FkuBPPrSZOzc3\nzPt+5blSLEZ0WJt9GaDGXsWaipWMTckRQos1DzXldv6vT+xkXWsFp66N8bf/6wRXe6dSnyuCdGzK\ni0kwckfdtuT/D/H3/36KX77ZRXWZjT/72A5aaufXIzK6XJiqqvMW7Ns9m7EYLYwm5yHfkE8pHic6\n0I+lqRnBOD/BaV1bJX/2sR1UuK387NBN/vt3j+JNc7Ar85CpzEQxkJ+KmMbZs2dpa2ujqUnWYB55\n5BFefvllVuboDH+7cGz4FHEpwd76ndwY8PGrw12phtEfeWAN79nRlNXmayqvwFhenpd9GeSQv3pH\nHVLEgLUy/58ikOzmtFBFxzs3N9BS6+Kbv7zASyf6OHVtjPt2NHHXlobUIjbFrGyr3s6JKyM8/UYn\nfaMBrBYjv//oBvZtyt4I2VqAfRnktPJXzsjOaz00VVuWeTAIAr/zrtU0e1z88JXrPPX8FY5cGOL+\nHc1sXVWN3WkGUaDcUEGdqY3nj/bw/NEepvxRmj1OHn94/YIJSNbWVjgsR6RYG9Vn/yrsbbiD586f\nBPKfB0mSiPT2YK7xYHQ4Ml7jtMlNsb//4lUOnh7gi/9+kl3rannPrhbaG9wYTQKJMGz2bGRiUuSn\np65w5PwQkViC3etr+eSDaxdMQLK2thI4fYq4dwo86uqzKNhMVnZ4tjB8TijIkR4dHJA7aC1QVmJF\nYxn/9VO7+Oenz/PW+SFOXh7hwLYmHtzdoqo2/e1MwYJ9eHiYhoYZDaauro5z584V+rVLxuHXbuK2\n1fLTX0SIhk4AsKa5nA/cvYK772hldHThxtHWllaC58+R8PtVV1RUEASBXVU76ZQgYgzm/W8I3OyS\nx5KjNkprnZu/+NQufn7oJgdP9/OTgzf4+aGbNDsEagFLqIIvfvs6kgQCcNeWBj50zwoqciSJWOrl\nAlD5OMwAKm0VtFrlscfN+dfnCPf2YLDbMdXUZL1GEATu2drI5hXVfO+FK5y6Nsa1Pi9mk4GV9ghu\nrIhDTfyXf5ZzGkxGgQ/e3cF797blbEg+43PpgT3ZSzFkY1vtZl6NyzZ2m4Yqn+nEp6ZITE9jX71m\nwetMRgOffGgd+zc38IOXrnHs8gjHLo9gtRhZb4hgilk5fzzB4SHZgVpdZuV337OG/Zvqc54srS2y\nYI/09sIq9WUdFPY27OTZ2GWwJPJ2pCvm0VzlqxXz1MkbE/zwxSu8eLyXF4/3Umkzsgq41jfIPopD\nSU2nYMGeb5ynR+NOvliU9zVisVThrK6mbUMZ79ndxuZVM4Ih1zgDa1cRPH8O2/QYFR2Zj+gLsT+4\nk05OMMl43nMyeLMTwWikactaDJbcmt7nP7qDx9+/mVeO9/DayT68kWvgr0OYqmJ9exWbV9Zw59ZG\nOhrV1X4BGGxvI9DVTXWFDYM5u1DK9m9cX76Wq/gYZgCPJ3Ps+EIkwmGuDg9TtnEDtbW50/o9Hjd/\n9bkauod8vHlmgDfODBCMD+CmkcRoLVtX13Dn1ib2b26gXGX2Y9yxnj5AGh7I+7f0mGqJAn7HOOs8\nC6+nTM+Y6L4KQNW61arG4PG42bOliWMXhzhxeYSzN4eJ+idwBMqxhl3csb6Sh/a2cceGeowqhaxh\n01omngHT+FDWcS5Edc0WXox3EbIFcFeYsZltuW+aw/SY/Oy6LesoU/H8h+vKeffuVl4+1suxi8Pc\nnOwkOi4QIXbbyCotFCzY6+vrGRiYyXAcHh6mtjZ3NblcmvBS4XY5KBMcfOKTM45TZWwejzvnOMUa\n+eUbOXeZWEO75uf7RuQ42QlxnDOd17KmbWdDEkUC3d2Y6xsY90YA9RrvvnW17F3r4YsHj8BwHfva\nV/LgYzPt77T8RoaGJqTrNxg4dzWrlrTQfNojZYCPs5PnGRrOXP99IUI3roMkYahv1DRuh1Hg3Tua\n2L3RzZd+fhTGG/n9d+1gzUY5SS0aijIaUn8cN9d4mL5xk5ERX14+E1vcSVgI8XLnm7Q5s5/Ass3l\n+DlZ449X1WmahxV1Lvl/66Z58RdhhEAl//UTu1MmoYnxzBFVmYiVy9FCE5ev0Yz2dz0WjSMkjMRM\nYV64eDjl79DC1JVrIAiEnFVEVDzf43EzNRlk56pqdq6q5t8vXeTwwDH+eNunbxtZBeo3yYKdp5s3\nb6anp4f+/n6i0SjPPvss73rXuwr92iXD6ZKTc/I9eaQch3nalxUbXswS4a2hzK3SFiI2MoIYDmsq\nS5pOr7+fgbhc7yYRzj/LrpAIIYBIQN7gvMIklyauar9/AYehGo4NnyJqliNlCnGYWVtaSUxPk/BO\n5b44A2JIQLLEOTt2nmBMe/io1m5BczkyeCxlDss3httUXYPBbi/4nYibI6mINS2k/Ax1dRhs2rX9\nSCLKyZEzVNrKWVe1OvcNtyEFC3aj0cif//mf8+lPf5pHH32URx55pGgcpyB73RMFZJjJHdqteduX\nlZfHZIWjQydJiNpC5RSBls1hmIu3Bk8gGuIYjIU5itK7SuVD+sv81tAJzfcXItglSeLI4HEkc2zW\nWPIhFcedx3qQJEmO5XdZiIlxToyc1vwdkd4ejC43psrsDS6yMRme4vLENdxuuTZOvvOQym0YHiYR\nztzjdyGUd6LM7eCGt5ORYPZEokzEx8cQQ6G834nTI+cIJyLsadiZtarn7Y4uo77nnnt4/vnneeGF\nF/jsZz+rx1cuGWo61C+EYDBgbW6RO7THtH+H8vKsbmhjOurn4sQVTfcXItBiYpzjQ6dwW1w4XbaC\nsuysTc0gCPlvcIEoJrOB2rIazo1eIBDT5kyO9PSA0Sg3QdFIl6+HocAwq+rlzamgeSigxEIkHEcU\nJWoqKhAQNGuriWSbQmtLa15moLcGTyAhsaJWbpBR8AYnSQS7ta8H5bntHvm31DoPhZ7elAYwe+vv\nyHHl7Utxbkc6okdYk7W1FUSRaL/6+t8KyrH/jhbZtq2kcatFrfc/E+fGLhKIB9lVvx2nq7CiRwab\nDXNdHZFe7bVzYKb29r6GO4hLCY4Pq9dWpUSCSF8v1sYmBJN2t5FSUXBfm1yeVw+NPZ/QT+W55W4H\nG6rX0u3rZcA/pPr+mYxT7WtBlETeGjyGxWBmbf0KoHCTFID/Zqfme5WNdVVdG3aTnbcHj2s6yabe\niTzMUWOhCa5O3ZAT5xboRXC7844X7Hp0S0lpaXmYIZTnrqxrodXdxIXxy0xFvKrvj/T2YPXUaA61\nhJkyxXc27sbutCBJEM6Qbq8WW2sbYihEbEzb0VkUJULBKA6XlV11OzAIBt5MK2Obi+jQEFIslteL\nHI6HOT5yhipbJRtqV2O1mQpaC6bKKowud14ae3oxOKWsw9z2gQtRiGC/PtWZKlNcWe6aNZ58UHwu\ngc4uzfcq819WZmdX3Ta80WlNJ9lCNPa3FW29QbvD9naiJNiz9PzUQiGOQ7n9lwmTycj+xt2pgmRq\niHu9JLxTODsy92ZdiLHQOJcnr7GyvJ16Z50uRY9mzBDa5iEciiFJ8m9RbnWzuWYD/f5BeqZzl1eV\nn9clP19D82qFkyNniSai7G24A4NgwOW2FiTYBUHA2tpKbHSURDB3Hfh00ovBba5Zj9Ps4O2hE6q1\n1ZlSAtrnQaluuq9hly6nWKV2TiAfjT2tAJgSEXNk4Jjq+yM9PRjLyzGVqw/XBbmD2uHBY1iNw/az\nyAAAIABJREFUFrZ7Nue+4TamJNh1qAlhaWzKu8FAeu3tO+q2YTaYOTxwdFYv0GwoJwRnR7vm5x5O\nvihK5T5dTi55OlDnNthQxvRG/1tZ70lH2VBteQi0wwPHEBBSdYJcZTbCwRgJFZ12sjErUUkDMxq7\nFZPBxO76HfhjAc6MXVB1f7inB8FiwVKvLWQ2FA9xauQcHns1qyo6sCeTowra4JK1c4Ld3Zpr56QL\n9hZXE02uBs6NX8IXzR12mPD7iU+M56WtX5y4wlTEy676HdhMhXVuutWUBLsOGrvBYsFS30Ckt1dT\ng4FU+6/kGOReq1sYC09wbTJ7aVWFcHdSsK9coWm8CTHBW4PHsJvsbE82ULiVGvvcssXrq1ZTZavk\n+PBpQvHcURWRnm4QhKy1crIxmFamuMomR5G43PILHQ7mb5KaqSFU2DwovQFe7zuS9R4FMRYjOjiA\ntblZdfNqhbcGTxATY+xv2I0gCBiNBmx287ym1lqxtrUhRqNEhwY13Rf0y450s8WIIAjsb1B/kk1F\nieVhlns9qUjcnaEnQ7FREuw6VXGztrbKDQZG1duXQ0nh4XDOZGoq2qrSwWchlKO3a4U2wX5+/BLe\n6DS767enyhTrobGbysowVlRoPrnMbTSS3gv0+PCphW6diVmu1R6zrPgY0rskKYK9kHlImeY0nlzm\ntoOrd9aypmIlV6duMBRYuLBYdKAfEgnNZhhJkni9/wgmwTgrEShTU2utKPMQ6dY+D+lF8XbXb8ds\nMPN6/5GcJ9l87eujgXEujl+ho6x1XpniYuQdL9hNJiMWa2EOM8jPgTpjgpg59q0ob6PeUcupkXN4\nIwsfPSM93Rhdbiw12rz3mRooOJNp84U2FrC1thGfnCQ+rb65daZGI/uUXqD9CztR42NjiMFgqtGF\nWsLxMEcGj1NucbOlZkPq704dBLu5tg7BastbY7enbfR3N8s1Z17PYZbK13F6ZfI6w8FRttduxW2Z\nccA7nBaikQRxFX1Ss2FtawcgnPSBqCEVy59WBM1hdrC7fjvj4UkujF9e8P5wlpLFuXj55htyb9em\n4tfWoSTYAXRpXGtTFnFXl+p7UppqmkATBIEDzXeSkBK80Z/9CJ4IBOSY5bY2TTHLI8HRlGbS5Jqp\nRTLjMCvw+J2HGSJTa8ByaxmbazbQ5x9I9QLNhCI0tEbEvDV0gnAizN1N+zAZZkIkXW7brDHlg2Aw\nYG1J5jZE1X9PwB9JOdIVttZspMzi5u2hE0QS2b8rX8fpoeQaO9A8u2iZLj6X5hbZ96RBY1cc6XPL\n9R5ovhOAg71vLnh/pLtbDr1VUdZEISEmePnmYewmOzuz9HYtNkqCHXkRh0M6Ocw0LOJsLfH2NOzE\nbrJzqP8IsURmW2+mLjlqeLVX1kzua7l71t9TDrMCN7h8en9mm4d7mmRh82rv61nvjeQRsyxKIq/1\nvYlJMHLXHA3NVVa4xg7JVnnJ3qNqCQWiqYQ5BaNBjpYKxcOcWCC2P9zTI/sZmptVP28yPMXZ0Qu0\nuBppL5ut4ephojRYrdibGjU1t86k7AA0uRpYVSF3VxoKjGS8VwyHiQ4NYm1t0+RnOD16Dm/Yx976\nnRl7uxYjJcGOPkX1jQ4H5to6wt1dquOvszWxthot3NW4B38skDVRJ9zdBYBNQ4ifPxbgyOBxqmyV\nbJvT29VoNGBzmAno4GsArSappAliTqnatZWraHY1cnLkLGOhiYz3ztRGUX/0vjRxjZHgGDvrts0y\nP0CaYC/UcagxQkh2pMczNpa4q3EPAgIH+97MuLYkUSTS24uloUFVdU+FN/rfQkLinub98059ejWa\ncK1ckWxunVkYz2WhXqeK1n4oy0k20tsjN5xJnp7VIEkSL/a8hoDAPc3aSy3frpQEO/o5UG1tbXJz\n67HMfRTnEsiiqQIcaN6PQTDwat8bGV/mGYHWrnp8b/S/TUyMcV/znRmrJzqdloJfZHONB4PDkYrY\nUUMwEMXuNGOYo2UJgsC7Wu9BQuKVLFp7uKcHU2UVJnfuUr0KB/veAODepKBIJ2WKKXiD09YPN7TA\nWqi0VbCzbiv9/kHOj1+a93lsZBgpEtZkhgnHw7ze/xYOkz3VJSkdvRpNOJOOfbV29oUE+9aajZRb\nynh78HjGaKl8lJ0rk9fpne5nT/N2ah2e3DcUCSXBjj72REhzFiUXWC5CSU3VmaHed6Wtgu2ezfT7\nB7k2Nb/VYKS7G4PdPqt59ULExDiv9b2JzWhjX2PmeucOl+wwixXgMBMEAVtbO7HhIRJBdfVeFmpi\nvbN2K5XWCo4MHMUfm53wIzevntKkrQ/4h7g4foUV5e20ls03WzidFgRBB5NUY5Pc3FqlSWohgQbw\nnrb7AHi+65V5G324S04CsmlIVDvUf4RAPMj9LXdnND/oEQYMssYO6k2UC82D0WDkQPN+wolIRlv7\njGBvVz2+F7pfBeD969+j+p5ioCTY0U+w2zQK9kAggsEgYLVlrm9yX8tdAPxmzssshsNEh4dkW6JK\nx+nx4dP4otNy+QBT5rBAvY7fyganRluNRePEogkcWZpZGA1G7mu5i6gY4/W+2ZEh+djXn+18AYD3\ntN2b8XPBIMz0Pi0AwWTC2tSsurl1LsHe5Gpgc80GOn098zZ6xWFva1Mn2COJKC/3HMJmtKXMG3PR\n6xSrJM+p3uCy2NgVDjTvx2ly8HLvoXlljSPd3QhWG+Y6dQla3b5erkxeZ23lKlZW5Vfm+HalJNjR\nJzkH0h2oXaquV7JOswnnjvI2NlSt5erkdS5PXEv9PdIrN69Wm4QRS8T4deeLmAQj97ZkfpFhZh4K\nFWqK5hjuzJ1OHgwosfzZbcPKZnSw7w3CaUfwlIamUmPv8fVxevQ87WWtbKpen/U6R4EF0RSsrW1y\nc+vB3MXhsvlb0nmw7X4Anu96ddbfw12dsuNU5Ty80f8W/liA+1ruxGG2Z7xGL43d5HTKvqcedb6n\nXPNgM9l4oPUAoXiIV5MmNQAxEiE6OICttVW14/TF7oPAzGloOVES7OinsRudTswejyoHaqZ43Uy8\nb+V7AXj6xq9TyRlhjbVRXus/zER4kgPNd6YyLDOhxNMXHPrZnhTs3SoE+5xyAhm/z2Tj/pa78ccC\nPN89I9RmTBDqErSeufk8AI+teHDBk47DaSURF4lG8jdJyePqmDXOhcgWGZROR3kraytXcXnyGlfH\n5MxkKZEg0tONpbEJgzV3Gnw0EePFnoNYjZZ5kVHpWG0mDEZBl2bO1tY2xECA+MR4zmuV9ZDJiaxw\nT/N+XGYnr/a+ntLatTpOu329nB49T6u7ibWVq1TdU0yUBDv6aewgmyHEQID4+MIO1Eg4jpiQcgr2\nFncjd9Rto9c/wMmRs/K93eodp/5YgOe6XsZhsvNQ+/0LXjtz/C4sIsRUVS1XOFQR069GoAE80HqA\nSmsFr/S+zlhoAkmSCHfexFhRgakid1OJ61OdXJy4wpqKlTm74ug1D6kNrjN3eYhcphiFhzveDcC/\nnvwhoiQSHRpEikZTz8rFq72vMx31c6D5TpxmR9brBEE2Sekh2BVnphqHeiAQxeYwY1ygcbjNZE1q\n7WFe6T2U/O6u5LPacz5DlER+dPVpJCQ+uOqRvGrX3+6UBDtgs5sRhMJty6Dezp7LlpjOYysexCgY\neebm88TFOOGebtXFnp7rfJlQPMx729+FY4EXGfQ7uQiCgLW9ndjYKAn/wr0y1ZggACxGCx9Y9TBx\nMc7Prz9LfHKShNerSlsXJZFfXP81AI+tfDDn9XqZIaxNzQgmkzqTlMr1sKqig931O7g52cOhviMz\np5b29pzPGA6O8uuul3CbXTzQeiDn9Ypg18MkBRBRc3LxR3HmWAsga+1ui4sXe15jKDCcMn+q0djf\nGjxOl6+HnbVbWbMMtXUoCXZgRjsp1LYMaY7DHNrJjKaa+/hcY6/m7qa9jIXGefbys0T7+7C1tee0\nJfb7BznUf4QaWxV3N+/P+Rw9Ty6KoMm5wanUVEGOkFlR3s7p0XN0npdjme0qBPvzXa/S6etmR+0W\nVpS357xeL1+DYDJhbWsn0tebMwM1FIgiCLKSkYsPrXoUp8XBMzefw3dD7g9rzeE4FSWRf7/0E+Ji\nnI+s/cCC2rqCw2VBTEhEwvm1jVRIKTs5NrhYNJF0pOdeC1ajhY+u+SBxMc5TF39EuKsLwWrNqewE\nY0GevvEbLEYLH1z1iOp/Q7FREuxJHAU2tVZIFYDKqbHnti2n89iKB6m113DhzKuy4zRH4a9ALMi3\nzn6XhJTgw2veh9mQu7OQXho7gK09Gb+cQ0tTa4oBeQP+8OrHEBA4d+ol+Tk5BHunt5tfd71IhbWc\nj679kJqhF9wuMR1be4ecgZqjMJrib1FjFnBbXHx8ywcJJyIMXz0jtwTMUdnyzYG3ueHtZKtnk+pa\n44rSESgwWcvocmGuqyfcdXPBDFTF9KVG2QHYVruZXXU76J/sITI4ILcEXEDZkSSJH1/7Jf5YgIfb\nH6DSVqHtH1JEFCTYn3vuOR599FHWr1/PhQvqakbfrjicFuJxkVi0MIeZ0eXCVFOT04G6UHJSJmwm\nG7+36XdpnJDHF2/OrpmIksi/XfgBY+EJHmq7n81pRa4WwmwxYjIb9NXYcwl2laYYhbayFt634iEq\nRmQTj9CcvRJfOB7m3y78AEmS+E8bPqpKS4UZwVKojR3SI4Sy29klSSLojy7oMJzL/Sv2s8rVimPE\nR6imbMGWgDemuvj59Wexm2z8b2s+oNqmrOcGZ1+xEjEUWrCEb0CDeVLhI2veR7vfgiBJRBqqFrz2\nmZvPc3ToJK3uplQo8XKlIMG+Zs0avv71r7NrV3G3kQL9Mu1A1lZFv3/BEr4zyUnqF3Gzu5GdIbmS\n43+E3s7YvT0hJvjZtV9xceIKG6rX8sgK9YkXejrMTBWVGMsrcjpQlSbWFqv6XqUPtNxDw6TERJmR\n73U9k7HD0Hhokq+c+iZj4Qne3XYvaypXqv5+XU8uyRPFQoI9GkkQj4ua1oJBMPCJqvswiXDdHeLZ\nzhczXnd54hpfP/0vxMQ4v7vutym3qs/Q1cskBaROmOGb2edB2UDU2NgVHGYHDxnWAfBi4irnxi5m\nvO5g75s83/0KHns1f7T192YVfluOFCTYV6xYQXt7e8Hmi9sBPe3L9lWyQyZ841rWawIabMsKkiTh\nGJwk6rJxnXH++7Gv8ubA28QSMSRJotPbzd8f/0de7XsDj72axzf8DgZB20+smKREsfDf1NbeTnxy\ngrh3Kus1akI+5xIfGcYUjROsr+T06Dn+7uiXOTt6QY6UiYc5N3aRvz/+VXqn+9nXsItHO7RlFerl\nPAW5hK/B4Vjw5JIyy6k0QSiYBuSNPVBXwW+6XuKpiz+k0ys3E58IT/JKzyG+ceZfEZH47OZPsq1W\nW7s3p1OfujkAthXyxhq+OT+LWkFLQEE65UNyiejBWgvfPPtdXuh+lfHQJJIk0Tc9wJMXvs9Prv0S\nt8XFn2z7zLz6QMuR5b1taSC1iHXQ0uwrZcEeun6dsn2ZE4JSha80CLX4xAQJr5eqHTt5fOPd/MeV\nn/H9yz/l+5d/ikEwpOLc72zczftXPpwzCiYTDqc11dRaq8Cdi629g8CZ04S7unBtnV+PRBQlQoEo\ndU3qtUiYMe9s2v4uhhoDHB44xjfPfReb0Uo4IQsho2Dkd9Z+iDsb92gOZzOaDHJTax0EuyAI2No7\nCF68QMLvz9h0PJDH6Q1InYYevPPjdE78hreHTvD20AncZhfTMdlUZTGY+YMtn8oZ4pkJXcOAm5oR\nzGbCndkFeyCPDU6SJELXr2EsL+f37v5j/vncv/H0jd/w9I3f4DQ5CMTlshaNznr+04aPUmPX1rug\nWMkp2B9//HHGMhS1euKJJ7j//oXjohfC43Hnfe9iUN8oCxdBmj22fMYpVmygz2Ih1n0z6/3RcByH\n00J9vfqGu2NXzwFQvXkDWzfdza6Ojfzowq+YCE4RSUSxGM18eOPDrPdof4kVqmuc3LwyitVsKvg3\nMm3byPjTP8cw1IvnATkZJv07/dMRJAkqq5yanjU9JJfCbb1jO19Ys5rf8j3ED889Q59vkFpnNR5H\nNfd27GNVdXte4/Z43JRV2Jn2hnVZp8GN6whevIB1apjKjoZ5nw/2eAGoayjT9LxY900MFgsb9uzm\nHw17OTt8iYOdR7gwcpXtDRvZ0bCZXU1bqXLk5yS0W+UInXhMLGgelHuHV6/Cd/kKVS4TRvv8jFcx\nLp8SW1orqax2qvruyOgoiakpqvbuYf2qjaxs/L95vfsoNya6uTnZTXtVM4+tfTfbGzbm3OBvN5lU\nCDkF+5NPPrkoDx4dzd2YdimJJ731I8PTqbF5PO68x2ltayd4/RpDPSMZF7HPG8JVZtP0/aOnzgOQ\nqGtO3mfmtzs+OG+chcytYJQXf3/fJEZLYUFTiepGEATGz5zH8eD0vHGODcv/bTQZNI158uIVMBoJ\nuqoJj05jxcUn1/zO7IvE/OZBGaPVZmJ0KMbgwBQm8/xKmFoQa5sAGD59gXjzfFv/0IAs2EVJUj3m\nSrtAsKcX+9p1jE/K2ZdNplY+vroV0vb1RABGA/mtB1GUEASYnAjkvabSf3NjcxtcvETf8XM41s0v\n6TAxLhd5C0diqp83ffQMAIaW9uQ9RvbX7GN/zewSvGNjC+dTFPKuLyVqNx/dwh2L3c7u1Cm0S8G2\nchUksyPnEo8liEYSmk0doZs3wGDQVL1OK3ral40OB9bmFsKdNxFj8xuGaAl1VJDicSK9PVhbWjGY\nc8d858tSOlDzsS37Ll8BScK+Kv/TWS4MSkG06cLnANLs7FnmIdVBSsNGGrpxHZgxf5aQKUiwv/TS\nSxw4cIAzZ87wuc99js985jN6jWvJ0VOgAakXLpxceOnkLdB6urE2NauqCZIvelX1U7CvXo0Ui2Us\njKY11BFk+7oUj2PX2MBbK3rOg6miAlNVNaEb1zPGcSvKRKbyzdnwXZCjP+yr1xQ8voXQqyAazETG\nhLI4UIP++R2kchG6cT2ZCLa8qjMWSkHO0wceeIAHHnhAr7HcUowmAza7WZfQLgDbSlk7CV2fHxmT\nj0CL9PUixWIprWex0H2DW72WqVdeJnTtKuzbMeuzlNPQrX4eglfkZsb2Net0GV829HQcAtjXrmX6\nyGGiA/1yL9A0gn456zS9iXUufJcugyBgX7nI68FlZXTITzQSx2or7IRkqqzCWFFB+OYNJEmaZfNO\nxEUi4Tg1deojVsRIhEhPN7aOFRjMy6OlnV6UMk/TcLosuoR2AZjcZZjr6uRFPEdLyycRQ9FycmWc\nFopTx9hlkDV2QBbsc8hHUw1dvSJ/75q1OowuO8qY9BLsjrXyRhRMjj+dgD+C3WGZ10EqG2Isiv/a\ndaytbRhsmcvu6oWe60EQBOwdK0l4vfMqPeZzig13dYIolswwGSgJ9jQcbqvcQShaWG0MBfuKVXK2\n3eDsbDslo1GTQFM01UW0qQLYHMkOQjpkXYKcqGT2eAhdn2+GCE5re5mleJzQtatYGpswlWkLkdSK\ncnIJ6DQP9qRgV35HBSXrVJNA60yao1Yv7lqAxTjBJTf6K7M3uFSoo1P9O6GYOW2LfGopRkqCPQ29\ntVVbMlEpNCdRSUvhK5CbFQcvX8JUVY25tk6XsWXDYBCwOy26vcgA9lVrEIMBgr19s/4e8EcwGAVV\nha9Arr8jRaPY1y6utg76m2LMNR5MlVWErlyZZa/OJ+s0nDTvLbZ9HcDp1i9JCcCxXi5vEbw0O0M0\nmEcsf8lxmp2SYE9D7+O3suDC1+YIdo2mmEhPD2IggGPDhiWpHe10yZUu9Yp0UgSQ7+Lslzngj+J0\nWVX/mxRtVzFrLCZ6a6qCIGBfu5aEf5rowExHpXyyToNXZbOWfdXiC/bUyUWnebA0NWN0uwlcujBr\nfWk1xUiiSOjGdUzV1arq8b/TKAn2NGZqY+ijnVgam+RFfPH87EWs0XmqaDeKtrPYOF3WlDNLD5Tj\nt+/ijBlCFCWC/ogmDW2pHKdAMuzOoFt0EMxsSKErl1J/05p1Koki4RvXsDU2YCpXn9yWL3qfXASD\nAce69SSmpoilFQTT+k5EeroR/f4leyeKjZJgTyNlitEpblcwGHBs3ETC6yXa15v6e9AvF74yW9TF\n6wYvyZUzHeuWSLAnj9+BaX02OHN9A0aXG9/FGYEWDkaRJPWaqhSPE7p+DUtD46Lb1xWcLqu+Jqk1\n8x2oWjX2aH8fYihE2frsPVv1xKljpUsFx/qNAATSzDFaywkEzstZ2M5N2urfvFMoCfY0UuVaddLY\nYWbhKQsRwO+PqDZBiLGoLNCampdEQwP9fQ2CIGBfs4bo2BjRoaFZ36021DHc3YUUiaSckEuBw2kh\nFNSnIBqAubYWU2XlLDu7Vo1dOb2VbVwawa6EYOql7EBmO7tWG3vg/DkQhNQmUWI2JcGeht4CDcCx\ncRMIQkqwJ+Ii4WAspRXnInzjBlI0uqRHTr01dgDnlq0A+M+ckr9bY6ijEua4FPZ1BYfLgiRBKKjn\nBreOxLQvFSml1d/iP30KBIHKnTtyX6wDBoMBu9Osq0nK7PHIkVKXLyEl5JLLWk6xiUCA8I3r2Fas\nxOhUV1PmnUZJsKeRqsmuo8ZucpdhbWsndP0aYjg0I9BUaqpLbV8H/TrnpOPcvFXe4M6clr97WqOm\nelk24yx2/Ho6etuXIS2e/bL8u2rZ4BJ+P6Hr17CtWImlYum6/zhdVgL+iK5lQxzrNyCGQqkG14FA\nRHUHqeClCyBJODdv0W08y42SYE/DaJS1Ez01dkiaYxIJgpcupR291WmqwUsXwGDAsQQhfgrKpqPn\nPJjKy3GvWU3o+jUSfr8mm2oiECB4+RLW1rYlM0dBWv0gHU8ujqRpzn/yhPzdGrJOA+fPgihmLIG8\nmDhcFuKxwruLzfrOpAkleOkCoigSCsRK9nUdKQn2OSyGdpJuZ1eEhBpTTCIYINzZKadML3KGYTqu\nRTDFAFTt3gWiSODc2Rmbqop58J8+CYkE7juWtlNXyiSl48nFXFWFbeUqQlcuE/f5CGrIOvWflk87\nzq3bdRuPGmYK5OnoSF6XPLlcukgoEEs+J/fpTZIkAufPYXS5sbaW6sNkoyTY5+BcBO3E1rECg8NB\n4MI5/Elh6VIj0E6evCVHTovVhNFk0NUkBVC1+w5AtrPPmCByv8z+48cAcO1cWsGu/EZ+nTc4985d\nIElMnzyhOutUiscJnj+L2ePB0pi9z+ti4FgkE6WtYwWhq1fwDcvlBdTMQ7S/j8TUFI6NmxZsXP1O\npzQzc1gM+7JgNOLYsJH42BjTQxOAOk3Vd+RNAMr27Mtxpb4IgiAnKekYCQFgb2nB7PEQPH+OgC+C\n2WLM2es0EQwQuHgBa0srlrrFzbqdy4wTWd95cN0hb3BTx0+qzjoNXrmMGA7j3Lp9SZLU0tGz92k6\n7n37QRQZPymH86oxTwbOlcwwaigJ9jnoHcuu4Eoen729Q7Oek43Y+BihK5exr1mL2ePRdSxqcLqt\nBANREon5ZWbzRRAEnFu3I4bDBLxBVRqa/9QpSCRwLbEZBtJ8DTpr7OaqamwrVjJ1U85tUGNbDiSj\niVzbltYMA/pnZCuU7doDRiMTV+X67K6yhedBkiR8bx0GoxHHpk26jmW5URLsc9C7NoaCa+cdGBxO\npsenEYTcx07fW0cAKNu7X9dxqEV5mUM6hrmBLJhEDISjkioNzX9CNsMstX0dwGQyYrObdDfFgLwe\nIkbZb5Jrk5dEEf/p0xgcjkUvApeJmdr0+s6D0e3GuXkLfp86v1P4+jWi/X24tu/E5F6aJLVipSTY\n57BYx06DxULZnXcRESzYzCzoLJMkCd+RNxFMpluiqcLiRMaAXJ0yXlkLgMO+cMxyIhggcOE81pYW\nLHX1uo5DLU63VXeNHeSNShHsuTT2wNkzxCfGcW3fiWBa+v7zi3WKBSjbt5+ISW66nsvvNHXwFQAq\n7r1P93EsN0qCfQ56t8hLp/yee4kYnViiC/dfjHR1EhsawrV9B0aHQ/dxqGExQv0ABJMJy54DAJgm\nBhe8dvrYMdkMs8RO03ScbiuxaIJoRJ+6OQrm6hrE2mYArGSfY0mSmPj1rwCofM9Duo5BLQ6XFUEA\n/3RY9+92btlGxCr38XQ4sod8xqd9+E8cx9LQuKTZx8VKSbDPYTGSUhSkihpEgxGzf4LIQH/W6xSn\nqXvfrTHDwOKE+ikIa2THl3TzEmI08zyLkQgTv3oawWymbP9duo9BLYsV+gkgJRtbx46+kfWa0LWr\nhG/ewLltO9amJt3HoAaDQcDhshLw6T8HBrOZmKMKczxE5NrlrNf53ngdKR6n/MB9S+48LkZKgn0O\n9mSjCb1NEEDKlmiNB/AefDXjNdGhQbyvH8JYXoFzw61zEC3m8Tuc/EqzfwLf4cxCbfKlF4hPTlL5\n7gcxV1XpPga1KCeXxbCzx9zVAMRPHSHS25PxmolfPwtA1Xsf0f35WnCVWQn49auboyBJEiEs2OIB\nxn72E6T4/JORJIp4XzuIYLFQtv/WKTvFREGC/R/+4R9473vfy/vf/34+//nP4/cvbGIoBpTO7Ho7\nT2FG+7WbJbxvHCLS2zvrc0kUGXryO0ixGLUf+/gtsacqLKbGrmyaNqJMPv+bVL0QhbjPx+RvnsXo\nclP50MO6P18Li1E3R8E/HUEQwBIPMfqTH837PNLbQ/D8Wexr1t7yZhIutxVRlHR3pkfCcRIJCWeZ\njUh3FxPP/XreNVMvv0hsbBT37r0YHaXaMGooSLDfddddPPvsszz99NO0tbXxzW9+U69x3VIcLquu\njSYUFOHg2b0NKRql/2tfJj41lfp88sXnCd+4jnvXbjmJ5RaSciIvgkBTNovqbZuIjY4y+sMfzGqb\nN/7M04jhMFXve/8t8zEoLKZgD/giuMpsODdsIHjh/KwKoHHvFEPffRK49do6LF6yljJpmH0fAAAa\nGklEQVSvVWtXYKyoYPyZp4mklbj2nznN6I/+A2N5OdXve7+uz17OFCTY9+/fn4ru2LZtG0PJkqzF\njtNlIREXCQVjun6vsohrNq+j5kMfJj4xQf/XvkLg/Dkmnv8N47/4GUa3m9qPfULX5+aDEuq3GCYp\nZR4aH3svloZGpl55iYGvfYXg1Sv0f+0reF99GXNdHRX33Kv7s7WSEmg6z0MiIRLwR3G5rdR8+CMg\nCAz809cY/fF/ELx0kZ6//SsiXZ2U7b8zVV/mVuJMxpj7dbazKxuFu8pJ3Sc/BYkEg//yTbyHXmP6\nxDEGv/UNBLOZpj/5Auaqal2fvZzR7az/k5/8hEceufWahR64ymwA+KZCGC36uSHSC4BVvPcRosPD\n+N58nf6vfEm+QBCo/cSnMLrduj2zEBwuK36f/pEQQX8Um92EraaKlv/z/2Hwm/9E4NxZAufOAnIr\nvdqPfeKWmqIUUmGfOgs0xTnvKrNia22j/vd+n7Gf/pjJ559j8vnnAKj+4G9R9fCjt4WzcLGcyOm1\nk1ybtlF+zwG8h15j+KknU9c0/OGfYOtYoetzlzs535zHH3+csbGxeX9/4oknuP/++wH4xje+gdls\n5rHHHlP9YI/n9hBemahvLOP8yX68UyHWbtQvfjoWkW3JbR3VWG1map74Y3qb5DR5Z1srrlUrsdXn\n97zFmM/KagcTowHKy+w5U//V4vG4CQailFfak2N2U/fXf0H3975P4GYnTR98P+Vbt9xSYZY+l5Ik\nYbYYiYRius5xKOmU9tSV4fG48Tz2IB0P3sfwiy8x+tobNH3wfVTv26t6nItNJCg7NRNxUfNzF7pe\nTMjmzqaWSjweNzX/+fP4H32IYE8Pwd4+XCtX4jlwd/4D12mcxUbOt/XJJ59c8POf//znvPbaazz1\n1FOaHjw6Oq3p+qVEMMpCxTcZ0nWck+MBzBYjvukwJGOCHe95FAAJmAam83iex+NelPlUmh50dY5T\nWV24rdvjcTPQP0UkHMdqM80as/PhD+AEYsDY2K1zwmeaS4fLwtSUvmuht2cSAKNZmPW9pt1307D7\nbkQWfkcW6zfPRizp4B4dntb03FzjHB2SP4snEjPXVTVgqGrAtW2PfM0S/DuXej7zRe3mU5Cd4dCh\nQ3z729/mG9/4BhaL+qbEtzvKsdM7FdL1ewMamzffapyL0CpwOmnaUcxdxYDLbSUcjJGI61c3J6Ch\nyuftgMNpwWAQdDfF+DWUsS6hnoLO13/zN39DLBbj05/+NABbt27lL//yL/UY1y1FKUbkndRPsMfj\nCcKhONW1Lt2+c7FZjIgQxWbvLi8ewZ6ejVxWoU9dfH9qgysOgSYnKVkWJSrGZjdhNqtr7F5CHQUJ\n9hdeeEGvcdxWKCnUemrsWhpL3C4ojkM9X+Zpb1JTLRKBBmkRIdN6CnZlHopng3O5rQwP+BBFCYNB\nHx+IPKfFMwfFQinzNAMGg4DLbcWno8ZejEdOd1Lo6Bnipphi3MUk0Bahbo5/OoLJbMBqu/WRP2px\nlVmRJHRrbB2NxIlFE0VjjiomSoI9C64yG9O+sG71yFM2VZV9HW8HFG1y2qtfyGNRmmIWySTlcltv\ni1BGtSjzoFcIbDEqO8VCSbBnwVWe1E50SkyZidctHuep1WbCYjWltGw9mPZGVNWjv53Q2yQVi8n+\nlmIywwC43PJ49drgis2BXEyUBHsWUtqqTkJN0XqLSVMFKCu3Me0N61Zewe8L43RbMRqLZ+nNJOfo\nu8kXm0Bz6Zx9qrbBRgntFM/btcS4dV7ExSrYXeVW4jGRcKjw8gpiQiQwHSk6TdWuc6hfSqAVkQMZ\n0kwxemvsRTYPxUBJsGfBlXIc6qOx+7xhLFYjVlv2ZgK3I8pGpMcG5/OGkaSZTbNYUJp769Vowl+E\nDmSYEcC6bXAlG/uiURLsWdDz2ClJEtPeMGXl+oTKLSXuVN2cwoWaEj7qKrJTC8gbXGA6qkuS0kyo\nY3EJNLtDPrnodYpN+Z2KKKCgWCgJ9iwojiI9NPZwKEY8JhadGQbSNXYdBHsyfLTYNHYgFb+uh8/F\nX6Q2doNB55PLdASL1ahbHaISM5QEexasNhNWm4lpHbSTYrWvw8yY9Qh59E4GgeJKylFwV+h3cim2\nrNN0nGU2gv4ooljYyUWSJDnkswjXQjFQEuwLUF5h10VTTQn2Isyw01ewh2Z9ZzFRVj5TyrlQ/NMR\nrDYTZkvxaaoutz5hwJFwnGgkUco6XSRKgn0ByirtRCMJIuHCOtQrWl4xCjRZABl1MUEUsynGrZhi\nCtzgZE01UnRmGAW9fE/KWtCrREOJ2ZQE+wKUJxddoTZFRRiUFaFgFwQBV5lVN429WDXVMp1MMak0\n+iLc3CDNmV7gelBOPiWNfXEoCfYFKK9MCvYCtZNitrGDvCEVenKRJAnvVKho58DhtGA0GZj2FmaK\nmYlhL855KEu+E4XWUVI2yJLGvjiUBPsCpDT2As0QPm84lZ5fjLh0iIwJh2JFrakKgoC73Fawxp4K\ndSxSU0x5pbwWCq18OqOxlwT7YlAS7AugaCeFRMYoMezFqqmCPsdvRaAVW1JOOmXlNiLheEEnF8W2\nrJwGiw1XmQ1B0FFjL+L34namJNgXQA+NPZTsvFPMtsRULHsBgl0xRxVzeJvyGxZijil2wW40GnCX\n23TR2F1lVoymkghaDEqzugDu8qR2UsDxu9jt66BPyGOqDnt5cZogANzJzOFC1oMSy1+sgh3ksYcC\nMaKR/E4uibiI3xcpaeuLSEmwL4DRaKCswo53In/tRLElLgvBXsDJxZ/snFTM86BHZIx3MoTdaS5a\nfwvM2MXznQdlHZXs64tHSbDnoKLKTjgUy7u64XLQ2O0OczIiJH9fQzE2sZ7LzMklv40+kRCZ9oYp\nr3ToOawlRzlt5NsTuBTquPgUpDZ89atf5eWXX8ZgMFBdXc0Xv/hFPB6PXmO7LSivcsCNCbyTIWx2\n7ZUZZ2LYi1c7EQQBd4Gx7FMTQSxWE3ZHcVW3TCelqeY5D9PJ6pbFbIaBdI09T8E+mXwninwebmcK\n0tg/85nP8Mtf/pJf/OIX3HvvvXz961/Xa1y3DRVV8uKbGg/mdf+Mxl68tmWQtVUlZFEroijhnQxR\nU+sqqlZwc0nVD8rTBKGY9IpesCshjwVr7MU9D7czBQl2p9OZ+u9QKITBsPwsOxVV8rF5ajI/we7z\nhrHZzUWZbZmOklKfz8s87Q0jJiSqa525L77NcZfbknXltXeUUtaQoiwUKwVr7KnkpJIpZrEoWNp8\n+ctf5umnn8btdvPUU0/pMabbivKkYM/HgSpJEn5vmOpal97DWnIqq+V5mBwPUFOn7d+jnHZqlsE8\nlFXYGBv2EwxENdcR9y2T+ihmsxGny5J3LLtvKoTZYszLtFlCHTkF++OPP87Y2Ni8vz/xxBPcf//9\nPPHEEzzxxBN861vf4nvf+x6f//znVT3Y43FrH+0toL2jGrPFiN8X0TxmnzdEIiFRU+ta9H/vYn9/\n+4oa3uQ60VBC87OuXxgBWJJ50IOFxljXUM7NK2MYMWj+twT9sgN+5eparLbCT3C3ci6ra130dE5Q\nWenAZDIueG36OCVJwucNU1XjpLa2bLGHqYliWJtqybm6nnzySVVf9Oijj/IHf/AHqgX76Oi0qutu\nJR6Pm7ExP+UVdsZH/YyM+DTZiHs7JwCwuyyL+u/1eNyLPp8Gs/zv7uuZ1Pysvp5JAKprF3+chZJr\nLk0W2dzY0zWOzaVN4xwdnsbhtOCbDkGB07AUv/lCOJwWkODm9bHUaS4Tc8cZDESJRRM4Fvmd0Mqt\nnk+1qN18CjKKd3d3p/775ZdfZsWKFYV83W1LeZWdeEzU3OtxYjQAQLWn+G3LTpcFs8XI5HhA871T\n40EEAapqijvMD/KPZU8kRPy+cNE7ThXyLQZWcpwuDQWdB7/0pS/R2dmJwWCgsbGR//bf/pte47qt\nSDlQJ0Ka4rAnxmQhWFlT/IJdEAQqaxyMDfkRRVGTo3xyIoi73JbzyF4MKGtB6wbnm1oeoY4KqVh2\njQ7UkuN0aShIsP/jP/6jXuO4rSlXQh4ngjS3V6q+b2IsgMEgLJuXubLaycjANN7J8ILH73TCoRjh\nYIy6huVhv3SX2zBbjIyPaBPsqVICRR4Ro1Be0thva5ZffOIiUJFHZIwkSUyOBamodmA0Lo9pTkXG\njKkXalMTyRA/lRvB7Y4gCFTXOpmaCBKPq4/pXy4x7AqKxq1VY1fWTrGHfN7uLA+Js8ikkpQ0xLL7\nfRFi0QRVy8AMo1BZo5gh1M+DEuqobI7LgSqPC0mCyTH186AIwOUi2K02M1abSXNew9iwH4vVVNQl\nNoqBkmBXgdVmxuYwa9LYFcfpcnAYKlRWy5uUFvuysgksF40dZpzhym+shuWmsYN8gvNNhojH1J1c\nYtEEUxMhauqKOwO5GCgJdpVUVNnxTYVIJERV1yuO06plEBGj4C63YTQZNGmqisau1iZfDCgJZ+Oj\nftX3eCdDOFyWos9ATsdT70aSYFzlBqfM13JIVLvdKQl2lVRUOpAk9WFuyykiRsFgEKiosjM1HlSd\nUj81EcRqMy2rLEPFvKbWgRoJx5n2qnc4Fws19bJDfHRIXfz32LAs2Ks1Zi6X0E5JsKskPTJGDROj\nAYwmw7Lz/lfWOInHRVWVHhMJEd9UmIpqx7I6elttJtxlVtWmGEXw1TbcXpmWheKplwW0WsE+PlLS\n2JeKkmBXiWJSGVOxiEVRYmo8SGW1A4Nh+Qg0SK8Zk3uD802FEEWJymXkOFWoqnURDEQJBaM5rx0Z\n9AFQu0xCPhUqqx2YTAZNGrvBIKSc8CUWj5JgV0ldo6xtDfX7cl477Q0Rj4vLKiJGIeVAVWFnV65Z\nTo5TBcWBqsYcMzwgrxllDS0XDAYD1XUuJsdyh36Kosj4aICqGueyCf+9nSnNsErsDgsVVXaGB3yI\n4sL25YlRWaAtJ8epwkzIo3qBphzZlxNqHaiSJDEyMI3TbcHpLu6a/Jnw1LkRRSnnBuedCJGIiyX7\n+hJREuwaqG8qJxZN5EzQmXGcLj9NtbzSjsEgpOylCzHQO4XBIFDXWL4EI1taUiGPOQRaYDpCMBBd\ndvZ1BbV29jHFvl4S7EtCSbBroK5ZMcd4F7wuFeq4DE0xRqOB2sYyxob9RMLZ+8DGonHGhvx46t2Y\nLcVfI2Yu5VV2jEYhZ6jfyKDiOF1e9nUFj8rIGCUipuQ4XRpKgl0D9U2y5jnUl93OLkkSw33eZZ1d\n19xWgSTBQM9U1muG+mWTVUPL8tPWQbYvV9Y4mRgLLGiaW672dYXKGtmBOja08AkuFepYEuxLQkmw\na6Cy2oHFalpQY58cDzLti9C6onJZhfiloxRC6+uazHrNYK88R40tFUsypltBtcdJIi4u6G9QNHZF\ns11uGAwGqmtdTIwFsjpQJUlibMSPu9ymS4ORErkpCXYNCIJAfVMZvqkwwUDmMLeeG+MAtKyoXsqh\nLSm1jWWYzAb6urNr7AO98mf1zctTYwdobJM3uO7r4xk/F0WJ0aFpKmtkhWC54ql3IYpS1rj+oD9K\nOBgr2deXkJJg10h9k3ykHs4S9thzU+6a1LqiasnGtNQYjQYaWyuYGg/iz9B8JB5PMDLgo6bOtaw1\ntPZV1QgCdF6d3zoS5HIKsWhi2TpOFXLZ2ZV3Yrmao25HSoJdI3WKnT2DOSYaiTPY68VT75Jbhy1j\nmpPaan8Gc8zIwDSJxPK1ryvY7GYaWysYGZzOuMEp9vXl6jhV8CT/fdlMc9cuDgOwan3tko3pnU5J\nsGukrtGNIGROVOrrmkQUJVqXsRlGoSkp2Pu657/Mg0kzzHK2ryt0rKkBoCuD1t6f7PW63DXVqhon\nVR4nXdfG55kofd4Q/d1T1DeXL9tggtuRkmDXiNliorrWxeigj3BodrhfygyzcvmaYRSqa53YHGb6\nuybnFQQbSDpOl7vGDtCxWhbsN6+Ozvq7fzrCjUujVFTZl71tWRAENmxrQBQlrpwbmvXZhdMDAKze\nUNLWlxJdBPt3vvMd1q1bx9RUdmfacmLNxjoSCYmTR2aaeUuSRM/NcWx207K3qYL8Mje3VRDwR2cV\nRgsGogz1e6msdmB3LG9zFICrzEZtg5uBnqlZG/25432IosTWPS3LNjoqnTUb6zCaDFw6Mzhroz9/\nsh+DQWDlOs8tHN07j4IF+9DQEIcPH6axsVGP8RQFm3Y04S6zcu5Ef6rK4cRogMB0lJaOqmVX+Csb\nTcmwx7PH+1N/e/2Fa8RjIhu3v3PWQ8eaGiRpJjomEo5z8fQADqeFNRvrbvHolgarzcyqdR68k7Lp\nBeTQ38E+Ly0dle+ITf52omDB/nd/93f82Z/9mR5jKRqMJgO77ulATEgcfb0T72SIF56+CEB78mj+\nTmD1+lqqPE4unhrg3Ik+blwe4eaVUeqby9m0s+lWD2/J6Fgja6PnT8kb/cUzA0QjCTbf0YTJtPyy\nbrOxYZu8mV86M5A0ywwCsGrDO2Nzu50oKBbtlVdeoaGhgbVr1+o1nqJhzca6/7+9u4tpMkvjAP6v\ntIDDOKaK06DD6CwOG4gFRhPdgURtbeSjVlFRboymDUZvrCB+hKJGA8aAqJekxAjRZDTK2myI0Wym\nWiEIIsYFN6Q6bHAcjAVRMhSj9OvZC9dO2NJqzOgp5fndnSYn+acfT09P3/c56Or4DY/+PYBfe19g\n7I0H6Uu/mVI/OWXRUuQVKPH3c/fQ+nMvZNFSREmnQZX31ymx/fCOfPYXSPxOjt/6hvGT+Q6ipNMg\ni46aUr9aAEAx7yvI47/Af+zP8fiXFng8Psiio/Dd95F/MUG4eW9h1+v1GBoK/Me/uLgYZrMZZ8+e\n9T/2oafqRAKJRIK/rfwLrl56ALfLixU5yf4Vy1QyY2Yscjcq8Y+f/gXXmAc/qpIi6uDqD5W3KQ29\nPQPobP0Vvw+/RvrSRMTERs6pUR9CIpFg8Y/z0fLPX/DVzFjMmhOHH5Z+G1HHAU4WEvrIavzo0SPo\n9XrExsa+7Y8yMACFQoHLly9j9mz+hmaMMVE+urD/P7VaDYvFgpkzI/8SN8YYC2d/2nXsEolkSm3F\nMMZYuPrTVuyMMcbCA995yhhjEYYLO2OMRRgu7IwxFmGEFXa73Y7CwkLk5+ejoKAADx48EBXlvc6f\nP4+cnBzodDrU1NSIjhNUuPfsqa6uRm5uLtatW4ddu3ZhdPT9B2J/Ts3NzcjJyUF2djbq6upEx5mQ\nw+HA1q1bkZeXB51Oh3PnzomOFJTP58P69euxc+dO0VGCcjqdMBqNyM3NhVarRVdXl+hIE2poaMCa\nNWug0+lQWloKl2vig378SBCDwUAtLS1ERGSz2WjLli2iooTU3t5Oer2e3G43ERG9ePFCcKKJPXv2\njAwGA6lUKhoeHhYdZ0Ktra3k9XqJiOjEiRNUU1MjONEfvF4vaTQa6u/vJ5fLRWvXrqXe3l7RsQIM\nDg5ST08PERGNjo7S6tWrwzInEVF9fT2VlpbSjh07REcJ6sCBA9TY2EhERG63m5xOp+BEgRwOB6nV\nahobGyMiot27d5PFYgk5R9iKXSKRwOl8e+KK0+mEQhGe/SQuXLiA7du3Qyp9e/fcrFnh2ZJ3MvTs\nyczMxLRpb99yGRkZcDgc75nx+XR3d2P+/PmYN28eZDIZtFotrFar6FgB5syZg5SUFABAXFwckpKS\nMDg4KDhVIIfDgVu3bmHTpk2iowQ1OjqKzs5ObNy4EQAglUrx5Zfh2WLZ5/Ph9evX8Hg8ePPmDb7+\nOnQbZGH3+paVlaGoqAhVVVUgIly8eFFUlJAeP36Mzs5OnD59GjExMdi/fz+USqXoWONMxp49jY2N\n0Gq1omP4DQwMICEhwT9WKBRhvT0IAP39/bDb7UhLSxMdJcC7hca7xVs46u/vh1wuR1lZGex2OxYt\nWoTy8nLExobXgSAKhQJ6vR4rV67E9OnTkZWVhczMzJBzPmlhD9ZnpqSkBLdv30Z5eTk0Gg2uX78O\nk8mE+vr6TxknqFD9cLxeL0ZGRnDp0iV0d3ejuLhYyEpusvTsCfWaq9VqAEBtbS1kMhl0Ot3njheU\nyOfsY7x69QpGoxEmkwlxcXGi44xjs9kQHx+PlJQU3LlzR3ScoDweD3p6enD48GEolUocO3YMdXV1\nMBqNoqONMzIyAqvVips3b2LGjBkwGo1oamoK/fn55BtEQSxZsmTcePHixYKShFZUVEQdHR3+sUaj\noZcvXwpMNN7Dhw8pMzOT1Go1qVQqSk1NJZVKRUNDQ6KjTejKlStUWFjo3y8MF/fv3yeDweAfm81m\nMpvNAhMF53a7yWAwUENDg+goEzp58iStWLGC1Go1ZWVlUUZGBu3bt090rADPnz8ntVrtH9+9ezcs\n/w+4du0alZeX+8cWi4WOHj0aco6wPXaFQoGOjg4AQFtbGxYsWCAqSkgajQZtbW0AgL6+Png8Hsjl\ncsGp/pCcnIzW1lZYrVbcuHEDCoUCFoslLBuxNTc348yZM6itrUV0dHgdvKBUKvHkyRM8ffoULpcL\nV69exapVq0THmpDJZMLChQuxbds20VEmtGfPHthsNlitVpw6dQrLli1DdXW16FgB4uPjkZCQgL6+\nPgBAe3s7kpKSBKcKNHfuXHR1dWFsbAxE9EE5he2xV1RUoLKyEj6fDzExMaioqBAVJaQNGzbAZDJB\np9NBJpOhqqpKdKSQwrlnT2VlJdxuNwwGAwAgPT0dR44cERvqf6KionDo0CEYDAYQEQoKCsLyQ37v\n3j00NTUhOTkZ+fn5kEgkKCkpwfLly0VHm5QOHjyIvXv3wuPxIDExEcePHxcdKUBaWhqys7ORn58P\nqVSK1NRUbN68OeQc7hXDGGMRhu88ZYyxCMOFnTHGIgwXdsYYizBc2BljLMJwYWeMsQjDhZ0xxiIM\nF3bGGIswXNgZYyzC/Be68EGj7hfMcwAAAABJRU5ErkJggg==\n", "text/plain": [ - "[]" + "\u003cmatplotlib.figure.Figure at 0x7f385e198650\u003e" ] }, - "execution_count": 48, "metadata": { "tags": [] }, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "# Create TensorFlow Variables using Keras's Dense layer.\n", + "def f(x):\n", + " return tf.square(tf.sin(x))\n", "\n", - "wb = tf.keras.layers.Dense(units=1, use_bias=True)\n", + "def grad(f):\n", + " return lambda x: tfe.gradients_function(f)(x)[0]\n", "\n", - "# We can access the underlying TensorFlow variables using wb.variables.\n", - "# However, the variables won't exist until the dimensions of the input\n", - "# tensors are known. Once the dimensions of the input tensors are known,\n", - "# Keras can create and initialize the variables. Until then, Keras will\n", - "# report the variables as an empty list: [].\n", + "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2π and +2π\n", "\n", - "wb.variables" + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(x, f(x), label=\"f\")\n", + "plt.plot(x, grad(f)(x), label=\"first derivative\")\n", + "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n", + "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n", + "plt.legend()\n", + "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "docKLUaonYG_" + "id": "-39gouo7mtgu" }, "source": [ - "## Step 3: *Define the loss function*\n", + "## Gradient tapes\n", "\n", - "Our loss function is the standard L2 loss (where we reduce the loss to its mean across its inputs)." + "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n", + "\n", + "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n", + "\n" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -245,125 +182,42 @@ } }, "colab_type": "code", - "id": "0_w8ZJSCtuY7" + "id": "MH0UfjympWf7" }, "outputs": [], "source": [ - "def loss_fn(predictions, labels):\n", - " \"\"\"Calculates the mean L2 loss for our linear model.\"\"\"\n", - " return tf.reduce_mean(tf.square(predictions - labels))" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "cellView": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 34 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 348, - "status": "ok", - "timestamp": 1525154234538, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "RkNbXoXkpjVH", - "outputId": "e4688f3c-e29f-416d-f541-6d81953b5660" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\u003ctf.Tensor: id=1252, shape=(), dtype=float32, numpy=16.979801\u003e" - ] - }, - "execution_count": 50, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "# Test loss function (optional).\n", + "def f(x, y):\n", + " output = 1\n", + " for i in range(y):\n", + " output = tf.multiply(output, x)\n", + " return output\n", "\n", - "loss_fn(wb(inputs), labels)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 51 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 418, - "status": "ok", - "timestamp": 1525154260083, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "K_7beXoHOU7t", - "outputId": "8f55c028-fe2b-4edb-ad68-a849afc60623" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "w: -0.311619\n", - "b: 0.000000\n" - ] - } - ], - "source": [ - "# At this point, the variables exist, and can now be queried:\n", + "def g(x, y):\n", + " # Return the gradient of `f` with respect to it's first parameter\n", + " return tfe.gradients_function(f)(x, y)[0]\n", "\n", - "w, b = wb.variables\n", - "print(\"w: %f\" % w.numpy())\n", - "print(\"b: %f\" % b.numpy())" + "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n", + "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n", + "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n", + "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "JVDWpL9VYWdP" + "id": "aNmR5-jhpX2t" }, "source": [ - "## Step 4: Create an optimizer\n", + "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n", "\n", - "We'll use a `GradientDescentOptimizer` to fit our model." + "For example:" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -371,36 +225,48 @@ } }, "colab_type": "code", - "id": "DudNEebMKDWN" + "id": "bAFeIE8EuVIq" }, "outputs": [], "source": [ - "optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)" + "x = tf.ones((2, 2))\n", + " \n", + "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n", + "# a single t.gradient() call when the bug is resolved.\n", + "with tf.GradientTape(persistent=True) as t:\n", + " # TODO(ashankar): Explain with \"watch\" argument better?\n", + " t.watch(x)\n", + " y = tf.reduce_sum(x)\n", + " z = tf.multiply(y, y)\n", + "\n", + "# Use the same tape to compute the derivative of z with respect to the\n", + "# intermediate value y.\n", + "dz_dy = t.gradient(z, y)\n", + "assert dz_dy.numpy() == 8.0\n", + "\n", + "# Derivative of z with respect to the original input tensor x\n", + "dz_dx = t.gradient(z, x)\n", + "for i in [0, 1]:\n", + " for j in [0, 1]:\n", + " assert dz_dx[i][j].numpy() == 8.0" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "YBeJYxY8YaiO" + "id": "DK05KXrAAld3" }, "source": [ - "### Step 5: Define a training step\n", - "\n", - "To fit model variables to the data we'll need to:\n", + "### Higher-order gradients\n", "\n", - "1. Calculate the gradients of the loss with respect to the model variables.\n", - "2. Use `optimizer` to compute updates to the variable values based on those gradients.\n", - "\n", - "To calculate the gradients, we use the [`tf.GradientTape`](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context manager\n", - "and its `gradient` function to compute gradients through computation conducted within its context:\n" + "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { - "cellView": "code", "colab": { "autoexec": { "startup": false, @@ -408,163 +274,37 @@ } }, "colab_type": "code", - "id": "diDZfrMJM3OC" + "id": "cPQgthZ7ugRJ" }, "outputs": [], "source": [ - "def run_step(inputs, labels):\n", - " with tf.GradientTape() as g:\n", - " loss = loss_fn(wb(inputs), labels)\n", - " # Compute the partial derivatives of loss with respect to the variables\n", - " grads = g.gradient(loss, wb.variables)\n", - " optimizer.apply_gradients(zip(grads, wb.variables))\n", - " return loss" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "1WWepgmJQOzc" - }, - "source": [ - "Repeatedly running the training step will nudge the variables towards the values that best fit the data (i.e., \"w\" will move closer to 3.0, while \"b\" will tend to 2.0):\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 51 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 380, - "status": "ok", - "timestamp": 1525154412590, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "ya5Qxz5XQlhU", - "outputId": "8dd47155-a6c1-44c5-c279-617c803f1723" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Values of w, b BEFORE applying gradients: 2.725763, 1.894334\n", - "Values of w, b AFTER applying gradients: 2.774932, 1.922555\n" - ] - } - ], - "source": [ - "w, b = wb.variables\n", - "print(\"Values of w, b BEFORE applying gradients: %f, %f\" % (w.numpy(), b.numpy()))\n", - "run_step(inputs, labels)\n", - "print(\"Values of w, b AFTER applying gradients: %f, %f\" % (w.numpy(), b.numpy()))\n" + "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n", + "\n", + "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n", + "\n", + "with tf.GradientTape() as t:\n", + " with tf.GradientTape() as t2:\n", + " t2.watch(x)\n", + " y = x * x * x\n", + " # Compute the gradient inside the 't' context manager\n", + " # which means the gradient computation is differentiable as well.\n", + " dy_dx = t2.gradient(y, x)\n", + "d2y_dx2 = t.gradient(dy_dx, x)\n", + "\n", + "assert dy_dx.numpy() == 3.0\n", + "assert d2y_dx2.numpy() == 6.0" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", - "id": "61TgeLVlKEQp" - }, - "source": [ - "## Step 6: Create a training loop\n", - "\n", - "Of course, now we can simply turn all of this code into a self-standing training loop. We'll also capture our loss and approximations of `w` and `b` and plot them over time." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - }, - "base_uri": "https://localhost:8080/", - "height": 364 - }, - "colab_type": "code", - "executionInfo": { - "elapsed": 580, - "status": "ok", - "timestamp": 1525154278709, - "user": { - "displayName": "", - "photoUrl": "", - "userId": "" - }, - "user_tz": 420 - }, - "id": "VukGe-huNaJ4", - "outputId": "c79c8e63-c781-451e-f74f-20815d8da49f" + "id": "4U1KKzUpNl58" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0.9409681558609009, 1.3733772039413452, 1.7128530740737915, 1.9793939590454102, 2.188689708709717, 2.3530514240264893, 2.4821391105651855, 2.583533763885498, 2.6631851196289062, 2.7257626056671143]\n" - ] - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd8AAAFKCAYAAABcq1WoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xd4U2X/BvD7ZLRpumlLS6EDgbKh\niIggU7aAgPhDRKsIUoYgiK++ioAguBARXmZBEARFUBGhiChIEQcqe+/RMlpGd9KRcX5/nDZtaFra\nkuY07f25rlw5zXmSfPMk5OY5Oec8giiKIoiIiMhhFHIXQEREVN0wfImIiByM4UtERORgDF8iIiIH\nY/gSERE5GMOXiIjIwVSOeJJbtzLs/pi+vlqkpOjt/rhkjf3sGOxnx2A/Owb7WRIQ4FnsOqcd+apU\nSrlLqBbYz47BfnYM9rNjsJ/vzWnDl4iIyFkxfImIiByM4UtERORgDF8iIiIHY/gSERE5GMOXiIjI\nwRi+REREDsbwJSIih/vxx61YtGi+3GXIhuFLRETkYA45vSQREZEtGzeux65dPwMAOnbsjOeeG45/\n/tmHFSuWwNVVA1/fGnjnndk4eHB/kdtUKueNMKesPDZ2C7p37wSNxkfuUoiInN6MGVOxdetmuz2e\nQiGgb98BmDFjdontbty4hgMH/sGKFV8AAKKjX0DXrt3x3XcbMH78q2jZshX27PkVaWmpNm/z8/O3\nW82O5nSbnTMy0jFixHN48cUX5S6FiIjuw9mzZ9G0aXOoVCqoVCo0b94S58+fRdeu3fHxxx/giy9W\noUGDhvDz87d5mzNzupGvp6cXOnTohF27duHkyRNo0qSp3CURETm1GTNm33OUWhYBAZ6lms1OEABR\nFC1/GwwGCIICvXv3Rdu27fDbb3H4739fxezZc2zeFhYWbreaHc3pRr4AEB09DgCwYsVSmSshIqLy\niohoiOPHj8FoNMJoNOLkyROIiGiI1as/g1KpwoABT6Jbt564fPmizducmdONfAGgR49eqFevHr79\ndgPefnsG/P2de/MDEVF1FBQUjFatHsKECdEwm0X07z8AQUG1EBgYhEmTxsHT0wuenp4YOvQ56PX6\nIrc5M0EsPOavIKXZ/FBW69d/jokTJ+LNN6di8uQ37P74JCnt5iO6P+xnx2A/Owb7WRIQ4FnsOqfc\n7AwAL774Ijw9vbBq1Qrk5ubKXQ4REVGpOW34enp6YtiwKNy8mYQfftgkdzlERESl5rThCwAvvTQa\nCoUCMTFL4ICt50RERHbh1OEbFhaO3r374ujRw/j7731yl0NERFQqTh2+ADB6tHTY0fLlS2SuhIiI\nqHScPnwfeaQ9mjdviR9/3Ir4+Ctyl0NERHRPTh++giAgOnoszGYzVq5cLnc5REQkk/Pnz1kGYe+8\n8xZycrLL/ViHDx9ESkqyvUorwunDFwAGDhyMgICa+PLLL5CZmSl3OUREJIM9e35FQkI8AGDmzA/g\n6qop92Nt27alQsPXKc9wdTdXV1e8+OJLmDPnfWzY8BVGjoyWuyQiIrqHYcMGY+3ajRBFEX36PIaF\nC5ehUaMmmDx5PN54420EBdWCyWTCnDnv4fr1azAajXjppTFo3boNtm+PxaZNG6FSqVG/fgQGDhyM\nH37YhD17foWvry+mT38LX3yxAZ9+Oge+vr44c+Y0UlNT8OyzL2Dbtq1IS0vFokXLIQjAzJlTkZWV\nhezsbLz66uvQ6TKxd28cLl26iNmz5+DMmZP4+ut1UCpVaNiwMSZMePW+X3uVCF8AeOGFkZg/fy5W\nrFiKF198CQpFlRjUExFVOPcZU+FqxykFoRDg3ncAdPeYrKFhw8a4ePECjEYDGjVqjOPHjyIiohGS\nk5MRFFQLAPDLLz/Bz88fb701HampqZg4cQzWrPkaX3+9DnPmzEdgYBC2bduCOnXqoG3bdujSpRua\nNGlm9TxKpQoLFizFzJlTcezYUSxYsASzZk3DwYP7ER5eF/36DUSnTl1w4MC/+PLLNXjvvY9Rv34E\nJk9+A15eXlizZiWWLfscLi4umDbtTRw9ehgtWkTeVxdVmfANCAjA4MFDsH79Ouza9TN69Ogtd0lE\nRFSCyMgHceLEMeTm5uCpp57Gnj270bLleURENLS0OX78KI4cOYSjRw8DAHJycmAwGNC9ey9MmfI6\nevXqg+7de5W4iblxY2n2Oz8/f8tMSL6+ftDpMlGjhh/WrPkM69evhcFggEZj/TiXLl1EUlIiJk8e\nDwDQ6TKRmJiIFi3u77VXmfAFgFGjxmL9+nWIiVnK8CUiKiXdjNn3HKWWRUCAJ3SlOLdzq1atsW7d\nauTkZKNfvwHYtm0rjh07ggcffMjSRqVS4/nnRxT5To+KehE9evRBXNxOvPLKWCxeXPwOt0ql0uay\nKIrYuPEr+PvXxLRps3D69EksWjTf6r5qtbSped68Rfd8PWVRpbbNNmvWHB06dMJvv+3GqVMn5S6H\niIhKEBoahqSkJGRm6qDVusPPzw9798ZZhW+TJs3w++97AAApKcmIiVkMs9mMmJjF8Pf3x9Chz6FZ\ns+ZITEyEIAgwmUxlqiEtLRW1a9cBAOzZsxtGoxEAoFAoYDKZEBoajsuXL1l2vlq5Mga3bt2879de\nqvA9e/YsunfvjnXr1gEAbty4gaioKAwbNgwTJ06sVBMbcK5fIiLn4evri6CgIABS0N64cQM1awZa\n1j/2WHe4uWkxZswIvPHGq2jRIhIKhQJarTtGj34REyeOhSAIaNAgAi1btsL8+R9j//5/Sv38vXv3\nxYYNX+LVV19G06bNcOfOHWzbtgWRkQ9i6tT/4vr1a5g48TX85z8TMXbsCKSlpcLfP+C+X/c9pxTU\n6/UYPXo0wsPD0bBhQzz33HN466230KlTJ/Tp0wfz5s1DUFAQhg0bVuxjVMTUUsVNWWUymdCu3YO4\nceM6Dh06xbl+7xOnBnMM9rNjsJ8dg/0sua8pBV1cXLBixQrUrFnTctvff/+Nbt26AQC6du2Kv/76\nyw5l2odSqcSoUWOQk5ODtWs/l7scIiKiIu4ZviqVqsjeX1lZWXBxcQEA+Pn54datWxVTXTk988xz\nnOuXiIgqrfve27k0U/n5+mqhUinv2a6sihvSBwR44qWXRuLTTz9FXNxPePbZZ+3+3NVJSZtOyH7Y\nz47BfnYM9nPJyhW+Wq0W2dnZ0Gg0SEpKstokbUtKir5cxZXkXr8pDBv2IhYsWIC5cz9Bjx79IQiC\n3WuoDvjbjWOwnx2D/ewY7GfJff3ma0v79u2xY8cOAMDPP/+Mjh07lq+yChQWFo5evR7H4cOH8M8/\nf8tdDhERkcU9w/f48eOIiorC999/jy+++AJRUVEYP348Nm/ejGHDhiE1NRUDBw50RK1lxrl+iYio\nMrrnoUb24MhDjQoTRRHdunXEyZPH8e+/RxESEmr3Oqo6bj5yDPazY7CfHcPe/RwXtwtdunSz2+M5\nit03OzsLzvVLROTcbty4jp07d8hdht1V6fAFgEGDnoK/fwDWrVvDuX6JiCqRYcMGw2QywWg0okeP\nTjh9Wjot8OTJ45GYeAMAMG/eRzh8+CA+/3wFVq6MwaxZ0zFu3EvYv/8fTJ36huWx+vaVRsaXLl3E\nK6+MwcSJY/HWW68hI6Nybumo8uGbP9dvenoaNmz4Su5yiIgqpRqtm9m8aAptNfQcN8pmG8/o4ZY2\nmrWrgfDwUj1n/pSC586dsUwpaDabraYUfOaZKERGPogXXxwFADAaDViy5LNip42dP/9jvP76FCxY\nsBRt2jyCTZs2lqs/KlqVD19AmutXOlPXUpjNZrnLISIiFEwpeOzYETz11NM4efIELlywnlLwbvnT\nAxbn5MkT+Oij2Rg/Pho7dvxomRChsqlSUwoWp2bNmnjyyf/D119/ybl+iYhsSD5w/J5tMpasuGeb\n7Kjh8Jw8AbDTlIJ3U6vVAFDk3A35sxFpNBosXBhT6c/tUC1GvoA01y8AxMRwtiMiosqgNFMK5k/t\ndzd3d3fcuXMbAHD+/Dno9dLJnOrXb4B9+/4EAOzcuaNMMxw5UrUJ3+bNW+DRRztyrl8iokrkXlMK\nhoXVxZkzp/G//31idb/69SOg0bhhzJgR2LHjRwQFBQMAJk78D9au/Rzjx0fjxx9jS9yELacqfZzv\n3bZv34YXXngGzz33AubNW2j3mqoiHhfpGOxnx2A/Owb7WVJtj/O9W8+evREWFo5vvvkat2/flrsc\nIiKqpqpV+HKuXyIiqgyqVfgC0ly/Hh6enOuXiIhkU+3C19PTC88+G4WkpERs2fK93OUQEVE1VO3C\nFwBGjhwNQRCwfPkSOGB/MyIiIivVMnzDw+uid+++OHz4EP79t3IeA0ZERFVXtQxfgHP9EhHJ6ccf\nt2LRovl2eSydLhP//LMPALB27WocP3603I+VmJiIkyfvfbav+1Vtw7ddu0fRrFkLxMb+gISEeLnL\nISKicjpz5rQlfKOihqNZsxblfqyDB//FqVMn7FVasarFuZ1tyZ/r95VXxmLVqhV4551ZcpdERFSt\n3LhxDf/5zyu4eTMJQ4YMQ79+A6zWf/fdRuzc+RMEQYGOHbvgmWeew9mzp/HJJx9BrVbDxcUFM2d+\ngHnz5kCv1yEkJBTHjx9Fly7dkJaWisOHDyI1NRWXLl1EdPRY7Ny5A5cvX8L06bPRtGkzLFw4DydP\nnkBubi4GDhyMDh06Y9Wq5VCpVAgMDELt2iH49NM5EAQBWq0WU6bMgKdn8SfOKItqG76ANNfvu+9O\nx7p1a/Daa/+Fh4eH3CURETncjBmu2LrVfnGgUAB9+7pixoycEtslJMRj1aovodNlYvjwYejb9wnL\nhAjXr19DXNwuLFmyEgAwduxIdO3aHT/+uBWDBj2F3r374sCBf5GcfAfDhkXh4sULGDDgSatNzgkJ\n8Viy5DNs3boZ69atxqpVX2L79q3YuXMH6tdvgKCgYEyYMBk5OdkYMmQg+vcfiD59+sHHxwcdOnTG\nxIlj8frrUxASEopNm77Bpk0b8cILI+3SR9U6fPPn+v344w+wceN6jBgxSu6SiIiqjRYtIqFSqeDt\n7QN3d3ekpaXBx8cHAHDq1AlcvZqACRNGAwD0eh0SE6+jQ4fOmDv3QyQkxKNbtx4ICwvHiRPHbD5+\no0ZNIAgC/Pz8Ua9eAyiVSvj6+kGnOwJXV1ekp6dhzJgRUKlUSE1NKXL//OkJAcBgMKBx4yZ2e+3V\nOnwBaa7fBQs+wYoVSzF8+MhiJ2gmIqqqZszIuecotSykczuX5vGsp/0rPAugSqVGu3aP4o033i5y\nr88++wJ//rkXs2fPwPjxk4p9dKVSaXNZFEUcOnQABw/ux6JF0mbmHj06Frl/RU5PWO2TJn+u3wsX\nzuPXX3+RuxwiomrjxImjMJlMSElJQVZWFry8vC3rGjZsjIMHDyA7OxuiKGL+/LnIycnGd99tQHp6\nGnr27IOnnx6Gs2dPQxAEm9MOliQtLRU1awZCpVLh99/3wGQyw2AwWE1hWJHTE1b7kS8gzfX79ddf\nIiZmCbp37yV3OURE1UJoaDimTXsT164lIDp6nNUIMygoCEOGPIOXXx4FhUKBTp26wNVVg9q1QzBt\n2pvw8PCAWq3GlCnvIDU1BcuWLURAQM1SP/dDD7XFl1+uwfjx0ejYsTPat++AuXM/QPfuPTF79gz4\n+Phi4sT/YM6c9/Dll2vg4uKKGTNm2+21V6spBUsyaFBf/PHHXvz2299o1Kix3R7X2XFqMMdgPzsG\n+9kx2M8STilYCtHR0kk3VqxYKnMlRERU1TF88xSe6/fOnTtyl0NERFUYwzdP/ly/2dnZnOuXiIgq\nFMO3EM71S0REjsDwLSR/rt/ExBvYunWz3OUQEVEVxfC9S/5cvzExiznXLxERVQiG71041y8RUcUr\nzZSCu3fvdFA1jsfwtYFz/RIRyW/dujVyl1BhGL42cK5fIqKKlz+l4PPPP43Y2B+s1n311Rc4f/4s\npkx5HQcP7scbb0zC+PHROH36FPr27WZpN3XqGzh4cD/0eh2mTn0DEyeOxfjx0Th//pyjX06ZMHxt\nyJ/r12w2Y9WqFXKXQ0RU4Vq3drd5WblSbWkzbpzGZpvoaI2lzdq1aoSHl+45ExLi8eGH87BwYQxW\nroyx2s9m2LDn4eHhgfff/xgAcOHCecybt6jYMxBu3Lgebdu2x4IFS/Haa29i0aJPy94JDsTwLcbA\ngYPh7x+AdevWIDMzU+5yiIiqHFtTChanfv0GcHFxKXb9sWNHsXnzdxg/PhqffPIhdLrK/b3NiRWK\nodFoMHz4SMyd+yHn+iWiKu/AAd092yxZkn3PNlFRBkyerMGtW6V51uKnFLybWq22ebvRaMxbr8Kr\nr76OZs1alOaJZceRbwleeGEkXFxcsGLFUpjNZrnLISKqUkqaUhAAzGbbh3sKgoDs7GxkZ2fj7Nkz\nAIAmTZrht9/iAACXLl3E11+vq9Da7xfDtwSBgYEYNOgpzvVLRFQB8qcUnDRpbJEpBQEgIqIhRo16\nvsj9Bg58CtHRL+D992eiYUPpN+Cnnnoa164lYNy4l/DRR7MRGfmgQ15DeXFKwXs4duwIunXriM6d\nu+Kbb3649x2qGE4N5hjsZ8dgPzsG+1nCKQXvQ/PmLdG+fQfs2bMbp0+fkrscIiKqAhi+pcC5fomI\nyJ4YvqXQq1cfhIZyrl8iIrIPhm8pSHP9jkZ2djbWrVstdzlEROTkGL6lNGxYFDw8PLFy5XIYDAa5\nyyEiIifG8C0lT08vDBv2HOf6JSKi+8bwLQPO9UtERPbA8C2DunUfQK9ej+PQoYPYv59z/RIRUfmU\nK3x1Oh3Gjx+PqKgoDB06FHv37rV3XZVWwVy/POyIiIjKp1zh+/3336Nu3bpYu3YtFixYgPfee8/e\ndVVa7dt3QNOmzREb+wOuXk2QuxwiInJC5QpfX19fpKamAgDS09Ph6+tr16IqM0EQMHr0OJhMJs71\nS0RE5VLuczuPHDkS8fHxSE9PR0xMDCIjI4ttazSaoFIpy11kZZOdnY2wsDDk5ubi6tWrcHd3l7sk\nIiJyIuWaz/eHH35AcHAwVq5cidOnT2PKlCnYtGlTse1TUvTlLrA4cp+4+/nnR2Du3A+xePFyvPji\nS7LVUdHk7ufqgv3sGOxnx2A/S+w+scLBgwfRoUMHAECjRo1w8+ZNmEym8lXnpDjXLxERlVe5wjcs\nLAxHjhwBAFy7dg3u7u5QKqvOZuXSyJ/r9/z5c9i9e6fc5RARkRMpV/g+/fTTuHbtGp577jm89tpr\nmDFjhp3Lcg7R0WMBADExS2SuhIiInEm5fvN1d3fHggUL7F2L08mf6zcu7lecPn0KjRo1lrskIiJy\nAjzD1X0qmOt3mcyVEBGRs2D43qeCuX7XIzmZc/0SEdG9MXzvU+G5fteuXS13OURE5AQYvnbAuX6J\niKgsGL52wLl+iYioLBi+dsK5fomIqLQYvnbCuX6JiKi0GL52xLl+iYioNBi+dsS5fomIqDQYvnbE\nuX6JiKg0GL52NnDgYPj7B2Dt2tXQ6XRyl0NERJUQw9fONBoNhg8fibS0VGzcuF7ucoiIqBJi+FYA\nzvVLREQlYfhWgMDAQAwcOBjnz59DXNwuucshIqJKhuFbQTjXLxERFYfhW0FatIhEu3aPYvfuXThz\n5rTc5RARUSXC8K1AnOuXiIhsYfhWoN69H0doaBjn+iUiIisM3wqkVCrx0kujkZWVhXXr1shdDhER\nVRIM3wo2bFgU3N09ONcvERFZMHwrmJeXN4YNew43blxHbOwPcpdDRESVAMPXAV56aQwEQcAHH8xC\nZmam3OUQEZHMGL4OULfuA3j55Ym4fPkSpk79r9zlEBGRzBi+DvLmm1PRokUkvvpqLbZu3Sx3OURE\nJCOGr4O4uLhg2bKVcHNzw+TJr+Datatyl0RERDJh+DpQ/foNMGvWh0hLS8X48aNhMpnkLomIiGTA\n8HWwqKjh6NOnH/74Yy8WL/6f3OUQEZEMGL4OJggC5s1biMDAIHz44SwcPnxQ7pKIiMjBGL4y8PPz\nw6JFMTAajRgzZiR0Op3cJRERkQMxfGXSuXNXjBv3Ci5evIBp096UuxwiInIghq+M3nprGpo1a4F1\n69Zg61ae/YqIqLpg+MrI1dXVcvjRa69NwPXr1+QuiYiIHIDhK7OIiIaYOfN9pKamYsKEMTCbzXKX\nREREFYzhWwm88MII9O79OPbu3YMlSxbKXQ4REVUwhm8lIB1+tAg1awbigw/exdGjh+UuiYiIKhDD\nt5Lw9/fHwoXLYDAYePgREVEVx/CtRLp27YbRo1/G+fPnMH36FLnLISKiCsLwrWSmTp2Bpk2bY+3a\nz/Hjj7Fyl0NERBWA4VvJ5B9+pNFoMHnyeCQm3pC7JCIisjOGbyXUsGEjzJjxHpKTk/Hyy6N5+BER\nURXD8K2kXnzxJfTs2Rt798Zh2bLFcpdDRER2xPCtpARBwKefLkZAQE28994MHDt2RO6SiIjIThi+\nlVhAQAAWLlxqOfxIr9fLXRIREdkBw7eSe+yxHoiOHotz587inXfelrscIiKyA4avE5g6dSYaN26K\nNWtW4qeffpS7HCIiuk/lDt8tW7bgiSeewJNPPom4uDg7lkR302g0WLZsJVxdXfHqqy8jKSlR7pKI\niOg+lCt8U1JSsHjxYnz11VdYtmwZdu3aZe+66C6NGzfBjBmzcefOHc5+RETk5MoVvn/99RfatWsH\nDw8P1KxZE7NmzbJ3XWTDiBHR6N69J+LifsXy5UvkLoeIiMqpXOF79epVZGdnY8yYMRg2bBj++usv\ne9dFNgiCgAULlsLfPwCzZ8/AsWNH5S6JiIjKQRBFUSzrnZYvX46DBw9i0aJFuH79Op5//nns3r0b\ngiDYbG80mqBSKe+7WJJs374djz/+OBo3boz9+/dDq9XKXRIREZWBqjx38vPzQ6tWraBSqRAaGgp3\nd3ckJyfDz8/PZvuUFPsfnxoQ4IlbtzLs/rjO4KGHOuCll0bjs89iMH78RHz00bwKe67q3M+OxH52\nDPazY7CfJQEBnsWuK9dm5w4dOmDfvn0wm81ISUmBXq+Hr69vuQuksps+fRYaN26Czz//DD//vF3u\ncoiIqAzKFb6BgYHo1asXhgwZglGjRmHq1KlQKHjIsCNpNBosXSodfjRx4jgkJSXJXRIREZVSuX7z\nLauK2PzAzRqSFSuW4u23/4uuXbth/frv7P6fIPazY7CfHYP97BjsZ4ndNztT5fHSS2Pw2GPdsXv3\nLnz22TK5yyEiolJg+Dq5gsOP/PHuu9Nx4sRxuUsiIqJ7YPhWAYGBgZg/fzFyc3MxduxIZGVlyV0S\nERGVgOFbRfTs2QcjRozC6dOn8O670+Quh4iISsDwrULeeWc2GjZshJUrl2Pnzh1yl0NERMVg+FYh\nbm5uWLZsFVxcXPDKK+Nw8+ZNuUsiIiIbGL5VTNOmzTBt2kzcvn0LEyeOhQOOJCMiojJi+FZBo0aN\nRZcuj2HXrl+wcmWM3OUQEdFdGL5VkEKhwMKFy+Dn54eZM6fh1KmTcpdERESFMHyrqMDAIHz66WLk\n5ORgzJgRyM7OlrskIiLKw/Ctwnr3fhzDh4/EqVMnMWvWdLnLISKiPAzfKm7GjPcQEdEQK1Ysw65d\nP8tdDhERgeFb5Wm1WixdutJy+NGtW7fkLomIqNpj+FYDzZu3wNtvz8CtWzcxadI4Hn5ERCQzhm81\nMXr0OHTu3BW//LIDq1atkLscIqJqjeFbTeQfflSjRg3MnDkVp0+fkrskIqJqi+FbjQQF1cKnny5G\ndnY2xowZycOPiIhkwvCtZvr06Yvnnx+BkyeP4733ZspdDhFRtcTwrYZmznwP9es3QEzMYuzevUvu\ncoiIqh2GbzXk7u6OZctWQq1WY8KEMbh9+7bcJRERVSsM32qqRYtIvPXWdNy8mYRXX32Zhx8RETkQ\nw7caGzduAjp27IIdO7Zj9eqVcpdDRFRtMHyrMYVCgUWLlsHX1xfvvDMFZ8+ekbskIqJqgeFbzdWq\nFYx58xYhOzsbo0ePQE5OjtwlERFVeQxfQt++/REVNRwnThzD+++/K3c5RERVHsOXAADvvvsB6tWr\nj6VLFyIu7le5yyEiqtIYvgSg4PAjlUqFCRPG4M6dO3KXRERUZTF8yaJly1Z4881pSEpKxKuvjufh\nR0REFYThS1bGj5+IDh064aeftuGLLz6XuxwioiqJ4UtWpMOPYuDj44Pp09/C6dOn5S6JiKjKYfhS\nEcHBtfHJJwuRlZWF/v374+LFC3KXRERUpTB8yab+/Qdg8uTXcf78eTz+eDf8/fc+uUsiIqoyGL5U\nrDffnIbly5cjLS0Ngwf3w/fffyt3SUREVQLDl0o0atQorF//HVxdNRg9egQ+/fRj7gVNRHSfGL50\nT126PIbY2J9Rp04IPvhgFiZNehm5ublyl0VE5LQYvlQqjRs3wfbtuxAZ2Qrr16/DM88MRlpaqtxl\nERE5JYYvlVpgYBC+//5H9O7dF3v37kHfvj1w5cplucsiInI6DF8qE3d3d3z++TqMGTMeZ8+eQZ8+\nj+HAgX/lLouIyKkwfKnMlEol3n33fXz44SdITk7GoEF9sXXrD3KXRUTkNBi+VG4jRozCunUboFSq\nMHJkFBYtWsA9oYmISoHhS/ele/de2LLlJ9SqFYx3352G//xnEgwGg9xlERFVagxfum/Nm7fATz/9\nimbNWmDt2s/x7LP/h/T0NLnLIiKqtBi+ZBe1agVjy5af0KNHL8TF/Yr+/Xvh6tUEucsiIqqUGL5k\nNx4eHlizZj1GjozGqVMn0bv3Yzhy5JDcZRERVToMX7IrlUqFDz6Yi9mzP8StWzcxYEAfbN++Te6y\niIgqFYYvVYjo6HFYvforAMDw4cMQE7OYe0ITEeW5r/DNzs5G9+7dsWnTJnvVQ1VInz598cMP2xEQ\nUBPTpr2FKVNeh9FolLssIiLZ3Vf4Ll26FN7e3vaqhaqgli1b4aeffkXjxk2xcuVyvPDCM8jMzJS7\nLCIiWZU7fC9cuIDz58+jS5cudiyHqqI6dUIQG7sDXbo8hl9+2YEnnuiNGzeuy10WEZFsBLGcP8RF\nR0dj2rRp2Lx5M2rXro0nn3xblRT0AAAgAElEQVSy2LZGowkqlbLcRVLVYDAYMH78eCxfvhy1a9dG\nbGwsIiMj5S6LiMjhVOW50+bNmxEZGYmQkJBStU9J0ZfnaUoUEOCJW7cy7P64ZM3e/Txr1seoVSsU\nM2dOxaOPdsBnn61G9+697Pb4zoqfZ8dgPzsG+1kSEOBZ7LpyhW9cXBwSEhIQFxeHxMREuLi4ICgo\nCO3bty93kVQ9CIKAl19+BaGhYXj55VF47rmn8f77H2PEiFFyl0ZE5DDlCt/58+dblhcuXIjatWsz\neKlM+vcfgODgYERFDcWbb76GS5cuYsaM2VAq+fMEEVV9PM6XZNO6dRts374LERENEROzGC+++Bx0\nOp3cZRERVbj7Dt8JEyaUuLMVUUnCwsKxbdsv6NixM376aRsGDnwcSUmJcpdFRFShOPIl2Xl7+2D9\n+u/wzDPP4ciRQ+jTpxtOnTopd1lERBWG4UuVgouLC+bPX4wpU6bj6tUE9OvXE7t375K7LCKiCsHw\npUpDEARMmvQfxMSsQm5uDoYNewpr166WuywiIrtj+FKlM2jQU/j2263w9vbGa6+9glmz3oHZbJa7\nLCIiu2H4UqXUtu0j+PHHXahXrz4WLvwUo0YNR1ZWltxlERHZBcOXKq0HHqiHH3/ciXbtHsXWrZvx\n5JN9cevWLbnLIiK6bwxfqtR8fWtg48bNeOqpp3HgwH706dMNZ8+ekbssIqL7wvClSs/V1RWLFy/H\n66+/hfj4y+jbtwd+//03ucsiIio3hi85BUEQ8Prrb2HRohjo9ToMGTIQX3/9pdxlERGVC8OXnMqQ\nIc/gm29+gIeHB155ZSw+/HAWyjkrJhGRbBi+5HTat++AH3/chbCwcMyb9zHGjh2J7OxsucsiIio1\nhi85pfr1G2D79l/Rpk1bbNr0Lf7v/wbgzp07cpdFRFQqDF9yWv7+/vjuu60YOPBJ/P33X3j88W64\nePG83GUREd0Tw5ecmkajwbJlqzBp0n9w6dJF9OnTjXtCE1Glx/Alp6dQKDBlynTMn78YGRkZePLJ\nfnjhhWE4ffqU3KUREdnE8KUqY9iwKGzZ8hPatGmL7dtj0aVLO0yYMAbx8VfkLo2IyArDl6qUhx56\nGLGxP2Pdug1o1KgJNmz4Cu3aPYgpU17HzZs35S6PiAgAw5eqIEEQ0LNnH/z66+9YuvQzBAfXxmef\nxeDhh1vigw/eRVpaqtwlElE1x/ClKkuhUGDw4CH4888DmDPnU3h6euLTT+eiTZsWWLhwPvR6vdwl\nElE1xfClKk+tVmP48JH4++/DmDp1JgBg1qzpaNs2EqtXr4TBYJC5QiKqbgTRAefmu3Urw+6PGdCm\nOUzmoqXrx72C7JHRAADPcaOg/vuvIm0MrR9CxvLVAADN2tXQzp9r8zmS/zoIuLhAee4svIc+abNN\nxryFMHTuCgDw6dUFitu3i7TJHvIM9P99GwDg/s7bcI39oUgbU2gY0r7fBgBw2b4NHlP/a/P5Urfu\ngDm4NoTUFPh262izjW7KdOQMHgIA8Hr2/6CysddvbtfuyJw7HwDgtnA+3FZ/VqSNqNVCdfoUbt3K\ngGr/P/AaPcLm86WvWgtjy1YAAN+2kRCMxiJtsqLHImv0ywAAj0kvw2XvniJtjM1bIn21dL5m16+/\nhPvHH9h8vuQ9+wAPDyguX4LP4P4222TOmYfcbj0BAD79ekJx47plndlsRkZ6OlZl6fG60Yjw8Lr4\nrmEjtDxxHBAEq8cx1wpGauzPAACXXT/D443JNp8v9butMIfXBTIzUaPzIzbb6F5/CzlDnwUAeA1/\nFqpjRyzrlAoBJrOI3I6dkTl/MQDALWYx3JYvLfI4okqFlL8PAwBURw7Ba0SUzedLj1kF40MPAwB8\nOz4MwcZIP2v4S8iaMAkA4PGfSXDZvbNIG2Ojxkj/8hsAgOt3G+H+/rs2ny9l116IPr5QXL8Gn/69\nbLbJnP0Rcvv0BQB4D+oLpY2d4XL6DYBu5nsAAO1H70GzcX2RNmZ/f6TuiAMAqPfshufkCTafL+3r\nTTA1iAByc1Gj3YOWfi5MP+k/yI4aDgDwjB4O9YH9RR7H0LYdMpasAABoVi6Hdsn/bD5f8oHjAADl\nyRPwjnraZpuMRTEwtHsUAODb9VEI6WlF2mQ/+zz0k98AALhPeR2uO7YXaWOqVx9pGzcDAFy2bobH\njKk2ny9l+68Qa9aEcPMmfPs8ZrNN5ozZyO0/EADgPWQglBeKHi+f06sPdO9/DADQzpsDzZdfFGkj\nenkjZfcfCAjwROqWn+A5frTN50tbuwGmJk0BADVaN7PZRs7vcnsJCPAsdp3Krs9E5AQUCgW8fXzw\nwpBncBoivvjic+y4fAmBajW8vX3g5uYmd4lEVMU578g3wLNCHpesVYd+vnLlMj7++AN8883XEEUR\nDz/8CKZOnYFHHmnvsBqqQz9XBuxnx2A/S0oa+fI3X6r2wsLCsWhRDPbs2Yc+ffrhn3/24YknemPo\n0CdxrNCmYSIie2H4EuVp1Kgx1qz5Ctu370KHDp3w66870a1bR0RHD+c5o4nIrhi+RHdp3boNvvtu\nKzZu3IzIyFbYvHkTHn20DV577RVcv35N7vKIqApg+BLZIAgCunR5DDt2xGHlyrV44IF6WLt2Ndq2\njcQ777yN5GROX0hE5cfwJSqBIAjo338A9uzZhwULliAgoCaWLl2Ihx5qgblzP0RmJncqIXJqZjOg\n0wGZmQ59Wu7tTCViP1vLycnBmjUrMX/+XNy+fRv+/v6YOPE1vPDCSGg0mnI/LvvZMdjPjnHf/SyK\ngMEAITsLQlYWoNdDyM6GkKWHkJUFITsL0GdJfxe6HdlZEPRZBW2yCrXRF2pT+PbsbOkpFQqkffUt\nDI91t1MvlLy3M8OXSsR+ti0zMwMxMUuwZMlCZGSko3btOnj99bcwZMgzUKnKfvg8+9kx2M92IoqA\nTgdBp4Ogy4Sg00Ghy4SgywR0OngrTMhISraEoJCVBdwVggXhWNBG0OuB/DA1mexbsloN0U0L0c0N\n0GggarUQNRrLbaKPL3RvvwNznRC7PSfDl8qN/Vyy5OQ7+N//PsWqVcuRnZ2NBg0i8Oab09Cv3xMQ\n7jpbVknYz45RLftZFKWQKxSUQmZmwbLuruXM/GUdhMyMguXC6/Q6CHaIDlEQADc3KfzcCsIQbm4Q\nNW4QtXnrNG557Qq3KRScGqkd7g5UjRugzbsux3+K7xfDl8qN/Vw6169fwyeffISvvloLk8mEyMhW\nmDLlHXTu3LVUIcx+dgyn6Oe8sFSkp0FIS4OQnlZ8GBYXmlZ/Z0Iwm8tfjiBAdPeA6O4uXTw8Cy17\nWK9zl9Z5BvkhzaQoCNG84LQEZn6AuroWOaVrVcLwpXJjP5fNhQvn8NFH72Hz5k0AgA4dOuHtt99B\n69ZtSrwf+9kxHNLPZjOEjPS84Ey3CtGC5XTp70LLVuttnB+9tEStuyUMzR6eQKHl/Nvh7gGzR35o\neuSFqDtErUeRUIVWW+aA5OdZwvClcmM/l8+xY0fw/vvvYteuXwAAvXv3xVtvTUPjxk1stmc/O0ap\n+jk3F0JaGhQZxQVlqnVopqVBUWhZyMwo8yZZUauF2csbopcXRC9vmL3zl32k68JBefeIs/Cy1h1Q\nyH8QCz/PEoYvlRv7+f7s2/cnZs+egX/+2QdBEPDUU0/jjTemICws3Kod+9nOjEYIKSlQJN+BIvkO\nhDvStWeuDvrrN4sP0Yx0aeegMhAFAaKnF0Rvbyk4ve5a9vKG6C2FqNnLO29d/rJ0O9TqCuoIefDz\nLGH4Urmxn++fKIrYuXMH3nvvXZw8eRxqtRpRUcPx6qtvIDAwEAD7uURmM4TUFCiSky0hWjhQFcl3\nICTfgeJO3nVyMhRpqaV+eFGthujtDbMlQPOCMi9ArZel8LSMUr29IXp4VorRZmXCz7OE4Uvlxn62\nH7PZjM2bv8OHH87G5cuXoNVqER09Di+//Arq1w+pHv0sitIo0yo4k6G4Yx2iluWUZAgpKaXaYUhU\nqWCu4QfRz0+6riFdm/1qWJa9wmsjRVRLI868wIWbW5Xe6UcO/N6QMHyp3NjP9mcwGPDVV2sxd+6H\nSEpKhI+PDyZNmoQePfqhQYMIucsrPVGU9qS1MRq1CtSUZGn9HWm5NDsTiQoFxBo1pPDMD1K/QoFa\no4YUsr41LIErenrdM0T5eXYM9rOE4Uvlxn6uOHq9HitXLsfChfOQmiptJm3YsBH69n0C/foNQNOm\nzcp0rLDdmM0Qbt+GMukGFEmJUCQlQZFYaDnphnR96yaE3NzSPaSvr1WIWo9M85Z9a0D0ywtTb58K\n2ZTLz7NjsJ8lDF8qN/ZzxcvISMcff/yKr77agLi4XcjOO91deHhd9Os3AP36PYFWrVrffxAbjVDc\nvpUXoolQJCbeFah5t926WeLZhUQXF5gDg2AOCIDZz79oiNbwsx61+vjIcoIDW/h5dgz2s4ThS+XG\nfnaM/H7OzMzErl0/IzZ2C375ZQf0eh0AIDi4Nvr27Y9+/Qbg4YcfgVKpLLizwQDFrZt5o9OkvBC9\nAcXNJOuQvX2rxN9ORY0G5ppBMAcGwhxUC6bAQClk8y9BtWAODIToW8NpfyPl59kx2M8Shi+VG/vZ\nMWz1c1ZqKvbH/oCD27bg8l9/wFuvRy0AdTUaNK/hh1C1Gp46HRR3bpd4XKmo1cJcMxCmoFp5IRpk\nFbJSuAZKm3qdNFRLi59nx2A/S0oK38qxLYiousnKgjIhHsqEK1DExwOpt+B58Yr1ZuDkZIQCePLu\n+2ZnA9evIQPARYUCBv8AuNVvAP9mzSEE15HC1TJaDZIOhanioUrkbBi+RBUhOxvKawlQXLmSF7Lx\nUMRflpbj46G4dbPIXfInJDR7ecMcGAhj0+Yw1wwsGK3mBaohoCb+jr+CH3b9jG3btuLGjevArZvw\nOHYUPXr0RL/QAXisVWu4u7s79jUTUalxszOViP1cjNxcKK4m5IXpFSjyrqWQvQJlUqLNu4lqNcy1\n68AUGg5TaCjMIaEwhYbBq2kE7rh6wRwYJJ1Lt5TMZjMOHtyP2NgtiI3dgvj4ywAANzc3dO3aHX37\n9kevXn3g5eVtj1ft9Ph5dgz2s6RCfvOdM2cODhw4AKPRiNGjR6Nnz57FtmX4Oq9q288GAxTXrlqP\nWuPzlhPiobhx3ebvrKJSCXPtEJhCpVA1h4TCFBIKU2g4zKGhUrgW3lkqjz36WRRFHD9+DNu2/YDY\n2C04e/YMAECtVqNTpy7o128AevfuCz8/v/t6HmdWbT/PDsZ+ltg9fPft24eVK1dixYoVSElJwaBB\ngxAXF1dse4av86qy/Ww0QnHjuvWoNX85IR6K69ds7hksKhTSyDUktFCwhsEcGibdViu4XIfVVEQ/\nnz17BrGxUhAfP34UAKBUKtG+fQf07fsEHn+8H4KCatn1OSu7Kvt5rmTYzxK7h6/JZEJOTg60Wi1M\nJhPat2+PP//80/rwh0IYvs7LafvZZIIi8YYUpFcuW0asls3E167aPJZVFASYawVbNgebQkKlYM1f\nDq5dISfBr+h+vnz5ErZt24rY2B9w4MC/AABBEPDQQw+jX78B6Nu3P0JDwyrs+SsLp/08Oxn2s6RC\nDzXasGED9u/fj48//rjYNhXxJrRp4wmzjZHJuHG5GDnSkLeswd9/F/0PQevWJixfLp3IYO1aNebP\nd7H5HH/9pYOLC3DunAJDh7rZbDNvXjY6d5a+xHv10uL27aJ7lQ4ZYsB//yudCeidd1wRG1t0ZBQa\nasb330uzqWzfrsLUqa42n2/rVj2Cg0WkpgLdutneoWbKlBwMHiydwu/ZZ91w+nTRMwV17WrE3Lk5\nAICFC12wenXRQNFqRZw+rcStWxnYv1+B0aNt98GqVVlo2VJ6L9q2dYetswdGR+di9GjpfZk0yRV7\n9xbtg+bNTVi9Wnpfvv5ahY8/tt0He/bo4OEBXL4sYPBADWA0QDAYAEPetdGIJRiLx02xAIAO2Iur\nqFPwAAoloFLh/8L3YUbfP2EOCcP033rhu31hgEpptWdwrVpmxMZK78uuXUq88YYGtnz3nR7h4SIy\nM4HOnW2/L6+/noOhQ6XOGT5cg2PHCj6bCoUCZrMZHTsaMX++9L7ExKixfHnRz6ZKBfz9t3T875Ej\nCowYYft9iYnJwkMPSe9Lx45a6PXS6zIaTcjK0iMrS4/c3PkQxTkAAD+/b2A0doebmxvUhf6D0aiR\nGV9+mZX3OlV4/33b78uuXTr4+ADXrwvo39/279azZ+egTx+pDwYNckN8fNHPZr9+RsycKfXBRx+5\nYOPGop9Nf38RO3boAQB79igxebLt9+Xrr7PQoIEZublAu3buln4ubNKkXERFSZ/N6GgNDhwo+p3R\ntq0JS5ZIn82VK9VYssT2d8aBA9L7cvKkAlFRtt+XRYuy0a6d9J3RtasW6elFvzOefdaAyZOl74wp\nU1yxY0fRfy/16pmxcaP0vmzdqsKMGbbfl+3b9ahZU8TNmwL69LH9vsyYkYP+/aX3ZcgQN1y4UPR9\n6dXLiPffl96XefNc8OWXRd8XLy8Ru3frERDgiS1b9Bg/3vb7snZtFpo0kd6H1q1t/3uR87vcXirs\nUKOdO3fi22+/xapVq0ps5+urhUple1R8PxQ2Tj/n6alBQID0hms0ts9Q5+qqQECAOq998WexCwjw\nhIsLcOdO8W18fLQICJCWVSrb7dzdXREQIP3D0Gptt1GrFZY3ytu7+Ofz8/NAQEDxzwUAXl5ulppc\nXGy3c3NzQUCA9EH18LDdJn9DRkCAJ3x9i38+X193y/MplYCt8zh4eNzP+yICJhOQKwVswKyp8Dh/\nGBnH9VCkfFv0gRQKCA3qA62GAuHhwDf1gMy8sywpVZZwVT05CO4fDJJqugkoDhV9qLK+L25uxbfx\n9Cx4X1xdi7ZTKBTQaEr3vuTXVJb3Jb+di4sCLi7e8Pb2xvPPT0OdOvWwadMm/PxzMkQxFWlpqVCr\n1dBqtXB3d4eLi9ryfF5exT+fv7/0OcnJKb6Nt3dBH6jVtttptQV94F7M9LQqVUEf+JRwJsoaNaQ+\nyM0taHP390bh7wxb7wsAaDSl/86Qntd+3xnFfaZcXBSlfF+kz6bZbL/vjNK9L9p7vi9ASf9e5Psu\nd4Ryj3z37t2LBQsW4LPPPoOPj0+JbbnZ2Xk5tJ+NRigvX4Ly7Bkoz52BKu9aee4cFLpMq6aiQgFz\naBiMEQ1himgkXTeIgKlBBEQn3LO3MnyeU1NTsGPHdmzbtgW7d+9CTo40yqlb9wHLaS4jIx+U53zT\ndlIZ+rk6qMh+FkXAaATyNnbBYBBgMEj/wTIagdxcoci6/PWF/zYYhEL3kf7TMXSoAV5e9qvV7pud\nMzIyMGzYMKxevbpUe04yfJ1XhfRzVhaU589Bde4MlGfPQHXurBSyFy8UOVG/6OICU70GBeEa0RDG\nBg1hqldf+u9wFVHZPs+ZmRnYuVM6zeXOnT9bTnNZu3Ydy2ku27RpW+x+HpVVZevnyshkks7jkpsr\nBVlODpCTI11b3ybdLt0GZGdLyzk5AtRqV6Sm5tgIwMLhd3cAFl1XEKrW7SrKypVZls3v9mD38N2w\nYQMWLlyIunXrWm776KOPEBwcbLM9w9d53U8/C2mpBeFaaDSriL9S5DAds4cnTBERMDXIC9eIhtLf\noeGV5qT8Fakyf56zsrIQF/crYmN/wI4d25GengYA8PX1RYsWkWjZshVatoxEixaRCA0Nq9Qj48rc\nz2YzkJUlhVh2trSclSXkBaFgFXZSEBaEoK1l27cVF6QFj2syyfP+qdUi1GppBKpSiXnXgIuLdHvB\npaBd/nL+xcVFzLtP0XVqtfVjFrST1nl5AW3amGwdCVhuPLczlds9+1kUobiZJIXr2TPSaDY/bG2c\naMLsH5A3ipXC1digIUwNG8EcVKtanwLRWT7Pubm5+P33PYiN3YK9e/fgypXLVut9fHzQvHkkWraM\ntARyeHjdShPIZelnUcwf0UmBqNcXBGN2tmAVjkWvC9rq9db3KQhX679zciq+jwRBhKurFDouLiI0\nGmnZ1VXMuw3QaETL+vy2rq62lotfX7OmFnq9/q4ALByu1gGoUlXNf/4MXyo3Sz+bzVDEX8nbVHy2\n0G+yZ6FISy1yP1NIKEwNCsLVmBe2om8NGV5F5eesn+fU1BQcPXoER44cxtGj0uXSpYtWbby9fdCi\nRUs0b97SEsrh4Q/Y3GGyLERRGqnpdAJ0uruvbd8mii5ISTFYQq9wgNoKUbPZ/ong5iaFnkYjXbu5\niXBzk/7Ovy58e36wFQ486/C71/qCx3BUyDnr59neGL5UJkJKMlRHDkN15BA8Lp6F4dgJqM6fhZA3\nz2w+UaWCqe4DeTs8ReSNZhvCWK+BtDsklVpV+jynpaXi2LGjeYF8CEePHsGFCxcAuAPwAOABrTYI\nDzzQEqGhTREcHIGaNetBqw1EVpaiSGjq9cWHqz02kapURcPPVgjmL+cHp5tb6f8uuL1qjvDuVpU+\nz/eDsxpRsYSMdKiOHoHq8CGojhyE+tBBKO/alKjSamGMaGS9w1NEQ5jqPlAhJ5ygysFgANLTBaSn\nAxkZAtLSBMvf6ekCMjKKjjCloNRCpwuGTtfHchtgnTh6PXD8uHQpLa1WhLu7CHd3oEYNs2XZ+rrk\n22rXdkdWViY0GunxNJpqsUsBVUL82FUnej1Ux49BfeQgVIcOQnXkEJTnz1nt/GSuUQO5XbvB0OpB\nGFs+CO9Oj+C2WwkHk1KlZDYDmZn54WkdmmlpUnCmp8OynB+sGRkFt+WflKOslEoRHh5S2Pn6iqhT\n5+5QlJbV6hxkZNzAnTuXkZh4Htevn8aNG+cgiukAMgFkws1NRNOmddGqVSO0bNkSLVu2Qv36Dcq9\nl3VAAHDrVoVv7CO6J4ZvVZWbC9XJ49KI9vBBqA8fgvLMKatTKpo9vWB4tCOMkQ/CENkKxsgHYQ4J\ntd4uFuAJcPORQ4mi9Ptj4dC0DkncFZgC0tJQaFkKUVEsW3hKe3yK8PQEgoLM8PIS8y4otFxwm6en\nCA8PEVqtdai6uJRl02qtvEs7AIBOp8Px48dw9Oghy+/IBw/GYf/+Xy330Gq1aNq0uWWHrpYtW6FB\ngwioOIQlJ8LffKsCoxHKM6ehPnIob0R7EKqTJ6yOmRXd3GBs3jJvRCsFremBevcc0bKf7092NpCc\nLFhd7tyRrlNSCv7OzFQhOdlsGZ0aDGULTkGQQlMKTxHe3sWHZuG/vb0L7uPmVjl/j9Tr9Thx4hiO\nHj2MI0eky9mzp2Eq9B9JNzc3NGnSLG+HrlZo0SISDRs2KhLI/Dw7BvtZwh2uqhKzGcoL56E6fNAy\nolUdPwohK8vSRHRxgbFps7wRrRS2poiGlWa2HWeVk1NykOYvF/67tJtu3dwALy9zMSPNoiHq7Y1C\nISsWeyrKqiorKwsnTx63jI6PHDmMM2dOwVjoxOIajQZNmzbL28taCuSOHR9Gamp2CY9M9sDvDQnD\n11mJIhRXLhca0R6C6shhKDILXreoVMLUqIlls7ExshWMjZtK2/7soKr2c04OLCPPwkFaeDR6d5Dq\ndKULUq1WRI0a0sXXV4SfX/F/+/lJt4WEVM1+dqTs7GycOnXCKpBPnz4Jg8FgaaNUKlG7dh3UqRNi\nuYSEhFqua9euA1dX2xMUUOlV1e+NsuLezs5AFKG4cd0SsurD0rUiJaWgiSDA1CACuS1bFWw+btZC\nGjZVczodkJQkIClJgdu3bQdp4dFpZmbpglSjkQLygQfMRYLz7kt+kPLtkIdGo0GrVq3RqlVry205\nOTk4ffqkZXP1hQtncOnSZfz11x8obtwRGBiUF8YhCAkJsyzXqSOFtIeHh6NeElVhHPnKRLh1C+rD\nB6x2iFLcumnVxhReN29E21oa0bZoCdGj+P9JVQS5+zkzsyBUExMFJCUJSExU5N0mWNZlZNw7TF1d\nC8KzpCAt3EZrewY2u5O7n6uL/H7Ozc3FtWtXcfVqAq5eTUBCQrxlOT4+HtevX7XahF2Yr6/vXaEs\nBbMU1qHw8fGtNGf0kgs/zxKOfOWWlQX1/n+gOrgf6vxDfK5dtWpiql0HOY/3LxjRtoyssmeDEkUp\nVPNDND9Uk5IKQjV/3b029fr7mxESYkZgoIigIBGBgWYEBBQfpNX8O5HyuLi4oG7dB1C37gM215tM\nJiQlJSIhIQFXr8YjISHesnz1agLOnTuDo0cP27yvu7tHoVCWRs/5f4eEhCIgoOZ9n92LnB/DtyIY\njVAdPgiXvXug3rsH6n//hpA3PRsgnd84p0cvy2+0hpYPQqxZU8aC7UMUgfR0WI1MExMVuHlTsBq1\n3rxZ8o5IgiCFZt26+aEqXdesWRCwQUEiAgJEe/20TWRFqVQiOLg2goNro23bR4qsF0URd+7cQULC\nlbyRc0Ewx8dL16dPn7L52K6urggOrm0VyvnBHBISilq1gnnYVDXAd9geRBHKUyfhsjdOCts//7Da\nKcrYtDlyO3aG4eFHYGz1IMzBtZ1qCCaKQGoqrDb9Wo9SC/7Ozi7+dSkUIvz9RdSrZ7aEaGCgaBWw\ngYFSqPLEWVSZCYIAf39/+Pv7W/3GXFh6elpeKCcgIeGKZVkaSSfgt99227yfUqlErVrBVjuF1ahR\nA76+NQpd+6FGjRrw8vLmKNpJ8TffclJcvpQ3so2Dy++/QXH7tmWdse4DMHTsgtxOnWFo3xGiv79s\ndZaG0Qhcvy4gPl6BhAQBV64oLMtJSSrcuCGWOOOKQiGNSvM3/dasab0ZWLqWgpf/obdN7s9zdVGZ\n+lmv1+Patat3/d58xbKcmHgDZrO5xMdQKBTw9fWFr68Uyn5+fpbl/KDOX65RI3+dL1wqeJNRZepn\nOfE3XzsQkpLg8ru0Gdnl99+gjL9iWWcKDEL2U08jt1MXGDp0grlOiIyVFmU2SzstXbkiBWp+sMbH\nSyF77Zpg8wT1CoWIWolNSOEAAAkTSURBVLWAJk3Md41SrUet/v6iXefAJKoOtFotGjSIQIMGETbX\nGwwGXL9+DdevX0NycjJSUpILXd+x+jslJRmXLl20OvFISTw8PG2MpouGduEwd3d3r/Y7ktkTw7cY\nQloq1H/+IY1s9+6B6sxpyzqztw9yHu8vbUru1AWm+g1k3YwsisDt24JVoMbHFyxfvSogN9d2fUFB\nZjz4oBmhoWaEhZkREiIiNFT6OzhYRHCwJ27d0jv4FRGRWq1GWFg4wsLCS9XebDYjPT3NKpALL9+5\nU/T2s2dPI6vQCXpK4uLiYmMUbTu869ULQW6uAHd3d2i17uU+F3dVxvDNl5UF9T/7LJuSVUcOQ8jb\n5CO6uSG3y2PI7dgFhk6dpWNrHfxhSk0FEhIUVqPXwiPY4nZg8vc3o2lTsyVQ88M1LMyM2rWlWV2I\nyPkpFAr4+PjCx8cXQL1S30+v1xcJ6sIj7Ltvv379Ok6dOlmm2jQaDdzd3eHu7gGtVpsXyh5511q4\nuxddzg/u/PvZauvMv3dX3/A1GqE6dMB6j+S8cyGLKhWMDz1sGdkaHnxImp26AmVmSuEaHy9YQjZ/\nOT5egfR02+Hq5SWdACI/WMPCCpZDQszg+QCIqCRarRZarRa1a9cp9X2MRiNSUlKKDe2srAwkJ6dC\np9PlXTKh1+uh0+mQlJQInU6H3ELnnr+/2u8O6sIhbyvIrf8TkL/s6+sLb2+f+66ptKpP+JrN1nsk\n//WnZY9kURBgbNYChg6dYOjUGblt28PeqZWTgyKbhfODNT5ewJ07tv8Hp9VKI9VHHpHCVBrBFmwa\n9va2a5lERPekUqkQEBCAgIAAm+tLs8OVwWCAXq+zBHTBcmbe33rLsvV66zDPb5OamorMzIxS/+59\nN4VCga+++haPPda9XPcvq6obvqJYsEfy73uK7pFcrz5yBg+R9kh+tCPEGn52eVq9Hjh/XoEzZxQ4\nezb/WonLlwWYzUVHry4uIkJCRDRvbiwSrKGh0vGu3MeBiKoatVoNb28fu442RVFEbm6uzXC2DvOi\n6wEgIqKh3Wq5lyoVvoqkRGlUm79HckK8ZZ2pVjCyhzyD3A6dYOjYGeYybGKxJTMTOHs2P2CVlqBN\nSBCKzKNao4YZbdqYUK+eFKjSCFbaRFyzplitZqMhIqoogiDA1dUVrq6uqGGnAVVFcerwFdJSof7j\nd2lT8u+/We+R7OuLnH4DpLDt1AWmevXLtUdyaipw5owS584VjGbPnlXg2rWiiVmzphkdOpgQEWFG\nRIQZDRtK1/7+FX4oNRERORHnC19RhNvC+cCOWPgdOFCwR7JWi9zHukt7JHfsJO2RXIYh5e3bQqHN\nxAWbjG/eLPoYwcFmdOlitISrdDHB19dur5KIiKow5wtfnQ7un3wIGI0wPPwIDB07S5cHH7rnHLai\nCNy8Kdz1e6x0sbXDU2ioGd27G/NGsdKItkEDM7y8KurFERFRdeB84evhgTv7j8M/LBBpetunXhNF\n6XSJ1qNY6XfZtDTrTc+CICI8XESbNgarzcX165vh7u6IF0RERNWN84UvADEgAHB3hzkzAwkJgtVe\nxfnLd09Fp1RKx8N26GC22lxcr56Zk58TEZFDOV34iiIwfbor/v0XOHXKA1lZ1iGrVouoX99cZKen\nBx4wc/o5IiKqFJwufPV6YMMGNbKzYQnZ/IBt2NCE8HDOnENERJWb08WUuztw4kQmAgM9kZzME/4T\nEZHzccrTO6jVDp/XgIiIyG6cMnyJiIicGcOXiIjIwRi+REREDsbwJSIicjCGLxERkYMxfImIiByM\n4UtERORgDF8iIiIHY/gSERE5GMOXiIjIwRi+REREDiaIoijKXQQREVF1wpEvERGRgzF8iYiIHIzh\nS0RE5GAMXyIiIgdj+BIRETkYw5eIiMjBnC5833//fTz99NMYOnQojh49Knc5VdqcOXPw9NNPY/Dg\nwfj555/lLqdKy87ORvfu3bFp0ya5S6mytmzZgieeeAJPPvkk4uLi5C6nStLpdBg/fjyioqIwdOhQ\n7N27V+6SKi2V3AWUxT///IMrV65gw4YNuHDhAqZMmYINGzbIXVaVtG/fPpw7dw4bNmxASkoKBg0a\nhJ49e8pdVpW1dOlSeHv/f3v398r6H8Bx/LkzubBxzDJaIblRSigXWHJBLlz7kRa3cqVc0FKUq7lS\nKAp/gLZwI0pZuZgr5UJRXGExy8evxgU6d6fOt9x8a3vbp9fjbrt61i5ee38+n7bfpjNsy7IslpaW\niEajpNNpFhYW6OjoMJ1lO5ubm1RXVzM+Ps7d3R3Dw8Ps7u6azvqRcmp84/E4nZ2dANTU1PD09MTr\n6ytut9twmf00NzdTX18PQFFREW9vb3x+fuJ0Og2X2c/l5SUXFxcagwyKx+O0tLTgdrtxu93Mzs6a\nTrIlj8fD+fk5AM/Pz3g8HsNFP1dOXXZOpVL/fJglJSXc398bLLIvp9NJQUEBAJFIhPb2dg1vhoTD\nYSYnJ01n2Nr19TXv7++MjIwwODhIPB43nWRLPT09JBIJurq6CAaDTExMmE76sXLq5Ptf+mXMzNvf\n3ycSibC+vm46xZa2trZoaGigoqLCdIrtPT4+sri4SCKRYGhoiIODAxwOh+ksW9ne3sbv97O2tsbZ\n2RmhUEjPMXwjp8bX5/ORSqX+vk4mk5SWlhossrfDw0OWl5dZXV2lsLDQdI4txWIxrq6uiMVi3N7e\nkp+fT3l5Oa2trabTbMXr9dLY2EheXh6VlZW4XC4eHh7wer2m02zl+PiYQCAAQG1tLclkUrervpFT\nl53b2trY29sD4PT0FJ/Pp/u9GfLy8sLc3BwrKysUFxebzrGt+fl5otEoGxsb9Pb2Mjo6quHNgEAg\nwNHREV9fX1iWRTqd1v3IDKiqquLk5ASAm5sbXC6XhvcbOXXybWpqoq6ujoGBARwOB9PT06aTbGtn\nZwfLshgbG/v7Xjgcxu/3G6wS+X/Kysro7u6mr68PgKmpKX79yqmzR07o7+8nFAoRDAb5+PhgZmbG\ndNKPpb8UFBERyTJ99RMREckyja+IiEiWaXxFRESyTOMrIiKSZRpfERGRLNP4ioiIZJnGV0REJMs0\nviIiIln2BzQKNGAGnBgwAAAAAElFTkSuQmCC\n", - "text/plain": [ - "\u003cmatplotlib.figure.Figure at 0x7f7a18df6b50\u003e" - ] - }, - "metadata": { - "tags": [] - }, - "output_type": "display_data" - } - ], "source": [ - "# Train our variables.\n", - "\n", - "# numpy is used for its asscalar() function.\n", - "import numpy as np\n", - "\n", - "num_training_steps = 10\n", - "\n", - "def train_model(inputs, labels, wb, optimizer, num_training_steps):\n", - " loss_at_step = []\n", - " w_at_step = []\n", - " b_at_step = []\n", - " for step_num in range(num_training_steps):\n", - " loss_at_step.append(run_step(inputs, labels))\n", - " w, b = wb.variables\n", - " w_at_step.append(np.asscalar(w.numpy()))\n", - " b_at_step.append(np.asscalar(b.numpy()))\n", - "\n", - " print(w_at_step)\n", - " t = range(0, num_training_steps)\n", - " plt.plot(t, loss_at_step, 'k',\n", - " t, w_at_step, 'r',\n", - " t, [true_w] * num_training_steps, 'r--',\n", - " t, b_at_step, 'b',\n", - " t, [true_b] * num_training_steps, 'b--')\n", - " plt.legend(['loss', 'w estimate', 'w true', 'b estimate', 'b true'])\n", - " plt.show()\n", + "## Next Steps\n", "\n", - "train_model(inputs, labels, wb, optimizer, num_training_steps)" + "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)." ] } ], @@ -572,7 +312,7 @@ "colab": { "collapsed_sections": [], "default_view": {}, - "name": "Eager Execution Tutorial: Working with Gradients", + "name": "Automatic Differentiation", "provenance": [], "version": "0.3.2", "views": {} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..84f1d031d40604ae029e8a8347474950ee01b38a --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/3_training_models.ipynb @@ -0,0 +1,485 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "k2o3TTG4TFpt" + }, + "source": [ + "# Training Models\n", + "\n", + "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n", + "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n", + "\n", + "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "3LXMVuV0VhDr" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "PJ64L90aVir3" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "tfe = tf.contrib.eager # Shorthand for some symbols" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eMAWbDJFVmMk" + }, + "source": [ + "## Variables\n", + "\n", + "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "VkJwtLS_Jbn8" + }, + "outputs": [], + "source": [ + "# Using python state\n", + "x = tf.zeros([10, 10])\n", + "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n", + " # value of x\n", + "print(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "wfneTXy7JcUz" + }, + "source": [ + "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n", + "\n", + "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "itxmrMil6DQi" + }, + "outputs": [], + "source": [ + "v = tfe.Variable(1.0)\n", + "assert v.numpy() == 1.0\n", + "\n", + "# Re-assign the value\n", + "v.assign(3.0)\n", + "assert v.numpy() == 3.0\n", + "\n", + "# Use `v` in a TensorFlow operation like tf.square() and reassign\n", + "v.assign(tf.square(v))\n", + "assert v.numpy() == 9.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-paSaeq1JzwC" + }, + "source": [ + "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n", + "\n", + "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "BMiFcDzE7Qu3" + }, + "source": [ + "## Example: Fitting a linear model\n", + "\n", + "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n", + "\n", + "1. Define the model.\n", + "2. Define a loss function.\n", + "3. Obtain training data.\n", + "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n", + "\n", + "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "gFzH64Jn9PIm" + }, + "source": [ + "### Define the model\n", + "\n", + "Let's define a simple class to encapsulate the variables and the computation." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "_WRu7Pze7wk8" + }, + "outputs": [], + "source": [ + "class Model(object):\n", + " def __init__(self):\n", + " # Initialize variable to (5.0, 0.0)\n", + " # In practice, these should be initialized to random values.\n", + " self.W = tfe.Variable(5.0)\n", + " self.b = tfe.Variable(0.0)\n", + " \n", + " def __call__(self, x):\n", + " return self.W * x + self.b\n", + " \n", + "model = Model()\n", + "\n", + "assert model(3.0).numpy() == 15.0" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "xa6j_yXa-j79" + }, + "source": [ + "### Define a loss function\n", + "\n", + "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "Y0ysUFGY924U" + }, + "outputs": [], + "source": [ + "def loss(predicted_y, desired_y):\n", + " return tf.reduce_mean(tf.square(predicted_y - desired_y))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "qutT_fkl_CBc" + }, + "source": [ + "### Obtain training data\n", + "\n", + "Let's synthesize the training data with some noise." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "gxPTb-kt_N5m" + }, + "outputs": [], + "source": [ + "TRUE_W = 3.0\n", + "TRUE_b = 2.0\n", + "NUM_EXAMPLES = 1000\n", + "\n", + "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n", + "outputs = inputs * TRUE_W + TRUE_b + noise" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-50nq-wPBsAW" + }, + "source": [ + "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 293 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 1210, + "status": "ok", + "timestamp": 1527005898290, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "_eb83LtrB4nt", + "outputId": "3873f508-72fb-41e7-a7f5-3f513deefe38" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEDCAYAAAA2k7/eAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJztnXlgU1X2xz/pAhRautCWUsCwWVlcUHHGBUFQcSg7uM8P\nFLUICo4VpygObihI3UdmUHBB0IGZQbEgFNGCqKgMolV2pKylCy1pukDp+n5/3LxmaUsDTUjSns8/\nbZKXd09C+b7zvvfccw2apmkIgiAITR4/TwcgCIIgnB9E8AVBEJoJIviCIAjNBBF8QRCEZoIIviAI\nQjNBBF8QBKGZENDYE+Tk5JCUlER+fj7+/v7cdtttTJgwgcLCQhITEzl27BidOnXijTfeICQkxBUx\nC4IgCOeAobF1+Hl5eeTn59OrVy9OnjzJ2LFj+ec//8mnn35KWFgYCQkJLFy4kKKiIh5//HFXxS0I\ngiCcJY22dKKioujVqxcAbdq0oXv37uTm5pKWlsaYMWMAGDNmDF999VVjhxIEQRAagUs9/MzMTPbs\n2cNll13GiRMniIyMBNRFoaCgwJVDCYIgCGeJywT/5MmTPPLII8ycOZM2bdpgMBhcdWpBEATBBbhE\n8CsrK3nkkUcYNWoUN910EwDt2rUjPz8fUD5/REREg+eRtj6CIAjuo9FVOgAzZ86kR48e3HPPPTXP\nDR48mE8//ZRJkyaxcuVKbrzxxgbPYzAYyMsrdkVIbiUqKkTidCESp2vxhTh9IUbwrTidodGCv23b\nNlavXk1cXByjR4/GYDCQmJhIQkICjz76KJ988gmxsbG8+eabjR1KEARBaASNFvwrr7yS3bt31/na\n4sWLG3t6QRAEwUXISltBEIRmggi+IAhCM0EEXxAEoZkggi8IgtBMEMEXBEFoJojgC4IgNBNE8AVB\nEJoJIviCIAjNBBF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCaI4AuCIDQTRPAFQRCa\nCSL4giAIzQQRfEEQhLOk0GTi84R7+XbIDXyecA+FBSZPh+QULtnTVhAEoTnx7YzHuDflUwyAlv4z\nizEwfNFiT4fVIJLhC4IgnCWhhw9hsPxusDz2BVwi+DNnzuTaa69lxIgRNc/Nnz+fAQMGMGbMGMaM\nGcM333zjiqEEQRA8TqHRiGb5XQMKjV08GI3zuMTSGTt2LOPHjycpKcnu+YkTJzJx4kRXDCEIguA1\nXJ/8OosxEHr4EIXGLlyf/JqnQ3IKlwh+v379OHbsWK3nNU2r42hBEATfJjQ8wic8e0fc6uF//PHH\njBo1iqeeeori4mJ3DiUIgiA0gNsE/+677+arr74iJSWFyMhI5s6d666hBEEQXMLRjAwW9u3FWmN7\nFvbtxeGMDE+H5FLcVpYZERFR8/vtt9/O5MmTnXpfVFSIu0JyKRKna5E4XYsvxOmNMb53xQhmZh1T\n5Zalx5h3ww08cfSop8NyGS4TfEe/Pi8vj6ioKAC+/PJL4uLinDpPXp73Wz9RUSESpwuROF2LL8Tp\nTTEWmkx8O+MxQg8fIjory67cMtZk8po4z4SzF0+XCP706dPZsmULZrOZG264gWnTprFlyxZ2796N\nn58fHTt25Pnnn3fFUIIgCC7FdhHVXFSZpcHyM8vGqWgKuETwX3311VrPjRs3zhWnFgRBcCu2i6ju\nBp4JDKR7QACZ4RH839dfezAy1yMrbQVBaNbYLqK6AOgaP4L4w7lMSt+NsXt3T4bmcqSXjiAIzRpf\nXUR1LojgC4LQrPHVRVTnglg6giA0WXy1jbG7kAxfEIQmi6+2MXYXIviCIDQZbGvqC41Ggg9k+GQb\nY3chgi8IQpPBMaOfE9vRrq7eV9oYuwsRfEEQfB49s/dbn2qX0V8QEcHiq/7odAWOyWRmxoyNHD7c\nFqOxkPffHwX4uzv884YIviAIPk2hycS/B1/HJVnH2In9StnK7heelWc/Y8ZGUlLGAwbS0zWmTFnO\n/PnD3RK3JxDBFwTBJzmakUHquOFE5mTTpbqaAcD1wDygQ1AQ1UOGnnVN/eHDbcHmHuHgwWDXBu1h\nRPAFQfBJUscNt3a2BJYDdwG9gRNDhp5TNY7RWEh6uvUeoWvXEhdG7HlE8AVB8BkKTSY2PDqVgB+/\nI8ZstvPrg1HCvz22I3ec42rZ5OTBwFKLh1/EggUjqapyTezegAi+IAhejz4pq23aQBuzmWHAAuz9\n+t8CA8mPH8Edya8RGn5uXS7tu7w3vS1aRfAFQfBq9ElZR/vmbuBZwGgwkN0hlqEr19C5a7dGjdXU\nJ22ltYIgCF7NtzMe4xKL2IPVvrkAuAgwjBzDpPTdjRZ7aPqTtiL4giB4NaGHD1GC1WDR7ZvktqFs\nbH85r2eMJiHhUwoKzI0ey2gstBtJJm0FQRDciF5u2anARGZ4BC169eIBlI3TBsuk7MbNPJ60Sdkv\nuQa279CApSxaNOasx7NdbNWhw0mGDn2P7OxImbQVBEFwF/rEbNba1cysqKjZSPyF6mo+GzWW0MOH\nOGHsUjMp62i/HD7cttZK2eTkwYSHh51xXEffftSopaxffyMAERHes/euKxDBFwTBK9D74HwO9u0R\nCs3E11FT71gzbzQW1RJvZ7L+ui4cTRWXCP7MmTP5+uuvadeuHatXrwagsLCQxMREjh07RqdOnXjj\njTcICXFuZ3VBEJo+u7Zt48vRQ+ladpqDBgNRrVtjAIqxL7fMrKfE0rFmPjl5EHfcsY2zFe+6LhxN\nFZcI/tixYxk/fjxJSUk1zy1cuJBrrrmGhIQEFi5cyDvvvMPjjz/uiuEEQWgCfDkmntllp5XMahpP\nnzyJBsQDy4BC/MhoFcawxcvqfH94eFit7P1cxLuuC0dTxSWC369fP44dO2b3XFpaGh999BEAY8aM\nYfz48SL4giCwa9s20sbG0+10qZ110x14JSwMf8LZZO7HKt6G0+Hs/8dSFi3q69S5z0W867pwNMS5\nzBV4A27z8E0mE5GRkQBERUVRUFDgrqEEQfAQZyN8+qTs6VUruUjTOIS9dZMNxAwczN8Pjyc9fXTN\n+87GUz8X8T4XzmWuwBvwuknbqCjf8PklTtcicbqW8xXn1Kmf2wlfy5bL+fe/76p1nPnECVbc1J/e\nmZmUAEOBpajOltHAAYOBsBtvZMz7i1g3ZZ2dLRMXV4qfXxUPPZTKwYPBdO1azIIF8UREnJ+Muq7v\nMisrHNu5gqyscJ/423Cb4Ldr1478/HwiIyPJy8sjIsK53ha+UAIVFeUbpVoSp2uROGuzb18QtsK3\nb18QeXnFtTL/+PIVzMjMtGuN0BUYDsxqFcRfjuQCUFEFs2dfT1mZ1ZaZPXsQ99+/qubCsnWrRlnZ\nUubNG+R2W6W+7zI21oTt/UlsbIFH/zacvdi4TPA1+65DDB48mE8//ZRJkyaxcuVKbrzxRlcNJQiC\nl1DfJKlueRgwcUH6FPBbb+fXtwF+A7a0CuLmVevszlmXLVNX6aQnbRVfneh1ieBPnz6dLVu2YDab\nueGGG5g2bRqTJk3iL3/5C5988gmxsbG8+eabrhhKEAQvoi7hKzSZCN34D17hUfIoYS4VLKu29+t3\nderEHWnfOd3Vsq4Liyfr58/XXIGrcYngv/rqq3U+v3jxYlecXhAEL8VW+ApNJj57YAKl337NDSgp\n7mD5GY+yccotO1FNfn8RuXknSUhY6ZQlU9eFJSlpQ7Opn3cVXjdpKwiCb/LtjMeI/fZr7sKayb9k\n+RkG3AkstuxEFRYRwr33fe60JaNfWPS5gTvu2Far742v2CqeRARfEASnqasMs9BUwOJxCVyanU4+\nUIgSeAOqffFLQPuwMAwDB3N98muYTGamTv2c9evB1pLJyPBvMOOvr++NyWQmKcn36uLPNyL4giA4\nja3g/pqeT+DqP3BJ9UHmY83ql6E2J9GAn4Hc9pezLCqRblRzHX4251iGrbNvMh1mx44nOVPGX59v\n76t18ecbEXxBEJxGF1wDJ5jM5fyjOrNWs7NC4G3gd/zZenUS3/74ol0LY6toK2c/KKiCIUPgwIE4\nsrLOPAnrOHl7/PguDhzowaZNucDnqE488U26AVpjEMEXBKEW9a2g7djhGMb0eG5mHa3R6mx2to6r\nWcUPwCrC9uRbXjEDqaxfD+HhO4GBNe8oKytl69Z8evUKtjuTPglr36++nPbtnyY39yrgJFlZUxg7\ndgFm85PY3mMYjZVn/BzNFRF8QRBqoSySEcA60tPD2bp1CY/cW0nv1GfpDeQBuWDX7Kwc2A2s4l+W\nV04C+ZbfU4E7KS01UFqqERT0DKWlLYGZVFcbyMrSqK5+gVGjate2O9o1YWGvACNrYi0o6ITtPUZY\n2GmSk2+u873N3eoRwReEZsDZZrrKElkH3Ik/WxmSNYaCOdX0B0qAB4BVwNNAJ9QFIB9Y1vZeKNoO\n/Aj8iWuuWU6LFktZvx5KS62iXFraDyXS1ucKC411irGjbw/tsL0TCA8/Smmp9fHAgQE1n6059bp3\nBhF8QWgG1JXpnqk1gdFYyK/pp4nnEv7ATiKBUGCA5edyIALoDOwliNfIJCzsM7744g/MmfOz5Zyr\nSU4eTnh4GAkJn5KSYmv8nLT8tBXuzDpjd/Ttr7mmmhYtrHcCM2eOYs6cule9Nqde984ggi8ITRDH\njP7AgTY01JrgxImX+emnYsrKuhKifcxf2EAw0BdqGp6lAnehWiMUA7tpxRtsB8Ixm1vx7LPf0aJF\na8s41nYrtgunjh/fRVbWFEs8y/DzKyYm5gQrV1ptGltqL7q6pdbdyaJFRiff27xr9UXwBaEJ4ijm\nsbFzcJwQdbQ7Nm8uAm0ioxjFzeykEHgCa06+HNCnVf8H/MKlrGUssMvyTDybNy+kqOhBHD1z2xW5\nBQVXMmvWOvbtC8JorCQ5Of6M9lJj2hj4agsEdyGCLwhNEEcxj4jowlVXnbk1gb92kkR6MM/yzCrs\nnfM2wK/Atxh4mXuB14A1qJ6X6hynToXSkGceHh7Gv/99l090Hm1qiOALQhNEedcFqInXlvz++06O\nHGmFn18nOnSoAuztjuh2+7k07Q36Y5XrEuzLLb8DXmYDMAhQ1TLV1aUUFS1BOfoltG5toqiocZ65\nlFK6DxF8QfBRMjIOM27cKgoKOhEefpSVK0fRtavyspOTB7N16wKyslR9elnZGMrKlgHxpKauZfPm\nLwgOziW8bTsuP/4SxvTddMZe5IcCs4COwC78mc8e4D+Wo4rp1CmW7t0rSUmZgC7w/fq9xZ49cy0x\nZTJzZm1fXm+toCyd2oIupZTuQwRfEFzI+cxOx41bVSPopaUaI0c+w9VX9yArK5zYWBMREUa7lasQ\nhFoDO4PiIhN9i/7ENVk/0Qb4G/AKcDvKq28DbAPKgHXczKqaupyLgRGAxv79s3jvvTuxnRQtLw+0\ni2nOnKW1JlQbEnQppXQfIviC4EKczU6dvTDUd5zJZCYnR8O2nUBeXsuasdUuTHOxN2X2AH3wYz/j\nuYgYNHoDpZYjwlBVOCGWM5qBv/Moyqu3LacEMHD6tCrBtP18Q4akoZorpALBbNqUQ0GB2e6z2Qt6\nIZs25TJkSFrN55NSSvchgi8ILsTZ7NTZC4PjcWVl7wGQlpZLdfVMbNsJaFqE3dgFBbG0ajWL06fb\no0Q4GH/SmcooWgNXo8wZ/Qy3AWuBTGB/y1DSus7B//dDVFUtAyqBY8Bky/mV+Dt+PiXWa8HSJNls\nHk5S0lK71saHDlUCHwPDgLWYzY+Tnm79HqSU0n2I4AuCC3E2O63vwmCb0cfE5PH99/arUX/80Q+z\neSLUallWTnCwieJi69iqdcFsYmJe4FTO10xhA3HAPuBFrEK/BJiLMmwOGAL5vPsLxPVpz6fJg3n0\n0dWkpgLkoMR+Hcrw2QU8SEzMJ3YtjWfOvJJNm/6H2Xzmjpb6pC+0q3WslFK6DxF8QXAhzmanHTpk\nk57+L5SBUkSHDrZ7waoeNtAeVd9uFfGiohzL744ty/IpKTlOy5azKC/vhqYFoaZdDbSqzOMeNjCX\nusstw1E9cF7yv4viqo9hv4Hd+zVSU5+ksjIEg8FAy5a5lJXNQ9PiwLLQKiZmPgZDJCkp92N7pzJw\noL/dqlr9oud4kevS5UKMxsI6jxXcgwi+ILgQ57PTQLDbG+o9TCazpc1viuX1AcD1wDwgBmhBdXWU\n5fjrUHl5NKqTzd1o2mbKyu5CtTK7k0BWkMjtdM2HFtRfbvk9MI+HoeoKm6MKKS9vA1wClHD69KXA\nBJt3LeP06WNkZ3fA8U7l3/++krouenXd/Yh9c35xu+APHjyY4OBg/Pz8CAgIYMWKFe4eUhC8nuzs\nSGyFMjs7khkzNmI2P255vgBVUdMH8ENJdjzqYvAekAHMwX4dbIjl8XX48xBTeJvLLM9uwb7c8kng\nAiCdIBaRALyB/YYka1G1O/r5P8T+viAEaFeniNd30bMV97i4UmbPHiT2zXnG7YJvMBhYunQpoaGh\n7h5KEHwG+4VRbTh+fCdVVRdhFdV1wAzL4+G0aJFEefkBVG/KAqAH9gIcDBThzxb+j6uJRTnt+j1E\nf+ApoCfKfd9LIPN4CUjEKub6VuOlqEla2/PnYX9fUMw111Rb2hA7l6HbintUVIistPUAbhd8TdOo\nrq529zCC4HHOpgbfcWFUVtYITKZZwDisjQysgtuyZSTl5UlYBVffHlx/vJcW/EoiH2EE2qLuC/Qz\nhANGlNjPYxgwGnXXsAw4hf1W48ss77I9fy7qziIAP78sbrklnDfeGC4Zuo9xXjL8+++/H4PBwB13\n3MHtt9/u7iEFwSMkJq4hNbUt4E96egDl5Z/z4Yf/V+ex4eFhREf3tlsYdfp0F5RfHw38jvLvwwGN\n4uJg7DPuzsALQCcCWMOdfEJHqJmYreuSsA94jf0og8d2/mA2yh5S7REgwTKOyvZbtjxAWdlUoAug\nMWKErHz1Vdwu+MuXLycqKgqTycTEiRPp1q0b/fr1q/f4qKgQd4fkEiRO1+LNcZ44Yeahh1I5eDCY\nrl2LWbAgnoiIsFrHfPVVNqA6RYLGjz++WvO5Tpww88ADKWzapAF5DBgQRocO2PnfaguRGTaP56E8\n/BIgG3v5Pgp0pxWf8hc+IRi4FPtLQk9Urm4GdhDCAn4BuqPyfNsjL0c1QHseqEB1vDcAd9Kp0zx+\n/fVxpkxJZd++X8jP38vhw0amTl1d5/dwNnjzv7ktvhKnM7hd8KOiogCIiIjg5ptvZvv27WcUfF/w\n9XzFf5Q4XUNCwqqaUsmtW4P57rt/sHHjBMLDw2r62eTktKO6ugu2QlpcHMK+fUctG4Cssus5k5Ky\nBPgNeBWIRAl4H+yFuDd6GwNYgLVBcQkBHGIqM7kQ5ei3pnb1zS6gCEjmJ9Qq226Wcxc5HKkvv7rC\n8rt1nLCwzlRV+TN//nASElaSnj6DzEx9Edi5Z/re/m+u40txOoNbBb+0tJTq6mratGnDqVOn+O67\n75g6dao7hxSEs8IZ3z0jwx94B1XXspOsrF4MGrSEjRsn2PWzUatHrUJaWdmKQYOWEh3d27K61FbM\nNVT3Guvyp8DAX6ioGIO9ZBuAdJSFcyd+7Gc0/ehOUU0bYw1l7tyLtQ/O98AxgviI7aisvrvlqF7A\nXtQU7oVAS9RkrS781cDdNWfu3n1pzfcgPW58H7cKfn5+PlOnTsVgMFBVVcWIESPo37+/O4cUhLPC\nmRYHJtNhVCHjcvQtQbKyNJKSllJQYFuHPgyVscehWo/dR1bWf8nK8kdNehage/JwANs+OFBJUFB7\n2rV7kZycSNS0613AZtQdwAECuZ1EVnARqqmZ7eWjI+oeIBz4BQMvk4Ta+1XP6kNRtf0VwHPAEWAp\nEAXMx8+vkN69+9K5cxHwHtnZkbJdYBPErYLfuXNnUlJS3DmEIDQKZ7LWdu3iLJOr9hOnq1ZVomm7\nsWb1eulxgeXnRqADavJ1OJCEqoTRd4PNRU3QLgBKKSp63tJL/hmU4P8XmI4BE4MZSD920hvV0cYP\ne1PGhDJqnuIOVOb+PKp/zjKgHFWRU2zzGb5HZfnqDDExc9mwwdrKWL/zueOObTV3PrJIyveRlbZC\ns8aZrLVbt5Ns365qz21lVrUviEOJahCqQUEgyjL5K9a+MwuAKajFSvYNz2AkyqefZxnNgLJdDgNt\n8eN1pjKDi6gkFHUPEYqa2l2GtbNlJvAmM1C1OXrzhDCUPbPaMsYCwsJ2YzYPx/Hi1a5dnN1nru/O\nR6pzfBsRfKFZo2etGRmtMZn2kZFhJCHhUzsv33qMH/v3P83p00bUQqRhwHpU24NiVLZ+P8qqWYeq\naTegxHYJyj5xXK2q/x6MkvA2qJLMSIJYwiNs4VpqbyLeFdiBWob1Bd1YxR2oOwioPX2rHgcGmtiy\nZQJJSUvZtCnHIvzqmG7dTtl9L+LXN01E8IVmjb5w6J57PmbHji5kZYWwY0cOP/zwHuXlFwD5XHNN\nMG+8MYJZs75jx47nsbY+eB3ohxL7oSj/Xq+jz0ZZKmGWn0eAVjiuVlVoKKPmYcBAC/L4Cw/QAlUh\nb9s8Qd9E/DCQg4G5bEU1MzuFaocQgrJwnkDdfRxH1c8vY8CAkJrPW1BgJimpfntG/PqmiQi+0OQ4\n212nTCYzX32VhbWG/l8cP/4Mutilpi6jRYuNZGWFY93cIwNV6a7L8XxUhYttHf0ylKWi96XRPfVi\nlOtegrJgwlGZfSEB/I9HeYCXUFOqtvcDbVDSvhmYxzxURq8Bn6CqbabYjP0CMBbdVoqN3cE//zm+\n5jM3tEJW/PqmiQi+0OQ42z1RZ8zYSEVFP6zyGoKj9ZKSchz4BiW5T6KyedvVqi+gJktt31cIvI+q\njLH11FehBD8Y/QLhzxbGE04HrF1yjlG7q2UJLfgHPwArLec5iTJ42tuN3bZtB6677hNLtY2Z5OTx\ntS56Z7owSsuEpokIvtDkcPSfMzL87TbpePLJK5k792ebTUayUTX2+i5MjguTTKiKmj4o774QVSrp\n2Opgj8P7TqKmWG1ragpQ9fVRwDEMHORmRtGXHXRB1ebonW3uRuX/eqOFDYRyqtfrxBauo23bzhQV\n7aBduzgOHy6nqGgnaq5AjT1oUIsGBVs2C29+iOALTQ6r/1wIrGXv3kPs2KGqY9LTNdasmUVl5eya\n15V4ZwMXofZvLUJZNsrDV4+fw75LDdiLeybKdHkS1aYsFHjA8vM9VK+aXqhFVOpcLXmTKXSnDfZe\n/RKsPStPAj8Bb/MhsbGZpG+6tdbn7dv3LYqKpqAvuwoK+onk5IRaxzkiE7PNDz9PByAIrsRkMlNe\nXkFY2AcEBLwCDKWiwr7LTGVlZ8vjVNRkaxFqknMsSozboiY6W6BE+yrss/k+KL9cL4FcjppwDUDZ\nQS1R+XmY5fho1H+1/6F8/+W04s88yqNcBfzB4ewRqPqeA0ApLXibn4AJhIZ2r/Mzq5LKcJTFNJKe\nPS8/45yFjtFYiLrEgEzMNg8kwxe8nrOZhH300S9Yt05tuWetbdFw3A5Q/QxGTWr2xl5y+6Hq4/X3\nV1PbqgkDLkbZKDp6q4IdDsf/BDwGrCKA9fyFJfRE4xDKvsHh6L3Ad0AyM7Dtf1lYmFHnZ7auE1DH\nXXjh6TN9nTXIxGzzQwRf8HrOxmv+8UfbLvB6bcsAVHVMIcq6Kbc8PoaycRzr1k/avL8Y5bnvRmX9\nB1CLqqC21/+75Zi7gVmoydTjwP34cZw7uI8LqKpZLfsAyux5DGsPnP8BvxDLWraj7gqWoyZ9Aykp\naUtBgbnWxc5RuBcsGElVVcPfq0zMNj9E8AWv5+y8Zj1710V4p817v0cJcg9U1h2E6lLZBiW90ZZj\nJlse630oo4CHULZJAaqRWm/LuZdg7SMfClwLfInK9A1AJa14iGmspSWq4YFt82Mj8E/LmX/AwFv8\nRHT0ajiul4BqqN2n/CkqaklS0sZaIu0o3BERvtHhUTj/iOALXo+zi4BMJjMtWxZjbTlcSfv2p+jQ\noYqYmFOsWxeNmjgNQVXbPIG1cuYN1H+HauADVNOx6Vjl+VUgFtXorA/KytmL/cbeT6IuAN2Bv+HP\nVhKYRDBVhKHW49ree8SiMv1iYBVhFF74MqN676CkJJy0tGVAlkMMS2RiVWgUIviC1+Ho2c+ceSWO\nXrPtMR06ZAOB/PCDH2ZzT/SOM4GBc7jiigt4440refTRNahM3LZ2XpffdcCzNs/PcXjdgLqABKP6\n4kRZXg/HWk+Ti6rq6QQYaM10pvI6ccAh1KXAcQeqXagcftfVf+XzVbNqPv+QIWmoLQhXO8QQjtFo\nbvwXLDRbRPAFr8Pesy9g69YFREf3tpuwTUhYaXPMv7AX8uXAXVRUXEpqan9++WW+peVwFUpiQdkx\noGrsq7AX1hhqL3tqgbXR2WzUHMCtKBtHb5u8kAC+ZzjzuAhqvPo4y1nuxtp4YR/wHZEcaD+Jrz+c\nbPf5rXc09s3aYmN3kJw8HkE4V0TwBa/D3rNfR1bWk2RlWSds580bxKZNucBnqLr2AOBDy/GjUR0r\n56JWn75ETs5L2Fe5t8Bq59S1+2suKovXNwjMBfoC/0JZOlGoOv3lqHmAUYCBEN5iCjsJwbbxMDxt\n+WlEratNAl7hCoYOfYCvLRuB22Jt1uaPyTSXdu3i6NbtVJ2rZQXhbBDBF7wG3aZRu0Ppq17bYJt9\n792rcdllb1NW9kfURGknVL2LrdeeidqnNQJrWwMsP09TO6PvgrU12V7L43hUnX4R9nbPMlRWfxKV\nvz+PH/uJJ5o+VDAX1SvT9uzdUReAjqgan9dYQlBQFR9+OK7O70GqZwR3IYIveA22Vg5otG37EuXl\nJzl9Wm8ZUMC+fXuprn4Re4G3ldeLUNOhQ1HevGPVTjFqktb2Ob2RgV4FvxtlEd2Ftbe8fv5y1F3E\nt0B/gunNFPbQBWtdTrHD2fcAJ4C5PI1a2KURGvqi6744QXASEXzBa3AsvywpiaK6uhxYCORjMBRS\nXd0fewFuR+3e7yFY+9G/i/1WIUUoF/0pVO69H2XLrEb5+WGoCdqXUP89CrHtUaNkPRQD2VxPF66h\niDhUBX6X0UVuAAAgAElEQVRLyxHxWHtiHgR+IJhv+NYS0xLgd/r0EWtGOP9IawXB45w4YSYhYSUH\nDuzFdql/dfUhVOVLS+AhNK071kVSYF3s9CyqNn4Z1lYJuhV0G9bSS/189wDXAPcB/ijhH2455/2o\n7cCfQOXl01F2zyrURSKHAJ5kEg9yDUX0Rjn8D6LuJf6Gala8C7WI6oer/8qivbsIC/sSVc4ZCEzn\nxIm62yQIgjtxe4b/zTffMGfOHDRNY9y4cUyaNMndQwpegG3ZZExMHgZDJdnZHepsjfDQQ6kWK8ex\nX/x0rJt+L0fVzt+OypL1TUNOowTeH+XNL0CJ/SFUZh6GyvR160evrNmCEnRQUv0qtdsi2/aoAX/2\ncieP0Qkl246LqK6wRP078D3h7G8/hc/evIvw8DAGDowmJcW6w5T0rRE8gVsFv7q6mtmzZ7N48WKi\no6O59dZbufHGG+neXbKbpo6jH6+EfDTp6Rrl5e/QokXrmjr7I0d0K0fvF78ENXG6DjWRql8AilCZ\nfABqQrYUJdIXUbsscwLwIsryse1cqS+gMqIyeX3B1KWoUk1be2h/zWMD+dxOEp1Qa2mzqb2Iapcl\nor/zHPA05GrMmbOURYuM0rdG8ArcKvi//fYbRqORjh07AjBs2DDS0tJE8JsBGRn+WCtfirGVx82b\n8ygq6g74k54eQIcOv6AmQnWhPWY51rZ0ch5KmF8GrkZZO9NRG4E4ZubBKHHvbvndtsGZXrnTEmuZ\nZXeUXD+OtavNVmASBhYxhme4kZyada/hqBoix0VUR4BlrETdbahY9JWxUnkjeANuFfzc3Fw6dOhQ\n87h9+/Zs377dnUMKHka3cvbu3U99PeSLiqqwzchPnnyRsLBXMJs7oCpkOlN7pWs0ag9Z2wqd5aiL\nQ0vs5fc3lGUzHXVn8aHl+TxUM7OHgR9Qwr4AlZf/EVv7BvJoxTKmMZN5DiPehSoYnYOq9P8dSOav\nQDLWuxn1WcW6EbwJtwq+pmkNH+RAVFSIGyJxPRJn3Uyd+rnFyvkMW8E2GNqiaUtQAh2DdW/YYIqK\nqhk6tA2pqX6orQIN1M6hg1Btix07YZajBFvvn3MUlbFnoCyhLOy3F1mG6pXzrOW5EahGadZiSj/2\n8Sce4BKUfeM4Iqj7h2LgZ+C+las5tKyYgwdX07GjCU2rICtrNV27lrBgwUgiIs7/34ov/H36Qozg\nO3E6g1sFPyYmhqysrJrHubm5REdHn/E9vtDlLyrKN7oReiLOffuCUNKod3pUQqtpLVFTnX1Q2fda\nrFn+cNLSnqBt21YUFenyOgwl4hEosR9qeY/tRWAr1lYJlUCZ5fl01DKnO1HzAbaSHYK6IDjePagW\nyi1ZwZ9ZSRRK7B0bJ+9EXbIOA/P4KzCPqsX1t2uuqjr/f9O+8PfpCzGCb8XpDG4V/EsuuYQjR45w\n7NgxoqKiWLNmDa+99po7hxRsUOWOq5zaOMRVqD4wBajVrh+ibJTTqHLIO1HSeT3wH2xFt7z8QgwG\n6ySpyqFjUcuWdGtoKMoaCkNV1kSiBL81+mbg1sVYRcBbqAzfceGVfZ8cP78f8av+HxN5kWiUQXQZ\nSuyHYu/qlwDbCWAZe1AXDqSDpeAzuFXw/f39mTVrFvfddx+apnHrrbfKhO15xFrueP42qU5OHszW\nrQvIyrLtJvMSyh/XbZxAVAXMIpS9UwSUU1b2V1Stew+sE7ftUNXtF6JEvhR1Z6B78Ccs4ziutu1v\nGddoOacR5d/HoCqBVJ+cmBgThTmnmcrrNVO8rbGK/TqsG5OUAIb7JnHqxLWQ0s0ynvj0gu/g9jr8\nAQMGMGDAAHcPI9TBwYPBOL9xyJlxdpvB8PAwoqN7k5VlK8DtgV9RkqnbOONQojsCdVF4EXVR6Im6\nIPwN6wVjBipTvxh157Ac1YuyBEgE3qb2att1KMG3nW69HVXW+RtgwJ9D9Mx5mb5Qk9lXAL9YzqqL\n/fdArrErr//8ExVVgRQUmJESS8EXkdYKTZiuXYvZurXhjUOcwXGbQb2WPiOjNSbTXiIiutC9eyXJ\nyYOJicnDXoBboeri11HbT9d/jwIWo5oRnLa8pxRVNtkVtSh8JKoLpm255nLUBWW25Ryhlvd84zBW\nBaq0cwYQTkve5EFeJghVgW9bxT8Pa9u0X4G+H3zMjcNGEGbZSUpKLAVfRQS/CbNgQTxlZWfORPXM\nvS7hts3gHfvc/PBDMWbzg+gymZX1Pjt2BLFmzVpURfoLqJbCe1GLnlRFTm0/HcvvIVgbmC1Bib6+\n4YgJlYNrqDsAx7qZwyg//3eU7/8ZyhKy9sDx89tD9+4XcOD3KQzj31yE2tMqHzUlbHvGCNQ9QC6w\ns/f/MX2YbR2/IPguIvhNmIiIhjNRxxWxWVnL2bFjZK1NRxy3GSwpsb0AFKJE/lkqKx27wFdYfgaj\nJmuXo7L3LagFSgtQ3vpfLOcyoKpt9K0D9bJJtayp9sYkO1Fi74eaatVQ62BNGAwLMBgKCAgopLz8\nSTJ/f5XH+DctsC/UdOyGvxdY3fmv9L7iYj4Su0ZoQojgN3McM3clzKvsNh3ZsuUFIiO70bLlLMrK\nugIFVFaWohqSrUMJdBeH8/RC9YzvgzJJ/FENyu5CyeoW1MpWfd1qqOW9GsqnL0RV46iWxMHBwbRu\nncHJkyGcPDkLuBJ1FzAFa0Y/E1sZ17TOaFoorQK2MLG8KxEUcrHlXbaRhqIWUYUDu6KiefS7//FE\neIQLvl1B8C5E8Jsp1s1Gcqg94Wm/yjUnpx05OX7AH1AZ9RSsbvdc6l4odQRrqeQIm2MvRpVq9gJS\nULtP9QdmWc5/EjVluhblxa8FWtO2bQGXXRZJaupk9L48+lixsVmUlMTY1PAb0Dcab8F7TDr1Fn1Q\nS7JKLKPbRhqGarV22YrV3DZgoAu+XUHwTkTwmxG2lTbHj+8kK+shlOwtIyTkFOXlBygr80ctWtJ3\nnApFudlTsIq33mCgN9YLg75QKhrlpV+OfR6t7yk7EiXYek2+vvq1o+U1nWLgH+hZe1aWRnb2U8BS\nlGe/gKCgIMLDs4mIMFJdfYCiIpvaevZzEy25iHIuR80QBAKnLL+/iJrizQUyW7Zk8jdb6Ny1G4LQ\nlBHBb0bY+/WjgPctrxRQXNwG5YPbNv3VO0teQG3bR29yZrtQqhglpzEoJ9w2j24DVGP1823PV4i6\nSNger0u09ThNuxp1UVCWTXi4ucZ6ggJiY+cSHd2bnN8/5s8nV9ADlc3bVuC8iqrSH44ylHq/9TZT\n7ri7MV+rIPgMIvhNGMeVthkZAdgLbQFK0O+zPLbfzi8oKJLQ0AxycsB+Zep2AgK+oby8M0pC26EE\nXpU8qvLKu7GuUf0NdRFohVoEFYSSXF2GQ1H5ti7HJSg7Zz72F4GdqBYIYfj5taekRP8cAOFEtA1j\n4P4JtDhZXNPw7ANq32f8AmwA/mgptxSE5oIIfhPmgQdSSElR1S7p6RrR0S9gK6ABAcFUVtq2Frbv\nHBMenkV09CXk5NyAEu8yoAXV1X+mvPxfqIlafU3qfJTYg8r030bZNDstj5+ynONFVEbvKO6foTJ6\n2wtBEQbDU5bM/iQwGVXeeSctWhykqKhnTbxBPMWf9syhB8qmsZ3yrVXT87fnmPlIoku+Y0HwJUTw\nmzCbNtlPvppMkSi/3AAcoqoqFNiBVWSHoiZXewO7aN26DXv2bANyUPZNN5QL/jFK7HeiRPtt7Gvs\ny7BfHDUHqxVkQElxLPbibsDffzdVVXrXSwNt215ARUUwpaW23n4psbFzadu2M3v2DMOfF3iAp4lE\nTfmWoNbTrkXdY4xCFYgagX1Az7feZqRYOEIzRQTfx6ivxUFdzzvWo1RXF6AmX5cBT6BpytYJCHgG\n6EBl5WGUtbINuICMjN/RtBmo0kt9kdW/UBuRLMde1Gdh3SzcfkMSgyHC0iq72CaeocTEvMjp07EY\nDCauvroNYCQ19YGacw4atJStW49SWmr9DLGxOaSnTyMh4VP279nCgzxNR2qvvT2NMpbyLaMa3nqb\nv4rQC80cEXwfo74WB5s2VWI2twRuID09FFjK1Ve3JDX1JZS1coyIiALy8x0nTcMJDu6C2TwRtcI1\nEHgMNUmql17G2hyvi7njxGtfVG/6LNRCKqtI33ijgV275pKV1cVyvjhiY/ewceM9hIeH1bSgLSgw\n06KF/cpgs7mQMWPmUlDQifDwTFauHMnRjAx6bnyEKyiqKcB0XHt7CDUzkN82lAlfbpIKHEFABN/n\naKjFgV4yefhwW7p00YBpNa/17fsOO3a8YKmpt9opRUU5qOy8LepPwlY+9SZluoAXYW2LYOuOH0Jd\nWKpQ3S5fom3bKAYNakFy8jAAkpI2cvhwT4uYj6/VfK2uHjXh4WGkp0+refzh669wdO7zRKNaIDiu\nHNCAHwEzEP3W20yXrF4QahDB9zEcWxyoCpnaJZNGYxHHjkXYvZafH0OfPhXk5LRCTZqGAK2orn4I\nlQ8/i6qksfXWT6ImVZfTqlUZoaEZ5OYuQS2YeslyjmKUp1+NukNQq2mvu+49OwFvTMOxXdu28Wn8\nYDprGiGWT51nGd22Z/1m4GhQax7/+nvJ6gXBARF8H0H36A8caENs7BxLk7MqysurSE21XgDCwvYw\ncGABTz55Bbfeuhpb8TYai9i06TQwFaugv4+qfAFlyVyCtY/8LtS+sGHAnUREzKWg4EJUnxtFQMBz\nVFY+jaqLWYtqoaA2B8/OjnTJZ1/98VL2JD7MGyhhn24T/VxUfVBHVGfLuLfe5nHJ6gWhTkTwfQTH\nJmeXXvoe0IKjR1sTGzuXdu3i6NbtFMnJdxIeHkZCwkoyMyej574xMb9SXh5JUVE7lH0TjxLyAlQd\n/nKs1TS6NdQH+CcBAZFERBwnK2sq6uKgoQt8dXVny/kqsDY8U6tnjcZKwPle+nXxzovPU/LmK/S0\njOLY2bIT6rK0u20od4lXLwhnRATfwzgrho7e/Y8/+mE2Wy8AV11l3c3KZDKzaVMlqi7+LgCOH99h\n6UNjK+h3Uv8kbAVwjKFDI/jww7sZMiSN48f15z9E7Vg1nerqcMv5PrR7f1jYaZKTbwZqTzQ7s/NW\nWspnbEmYQBuse8sOpfZWJzuB61es5o/SA0cQGkQE38M4K4a1vXt9az8A+92sZszYiNlchbJWQoAi\nqqvD7I4PCqrA3382JSX+qBW2O7H37gNRxY7v2Yy/FvssXu+pY8DP7xjV1db4Bg4MqLlwOV6sGtp5\n6z8LF3D0bzO4Cvu2CMtRl6VZqLqhA35+jFi3kd59Lz/j+QRBUIjgexhnxVDV1VtLFsvL29h590Zj\nUc3dwvr1oNab2u4rOwfb3HjIEPjii3KsneGvB55BbczdApVPG2p8+OTkwWza9CVms2MBJIDGLbdE\n1Cqp1HG8WNW389aWDRv45s7RdEXtcWVfza9G247aB6vlW28zQ7x6QTgr3Cb48+fP5z//+Q/t2rUD\nIDExUfa2rQNnxdCxZLGumvWkJFuf374vjiqvXEZY2GkGDgwgOXkQ69dX2RwTDlyFqrixdrLU4wkP\nD2PgQH9SUmwXQe0gOrqamJh8gHptKceLVV07b+kWzlUood9B7f2xvgcKWrfmwY1SgSMI54JbM/yJ\nEycyceJEdw7h8zgjhnWhXwD0rP6OO7ZZetvrXWRKsG5Q0gb4BX//Mq65xkhy8gjCw8MID8+yW8Wq\nO+V610nHeGrHOr5mgjgl5X7qs6XOtAfsrm3bWDlyCK0qKojCauH0x9pBPwzVmu1SaYsgCI3CrYKv\nVmoKZ+JMYujMhO6jj37BunVKbK37wd4DDMVgeBlNe9Hy2giqql4lNXUKLVooQV65cpRlFWsHNO0A\nXbv2IC5udZ2LovRY580bVBNTUtIGkpMHn7VHD8q++e6uMcRpGq1QbdG+xv5+oyewBwiZ+wp/u39S\ng+cUBOHMuFXwP/74Y1JSUrj44ot54oknCAkJcedwTQZd6Otql+B4cfjxRz9sxTYg4DQXX/yZpea+\nu4PnrpqS6YLctavRbhWrM9Q1yWw0ak7ZUjqrP17KvsSHa/bK0uvpY7G3cHYAwdNncKeIvSC4hEYJ\n/sSJE8nPz6/1fGJiInfffTcPP/wwBoOB119/nblz5zJnzpwGzxkV5RsXBXfFeeKEmZtu+pjMTH17\nQGs1TFZWeK1xDYYT2MpkSEgxv/zyIACjRy+289z1n3FxpWcV/4kTZh56KJWDB4PZv78a2wtMVlY4\n69Zdz5Qpyzl4MJiuXUtYsGAkERG1z3/49995+7rr0PLyuBb7GYYY1KaFy1BtEQ4FBjLhhx+49Mor\nnY7zfNDc/z5diS/ECL4TpzM0SvA/+OADp467/fbbmTx5slPH5uUVNyak84Le7MsdJCSsIjPTdutA\na7uE2NiCWuNefXUbUlP1LpXFXHGFH6NHL7HYQBUMHfoemZlhnDixj4gII927L2X27EHk5RU7vQYg\nIWGVzWSw/d61sbEFVFX5M3/+8Jrjq6pq/ztu2bCBNXeOph2qyfJOVF2QXsV/AHVZy42MYsSaL7nN\nMinrTX8P7vx3dyW+EKcvxAi+FaczuM3SycvLIyoqCoAvv/ySuLg4dw3VpFB2i307ML1dgu0Eqi7W\nmZnRxMbutWm10NbOchk1ailpabcAt9Qay9k1APYe/TDCwl6hS5cLnZpkLjSZ+PD2MZz67Rcisd9A\nUe+8vxXVxrjrW2/zkEzKCoLbcJvgv/zyy+zevRs/Pz86duzI888/766hmhQxMXnArVhbIvzGpk33\n1Mq8HVst6CtthwxJw9kJVGcnW+1LR0MZOLA9ixbd2OBnKTSZePe6fgSeyKcDEIf9fUsUkI5aQnaD\nbDcoCG7HbYKfnJzsrlM3aQyGSlS/GmXRXH55O6daLehi7Wxd/9kce7alo4UmEx/eOoLAHdvpgloC\n1prabYx/B4au38TAmwf4xG2zIPg6stLWy8jO7oCavtQff1bncfWJta04x8WVMnt2/eLsrJCfqXTU\nkS0bNrDqztG0R7VAsF3nexfWNsbfA/1XrJa2CIJwHhHB9zIam3XbinNDE05nI+QNUWgy8cn/3U7Z\nT/8jGrVm19a+CQcWoBZRHbj8Sh5Y/gmh4REuGVsQBOcQwT+POFMV8+STV7J1q76l31FmzhxV57lc\nKdaNpdBk4tUr+hB56iRdgQyUjWNr32QBtGzFdau/kKxeEDyECP55xJmqmLlzfyYr60nAQGmpxpw5\nS1m0yOiJcJ1CX0TVDvsKnGewt286zn1FFlAJgodp8oJfV1ataZzzhhyNwZmqmHNpU+AJ0lI+Iz1h\nAj1QmyN2wN7CuQDV1fJXoK9U4AiCV9DkBb+urBo46w05XIEz/vzZVNl4gkKTifWJD3MkdY1da4SZ\n2Fs4+1FCP12EXhC8hiYv+PVnzOc/i3amKuZcu2eeD45mZLDwun6EVVfVqqmPRO2EG4US+/C/PSdZ\nvSB4GU1e8OvOmM+u2ZercGai1ZsmY3UKTSa+eHgSpWnrCUNtObgT1XxZb41QBJQBmZf2JfG/n0kF\njiB4IU1e8OvPmL0zi/Y2jmZk8J8Bf2RuRTnLgenozZZVa4RoVEYfcNUfeOCj/4jQC4IX0+QFv76M\n2duyaG9k17ZtpA4dRFfq3ua8N/CjwY/79hwQoRcEH6DJC75w9hzNyODzUX/C73guc4FXULZNMbW3\nHBz6xUYRe0HwEUTwBTv0rL43qtfNEdTq2GUooX8JaAvkRbfn9tVfyN6yguBDiOALgEXoRw7hgooK\nLgGGoerrXwKmAGtRk7SFLVpy7efrZbWsIPggIvjNHL0C50Taeru6erXHFrRHZfeHUZ0tbxOhFwSf\nRQTfDTi7k5SnSUv5jG8SJhAFGKlrjy3YB1RFRnHXmi/FvhEEH0cEvwHqEu+GthNzdicpT1JoMvFz\nwgQ6A0+gsnjbCdm9wGZUVi/2jSA0DUTwG6Au8f7sswlnfI8398PRWyPkfLWei4BAVKTxKBvnJJCN\n2nLwZulXLwhNCj9PB+DtnIt4G42FqDwZvKkfztGMDN699CJOpK7huYoKgoBjqEjDgDtRi6hCbhzC\ntL2H+OOAgZ4MVxAEFyMZfgOcSzMzb+yHU2gy8emga7m2vAwT1qx+CfA00BXYHxjI7d9tFa9eEJoo\njRL8devWMX/+fDIyMlixYgV9+vSpee2dd97hk08+wd/fn6eeeor+/fs3OlhPcC7i7S39cMwnTvDf\ne+6hcPO3VBYXM1vTMAAfY83qpwFzAgOpvOkW7ntjviyiEoQmTKMEPy4ujvnz5/P000/bPZ+RkUFq\naipr164lJyeHiRMnsn79egwGQz1n8l68RbzPlqMZGSy89goiNI0eqEVU24FLUTX2r6I6XO5v2Yp7\nf9sjQi8IzYBGCX63burWX9M0u+fT0tKIj48nICCATp06YTQa+e2337jssssaM5zgJLp9E6VpdrtQ\nPY0S/FDADJyIiua2z9eL2AtCM8EtHn5ubi59+/atedy+fXtyc3PdMZTgQKHJxL8HX0e306VUYF9b\n3xVYDByL7ci9GzeL0AtCM6NBwZ84cSL5+fm1nk9MTGTw4MF1vscx4wectnMaqnH3FrwtTvOJE6Q8\n8ACZa9Yws6LCzqvXM/x9QI9Ro3j4/fcJi/Ausfe277M+JE7X4Qsxgu/E6QwNCv4HH3xw1ieNiYkh\nOzu75nFOTg7R0dFOvTcvr/isxzvfREWFeE2cu7Zt48sx8XQ9Xcpx7FfMDgNeADqixF7fW7aiyru+\nZ2/6Ps+ExOk6fCFG8K04ncFldfi2Wf3gwYNZu3Yt5eXlHD16lCNHjnDppZe6aijBhi/HxDP7dCn3\no1bM7sW6AiAU8IvtyIC9h5h+vEi2HBSEZk6jPPyvvvqK2bNnU1BQwOTJk+nZsyfvvvsuPXr0YOjQ\noQwbNoyAgACeeeYZn6zQ8WaOZmSQOm443U6X2vn03VFtEsqBnE6duCPtO/HqBUEAwKDVZbh7EF+5\nffJUnIUmE9/OeIyDa1fzXEUFy1BdLXWffhYQGhZG62uu488fLaGiKtAjcZ4NvnTbLHG6Bl+IEXwr\nTmeQlbY+gi702qYNtDSb6YZ9D5xS4ECrIG5eta6m/01YhG/8sQqCcH4QwfcRvp3xGPemfGqXydv2\nwJkT25G/pO/2ZIiCIHg5IvhejJ7Vhx4+hHbogJ1X3xO1G1V7Pz+yYzowdOUazwUqCIJPIILvxdhm\n9Y419dlhYcQMHMz1ya/JpKwgCE4hgu9l7Nq2jS9G/wljWRm5wALgblRN/SthYXTv0o1CYxfGiNAL\ngnCWiOB7GV+OiefFsrKaTH4ZkIry6SMHDub6RYs9GZ4gCD6MCL6X0a3stJ1XHwKYgoJYPGQo1ye/\n5sHIBEHwdUTwPYztxGyh0cjuwBZo5dYMvxioHjKU4ZLZC4LQSETwPYxduWX6z/x94CCe+vF7jGVl\nHDcYaHX9QMZIZi8IggsQwfcwoYcP2Vk4nQsLuftonidDEgShiSKbmJ9HCk0mPk+4l2+H3MDnCfdQ\nWGCi0Gi02e4cCo1dPBihIAhNGcnwzyOO9s1iDFyf/DqLMVg8/C4yMSsIgtsQwT+PONo3oYcPERoe\nIROygiCcF8TSOY+IfSMIgieRDN8NOJZaXp/8OqHhEWLfCILgUUTw3UBdXv3wRYvFvhEEwaOIpeMG\n6vLqBUEQPI0IvhsQr14QBG9ELB03IF69IAjeSKMEf926dcyfP5+MjAxWrFhBnz59ADh27Bjx8fF0\n69YNgMsuu4xnn3220cH6CuLVC4LgjTRK8OPi4pg/fz5PP/10rdcuuOACVq5c2ZjTC4IgCC6kUYKv\nZ/CapjVwpCAIguBp3DZpm5mZydixYxk/fjw//fSTu4YRBEEQnKTBDH/ixInk5+fXej4xMZHBgwfX\n+Z7o6Gi+/vprQkND2blzJw8//DBr1qyhTZs2DQYUFRXiRNjnD/OJE6Q+9BDBBw9S3LUr8QsWAN4X\nZ31InK5F4nQdvhAj+E6cztCg4H/wwQdnfdLAwEBCQ0MB6NOnD507d+bQoUM1k7pnIi+v+KzHcyef\nJ0yyLqLaupXFZZVM/OwTr4uzLqKiQiROFyJxug5fiBF8K05ncJmlY+vjm0wmqqurATh69ChHjhyh\nc+fOrhrqvCKLqARBaCo0atL2q6++Yvbs2RQUFDB58mR69uzJu+++y08//cTf//53AgIC8PPz4/nn\nn6dt27auivm8Umg0oqX/XLPloCyiEgTBV2mU4N90003cdNNNtZ4fMmQIQ4YMacypvQZZRCUIQlNB\nVto2gCyiEgShqSC9dARBEJoJzVLw69pbVhAEoanTLC2d+vrVC4IgNGWaZYYvpZaCIDRHmqXgS796\nQRCaI03e0qlrf1kptRQEoTnS5AW/Pr9ePHtBEJobTd7SEb9eEARB0eQFX/x6QRAERZO3dMSvFwRB\nUDR5wZfWCIIgCIomb+kIgiAIChF8QRCEZoIIviAIQjNBBF8QBKGZIIIvCILQTBDBFwRBaCY0SvCT\nk5MZOnQoo0aNYtq0aZSUlNS89s477zBkyBCGDh3Kd9991+hABUEQhMbRKMHv378/a9asISUlBaPR\nyDvvvAPA/v37SU1NZe3atSxatIjnnnsOTdMaOJsgCILgThol+Ndeey1+fuoUffv2JScnB4ANGzYQ\nHx9PQEAAnTp1wmg08ttvvzU+WkEQBOGccZmHv2LFCgYOHAhAbm4uHTp0qHmtffv25ObmumooQRAE\n4RxosLXCxIkTyc/Pr/V8YmIigwcPBmDBggUEBgYyfPhwgDrtG4PBUOs5QRAE4fzRoOB/8MEHZ3x9\n5cqVbNq0iSVLltQ8FxMTQ3Z2ds3jnJwcoqOjnQooKirEqeM8jcTpWiRO1+ILcfpCjOA7cTpDoyyd\nb775hnfffZcFCxbQokWLmucHDx7M2rVrKS8v5+jRoxw5coRLL7200cEKgiAI545Ba0T5zJAhQ6io\nqIMzjrUAAATvSURBVCAsLAyAyy67jGeffRZQZZkrVqwgICCAp556iv79+7skYEEQBOHcaJTgC4Ig\nCL6DrLQVBEFoJojgC4IgNBNE8AVBEJoJXiv47733Hj179sRsNns6lDp58803GTlyJKNHj+b+++8n\nLy/P0yHVyZn6HXkT69atY/jw4fTq1YudO3d6Ohw7vvnmG/70pz9xyy23sHDhQk+HUy8zZ87k2muv\nZcSIEZ4OpV5ycnKYMGEC8fHxjBgxwq6c25soLy/ntttuY/To0YwYMYL58+d7OqR6qa6uZsyYMUye\nPLnhgzUvJDs7W7vvvvu0QYMGaQUFBZ4Op05KSkpqfl+yZIn29NNPezCa+tm8ebNWVVWlaZqmvfzy\ny9orr7zi4YjqJiMjQzt48KA2fvx4bceOHZ4Op4aqqirtpptu0jIzM7Xy8nJt5MiR2v79+z0dVp1s\n3bpV27VrlzZ8+HBPh1Ivx48f13bt2qVpmvo/NGTIEK/9Pk+dOqVpmqZVVlZqt912m/brr796OKK6\n+eCDD7Tp06drDz74YIPHemWGP2fOHJKSkjwdxhlp06ZNze+lpaU1PYW8jfr6HXkb3bp1o0uXLl7X\nZO+3337DaDTSsWNHAgMDGTZsGGlpaZ4Oq0769etH27ZtPR3GGYmKiqJXr16A+j/UvXt3jh8/7uGo\n6iYoKAhQ2X5lZaWHo6mbnJwcNm3axG233ebU8Q2utD3fbNiwgQ4dOnDRRRd5OpQGef3110lJSSEk\nJMRrb01tWbFiBcOGDfN0GD5FXX2htm/f7sGImg6ZmZns2bPHaxdlVldXM3bsWI4cOcKf//xnr4xT\nT46Li4udOt4jgl9ff55HH32Ud955h/fff7/mOU9mfA31EUpMTCQxMZGFCxfy0UcfMW3aNA9EeXb9\njjzp7zoTp7fhbXccTYWTJ0/yyCOPMHPmTLu7ZW/Cz8+Pzz77jJKSEh566CH2799Pjx49PB1WDV9/\n/TWRkZH06tWLLVu2OPUejwh+ff159u3bx7Fjxxg1ahSappGbm8u4ceP473//S7t27c5zlA33EdIZ\nPnw4Dz74oMcE/1z6HXkCZ79PbyImJoasrKyax7m5uU73hRLqprKykkceeYRRo0Zx0003eTqcBgkO\nDuYPf/gD3377rVcJ/s8//8yGDRvYtGkTZWVlnDx5kqSkJJKTk+t9j1cZz3FxcWzevJm0tDQ2bNhA\n+/btWblypUfEviEOHz5c83taWhrdunXzYDT1U1+/I2/Gm7LqSy65hCNHjnDs2DHKy8tZs2YNN954\no6fDqhdv+u7qY+bMmfTo0YN77rnH06HUi8lkqrFJTp8+zQ8//OB1/8cfe+wxvv76a9LS0njttdf4\n4x//eEaxBy/08G0xGAxe+wf86quvcvDgQfz8/IiNjeW5557zdEh18sILL1BRUcF9990H2Pc78ia+\n+uorZs+eTUFBAZMnT6Znz568++67ng4Lf39/Zs2axX333Yemadx66610797d02HVyfTp09myZQtm\ns5kbbriBadOmMW7cOE+HZce2bdtYvXo1cXFxjB49GoPBQGJiIgMGDPB0aHbk5eXxxBNPUF1dTXV1\nNfHx8TX7ffgy0ktHEAShmeBVlo4gCILgPkTwBUEQmgki+IIgCM0EEXxBEIRmggi+IAhCM0EEXxAE\noZkggi8IgtBMEMEXBEFoJvw//5K32R/vBHAAAAAASUVORK5CYII=\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f5be3c99f50\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Current loss: 9.48636\n" + ] + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.scatter(inputs, outputs, c='b')\n", + "plt.scatter(inputs, model(inputs), c='r')\n", + "plt.show()\n", + "\n", + "print('Current loss: '),\n", + "print(loss(model(inputs), outputs).numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "sSDP-yeq_4jE" + }, + "source": [ + "### Define a training loop\n", + "\n", + "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "MBIACgdnA55X" + }, + "outputs": [], + "source": [ + "def train(model, inputs, outputs, learning_rate):\n", + " with tf.GradientTape() as t:\n", + " current_loss = loss(model(inputs), outputs)\n", + " dW, db = t.gradient(current_loss, [model.W, model.b])\n", + " model.W.assign_sub(learning_rate * dW)\n", + " model.b.assign_sub(learning_rate * db)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "RwWPaJryD2aN" + }, + "source": [ + "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 446 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 569, + "status": "ok", + "timestamp": 1527005915434, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "XdfkR223D9dW", + "outputId": "c43591ae-d5ac-4f2b-a8e7-bfce607e0919" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0: W=5.00 b=0.00, loss=9.48636\n", + "Epoch 1: W=4.58 b=0.42, loss=6.28101\n", + "Epoch 2: W=4.24 b=0.76, loss=4.29357\n", + "Epoch 3: W=3.98 b=1.02, loss=3.06128\n", + "Epoch 4: W=3.78 b=1.23, loss=2.29721\n", + "Epoch 5: W=3.61 b=1.39, loss=1.82345\n", + "Epoch 6: W=3.49 b=1.52, loss=1.52970\n", + "Epoch 7: W=3.38 b=1.62, loss=1.34756\n", + "Epoch 8: W=3.30 b=1.70, loss=1.23463\n", + "Epoch 9: W=3.24 b=1.76, loss=1.16460\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW0AAAEDCAYAAAD+/1UIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzt3Xl4VOXdPvD7zJZ9XwmELQkQIAELsiTsi6xiEBGXAiIW\nbV8WBY2K0tLa4lbsr283qxURtIoioAi8SpFNg6whi0FJKAoJBgLZt5k5c87vj5OZLIRkgEnOGXJ/\nritXJsmZyT0sN1+enPOMIMuyDCIicgs6tQMQEZHzWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu\nxODMQePGjYOvry90Oh0MBgM2b97c1rmIiKgZTpW2IAjYuHEjAgIC2joPERG1wKnlEVmWIUlSW2ch\nIqJWCM5cETl+/HgEBARAEATMmTMH9957b3tkIyKiJpxaHvnggw8QFhaG4uJiLFiwAD179sTgwYPb\nOhsRETXh1PJIWFgYACA4OBgTJ05EVlZWi8fL3t6AIADdugFvvglYrTeflIiIWl8eqampgSRJ8PHx\nQXV1NR5++GEsXrwYI0aMuPadCgtRvfoFeL2zDkJtLWxdu6NqRSrMs+8DDE4N9y4XFuaHoqIKVb73\ntTCTc7SYCdBmLmZyjlYzOaPVSfvy5ct44IEHkJKSgjlz5mDcuHEtFzYAREai6oWXUHwkA9WPPApd\n4QX4L/sVgpMGwWPTvwFRdCocERE15tQPIm9Ew3/FdBcK4P3ntfB89x0IVivEmFhUP/kMzCmzAL2+\nLb79VbT6LysztU6LmQBt5mIm52g1kzPa5YpIKaozKl9+DcWHT6Jm7gLof/wB/r98BEGjh8Fj28cA\nTyckInJKu17GLnWJRuXaP6P40AnUPDgP+jN58F+0AEFjhsO0fRvLm4ioFarsPSJ1647KP/0VxWnH\nUTvnAehPf4+AhfMQNG4ETDs/A/hiOkREzVJ1wyipR09U/OV1lHx9FLX3zIH+uxwEPPQAAieMgunz\nXSxvIqImNLHLny0mDhV/fxMlB4+g9u57YMjORMDcOQicNAamPV+wvImI6miitO1scb1Q8fo6lOz/\nBrUzZsJ4Mh0B99+DwKnjYdy7h+VNRNftL395DR999IHj4+XLl2DVqlWOj//61/+HDz/8txrRboim\nStvO1iceFf96B8V702CeNgPG48cQOGcmAu+cBOOBfSxvInJa//6JyM7OAKBsfldWVorc3FzH17Oz\nM5GQMECteNdNk6VtZ+vXH+Vvv4uSPQdhnjwVxiPfIPCeGQhImQpj2ldqxyMiN5CQMBBZWZkAgLNn\nz6Bnzxj4+PigsrISVqsVP/74A+Liequc0nnqXFN+ncSEASjf8AEMJ0/A+9UX4bH7c5hSpsIycjSq\nnloJcdhwtSMSkRN8Vj8Pj+3bXPqY5jtTULX699f8emhoKPR6Ay5duoisrEz075+I6uoyZGdnwsfH\nBzExsTCotL3GjdD0pN2UOPBnKH/vI5Ts2gPL2PEwHdyPoBmTEDD7LhiOHlY7HhFpVGJiIrKyMpCd\nrZT2gAEDkJWVgaws91oaAdxk0m5KHHQ7yjZtheHIYfi8sgam/Xth2r8X5vETUZ26EuJtg9SOSETN\nqFr9+xan4rbSr18isrIy8d//KssjHh4y/vnPf8HX1wfTpt3V7nluhltN2k2JQ4aibPMnKP1kFyzJ\nI+GxZzeCJo2F/8/vhSHzpNrxiEgjEhIGIC3tIPz9/SEIAgICAlBZWYHs7Cz075+gdrzr4talbWcd\nnoyyrTtQuuUzWIYlweOL/0PQhFHwn3c/9HU/gCCijismJhbl5WXo3z+x0ef8/Pzg7+9er33bLrv8\ntStZhvHAPvi8/AcYjx0BAJin3wWPXz+Hom69lRdn0Ait7jTGTM7RYi5mco5WMznjlpi0GxEEWEeP\nRemO3Sj9YAusPxsEj88+AYYMQeCEUfBc/xaEinK1UxIR3ZBbr7TtBAHWcRNQuutLlG7aCsycCUNO\nNvxSn0BIQm/4Ll8CQ/pxXqhDRG7l1i1tO0GAdex4YMsWFJ88hapnV0EKCYHXu+8gaNJYBI4fCc+3\n/wWhvEztpERErbr1S7sBKSIS1U88heKjmSj9YAvM02bAcOpb+D29HCGJveH7xGIYThzj9E1EmtWh\nSttBp4N13ASUv/2uMn2v/DWk0DB4vbcBQZPHIWjcCHiue5PTNxFpTscs7QakiEhUP/4kio9koHTT\nVpinzYD++1Pwe2aFMn0//j8wHD/K6ZuINKHDl7aDTgfr2PHK9J2eg8rnfgMpNBxe/96IoCnjETQ2\nmdM3kZsqLPwJ8+bNUTuGS7C0myFFRKJm2QoUHzmpTN/T74L+9HfK9J3QC77LfgXDsSOcvonciKCh\nazRuBku7Jfbpe91GXEk/hcrnV0MKj4DX++8iaOoEZfp+6w0IZaVqJyWiVoiiiD/8YTXmz78fy5Yt\ng9lsVjvSDbn1roi8BpddASVJMB7YB6+N62Ha9RkEUYTs5QXzXXejZu5DEAcPcfqqS61elcVMztFi\nLq1nWr3aA9u3u3afujvvFLF6dcsFXFj4E2bPnoF//GMd+vdPwJ/+9CI6dYrGfff93KVZbkbHvSKy\nrel0sI4Zh/K3NuDKye9Q+fxvIYVHwPOD9xA0bSKCxiTB861/cvom0piIiEjH5lAzZsxAZmaGyolu\njFtuzaoVcng4apY+gZrFy2A8uB+eG9fDY+d2+D37FHx/92uYZ8xEzbwF1zV9E93KVq82tzoVt5Wm\na9ru+leSk7Yr6HSwjh6Lin+9o0zfq34HKSISnpv+XTd9D4fnv16HUFqidlKiDquw8Cd8+202AGDH\njh1ITByocqIbw9J2MTk8HDVLHkfxN+ko3fwpau+6G/q8XPitTEVIYm/4LXkMhiOHeeYJUTvr3r0H\ndu36DPPn34+ysjKkpNyjdqQbwh9EtgOhqAiem/4Nz41vw3D2vwAAsU88DAsfxpVREyH16KlKruZo\n/QdZWqLFXMzkHK1mcgYn7XYgh4WhZvEylBw6gdKPt6M25W7oz+QBTz2FkKEDETR6OLxf/gMMWRmc\nwImoRfxBZHvS6WAdORrWkaNReeUKQr/eA/Omj2A6sA8+a1+Gz9qXYYvuCvOUabBMvRPWIcMAN3qV\naCJqe2wElcghIcDChSifcS+EygoYv/wPPHZ+BtPuz+H9xj/g/cY/IAUHwzxpKixTpsMyeizg5aV2\nbCJSGUtbA2RfP1hmzIRlxkzAYoHx64NKgf/fDni9/y683n8Xsrc3LGMnwDx1OiwTJ0EODFI7NhGp\ngKWtNSYTrGPHKy/c8PJaGE4cg8euHTDt3A6PHZ/CY8enkA0GWJNGKgU+ZRqkTlFqpyaidsLS1jKd\nDuLgIRAHD0HV86uhP/09PHZ9BtPO7TAd2AvTgb3AMytg/dkgmKdMh2XqnbDF9VI7NRG1IZ494i4E\nAbbefVD9+JMo/WI/rqTnoOLFV2EZOQaGjJPw/cNvEZw8GEFJg+Dz+9XKHuCSpHZqItVVVlZi69bN\nbfb406dPQGVlJQDgypXLGDnydmRlZTT4+kSUl7vuxcSdLm1JkjBz5kw89thjLvvmdOOkzl1Qu/BR\nlH38Ka7knEH5X/8J89Q7oS/Ih/f/voagKeMRPDAevqlPwLjvS8BiUTsykSoqKsqxdetHzX5NcsFg\n07dvArKzMwEA2dmZ6NWrD7KylI/PnfsRgYFB8Pf3v+nvY+d0aW/YsAExMTEu+8bkOnJQMMz33o/y\n9e/h8qmzKHvnfdTOeQCCuRZe699C4L0pCOkbA79fPgLT9m1A3VRA1BG8/vpfceFCAR5++EH8/e//\ni/T045g3bx5++9vnMX/+fVe9QML777+Lt99+EwBQUJCPFSuW4pFH5mHx4kU4d+7Hqx4/ISHRUdpZ\nWZmYM+dBfPttfYknJCS69Pk4taZdWFiI/fv347HHHsPbb7/t0gDkYt7esEyZBsuUaYAowvhNGky7\nPoPHzs/g+fGH8Pz4Q8geHrCMGQfLlOkw3zEFcmio2qmpAwke1L/Zzxcfz3bJ8U398pdL8MMP/8W6\nde8BANLTjyMrKwsbNnyIyMhIFBb+dM0XSHjllTVITV2Jzp27ICcnG2vXvoQ///kfjY7p3z8R69e/\nBQA4depbPPLIY/joo38DUEo8IWGAUzmd5VRpr1mzBqmpqaio0NZln9QKgwHWEaNgHTEKVb9/GYas\nDOUslF074PH5Lnh8vgu+Oh2sQ4fDMnU6zFOmA2HN/wUhupUkJiYiMjKyxWNqamqQnZ2BVauehn23\nD1EUrzqub99+yM39HrW1tbDZbPD09ERUVGcUFOQjOzsD99/v2j27Wy3tffv2ITQ0FPHx8Th8+LDT\nD+zsdfTtqcNnGj9SeVv7CpCbC3zyCYStW2E6lAbToa/hu+pZICEBYWPGAGPGAKNGARqZwrX4ewdo\nM5fmMzWzxAAAYde68/Ue34TFUg69XufIEBjoDS8vL8fHklQNQajPaDQCOp0JwcHeCAgIwPbtn7by\nHfzQvXs37N//OQYMSEBYmB+GDBmMrKxjKC8vw6Br/E/hRrVa2idOnMCXX36J/fv3w2w2o6qqCqmp\nqXjllVdavJ8WN2NhpgYCI4H5jwLzH4Vw8SI8Pt8Jj53bYUr7CsjKAv7yFwCAGN8X1uHJsCSPhHVY\nMuQwZ/+quI4Wf+8AbeZipqvV1sqoqKh0ZCgtrQZQ31GSZMLly1dw5kwBPD09sXv3HgwbloSaGhkR\nEZ3w4YdbMXbsBABAXl4uYmPjrvoeffr0w7p1b2PhwkdRVFSBbt164YUXViE+vp/Tz93Zf2xbLe3l\ny5dj+fLlAIAjR45g3bp1rRY2uRc5IgK18xagdt4ChPmbUPrFPhi/Pghj2tcwHjsMw6kceK1TfjAj\n9u4D6/BkWJNHwjJ8BOTwcJXTE7XM3z8ACQkDMH/+fRg6NAnDhyc3+rrBYMCCBY9g0aL5iIrqjG7d\nuju+9utfv4A//vElvPPOOthsIsaPv6PZ0k5IGIDNmzehXz/llXF69+6DoqIizJgx0+XP57q2ZrWX\n9uuvv97qsfzXvnVukcligSH9BEyHvlKK/OhhCNXVji+Lcb1gHT4C1uQRsCaNgBTR8jqhSzJphBZz\nMZNztJrJGdxPW0VumclqheHkCRgPfQ3T1wdhOHIYuqr6UwjFmFhYk0Y43lxxib0Wf50AbeZiJudo\nNZMzeBk7XR+jEeLtQyHePhQ1S5crJZ55UllKSTsI4+Fv4LVxPbw2rgcA2Lr3UNbD65ZUpM5d1M1P\n5OZY2nRzjEaIg26HOOh21Cx5HBBFGLIylBI/9BWMh9Lg9d4GeL23AQBg69odluQR9SUe3VXlJ0Dk\nXlja5FoGA8TbBkG8bRBq/mcpYLPB8G0WjF9/VV/iddvNAoAtuiusSSNgsS+ndO3mvi+TTdQOWNrU\ntvR6iIkDISYORM0vFwM2G/Q538KUdtAxjXtu+jc8NylXkNk6d3Gsh1uSRkDq3kPlJ0CkLSxtal96\nPWwJiahJSETNo/8DSBL0p3Ial/hHH8Dzow8AALZOUcCY0fDqkwAxcQDEhETI/gEqPwki9bC0SV06\nHWz9+qOmX3/U/OKXSol//x2MaV/BlKYsqeD99+GL9x13EXv0VKb3hAF1RT5Aefk2og6ApU3aotPB\nFt8Xtvi+qF24CJBlhJUWonx/GgyZGcpb1kl4frIF+GSL4262LtH1JZ44AGLiwDY5Z5zcT2VlJXbv\n/j/MnHlPm32PNWt+i+TkkRg9elybfQ87ljZpmyAAvXrBHNQJ5pRZyudkGbr8844CN2RmwJhxEh67\nPoPHrs8cd7WFR9SXeMJAiIkDIHWJ5g86Oxj7ftpNS1uSJOh07vc6MCxtcj+CACm6KyzRXWGZdqfj\n07qLhTBknmwwkWfA4z9fwOM/XziOkYKCHAVuf7N17wm44V9edzVokE+znz9+vMolxzfVcD9tvV4P\nLy9vREVF4ttvc/Dqq39Gaurj2LBhEwBlL+3a2hosWPALFBTk47XXXkFZWSk8PT2Rmvocunbtds3v\nc/ToYXz44fsoKSnG4sVPIClphFP5rhdLm24ZUkQkLBMnwzJxsuNzwpUrMGTVl7gh82T962va7+fr\nBzEh0bE+LiYOhC02DjDwr8etoOF+2unpx5Ga+gTWrn0VRqPfTe+l3VBh4U/429/eRH7+eSxd+hg2\nbdoGo9Ho8ufDP5V0S5NDQmAdMw7WMfVrjUJ5GQzZWfVTeVYGjIcPwXTo6/r7eXlB7NvfsT4uJg6A\n2DseMJnUeBq3FGcn5Bs9vjV9+/ZDVFRUi5exO7uXdkPjxk0EAHTpEo2oqM748ccfmt1c6maxtKnD\nkf0DHOeCO1RVwZCT3WAiz4AhIx3G40fr72c0QozvpxR4/0Rg2CAIIZ2VnQ65Tu42PD09Hbf1ej1s\ntvrXibRYzAAAWZbg5+fveLUbZzSd2K81wd8sljYRAPj4OPZUcTCbYfgup9FZK4Zvs2HMPOk4JBSA\n5B8AW2wsbDFxsMX1glj33tajJ+Dh0f7PhRrx9vZGdd3OlE33xwsKCkZpaQnKy8vh6emJtLSvMGxY\nEry9fdCpUxT27v1Pq3tp2+3d+x9MnjwNFy4U4MKFghbXv28GS5voWjw8IA64DeKA2+o/Z7VCn3sa\nhqwM+F/4EeaMbOjP5MKQlQnjieON7i7rdJC6doMYG+codFtsHMTYXsqLSXA6bxcN99M2mTwQHBzs\n+Jor9tK2i47uhsWLF6GkpBhPPbWyTdazAW7Nqipmco4WMwFNcokidOd+hOFMLvS5udCfyVXKPS8X\nustFV93XMZ3H1he5LTbupqdzLf5aMZNzuDUrUXsyGCD1jIGlZwzQ4OwVABBKS6DPy4U+LxeGuvf6\nvNOtT+f2Iq9bcuF0TgBLm6jNyYFBEAcPgTh4CMwNvyCK0J/7oa7E86DPO11X7KeVc8sbnF8O1E3n\nccpSixjXS1lyccF0Ts7bsGEd9u79DwRBgCzLEAQBY8dOwNy5C9otA5dHVMRMztFiJqBtc101neee\nVpZczv4XgtXa6FjHdB7XCx59eqEyJBK26GhIXaJh69IVcmioqhO6Fn//tJrJGZy0iTTIqem8bu3c\nUFfoHrs/B3Z/Dt+mj+XlBVvnLkqJR3eFFN0VtrpCl6KjIUV2AvT6dnx2dDNY2kTuxGCArWcsbD1j\ngTumNPqSUFKM0MorKMv8Dvr8c9Dln4f+/Pm69z/CkJfb7EPKBgOkqM6wdbFP59GOYpeio2HrHM3l\nFw1haRPdIuSgYKBXN1iir3FaWmUl9PnnlUI/fx76/PPQ5Z9zFLvx0NcQrrFaaguPUAo8uiukLg0K\nvW5al32d+6893TyWNlFH4esLW5942PrEN/91sxm6CwV1ZX4e+vPn6m+fOwdDxkkYjx9r9q5SYKBS\n4F2i69bT64sdiX0A2YNLMC7C0iYihYcHpB49IfXo2fzXbTboLhbWTen1yy/224b/5kHIzmz2rqF6\nPaTQMEjhEZAiIhq/D49s9DG8vdvwSbo/ljYROUevhxTVGVJUZ4hDh139dVmGUFzcYPlFKXPv4iKI\n5wuUrXPP5ELIymjx20h+/pDCwyFFRCrvHcVu/1wEpIhIyMHBHXJLXZY2EbmGIEAOCYEYEgI0uPTf\nO8wPpQ1OrxMqK6C7dBG6ixfr3hdCd+lS3fv6z+v/e+aaa+xA3Q9Qw8KbTO0RjlJvWPJosEmUu2Np\nE1G7kn39YPP1U86AaYnVCt2Vy1eVeePCvwjD96cgZKS3+FBSQGD91B4RAXTtAm9PX0jBIZCCgyEH\nBUMKCoYcEgIpKFjTJc/SJiJtMhohRXZSziNviSxDqChvMq03md7r3gy5px13a/71cOoe0ttbKfSg\nukIPaVDswcH1X6u7LQcHQ/bxbZeLmFjaROTeBAGyfwBs/gHKKw61xGKB7nIRQsQqlJ45D11JMYSS\nYuiuXGl0Wygpga6kGIYzeRCqnXsRBtlobDSty0H1hS4FBSsTfXDj4pcDAq97XZ6lTUQdh8kEKaoz\nEOYHa9dezt3HbFYKvbgYuuIrSrHbbxcX15e9/eOfLsBwKseph5Z1OsiBgZCCQ4AG/wtoCUubiKgl\nHh7KEk1kJ9icvY8oQigtVQq9boq/ZvHX3XYWS5uIyNUMBsihobCFhgJOvkxkmJMP3fFOciQicmMs\nbSIiN8LSJiJyIyxtIiI30uoPIi0WCx588EFYrVbYbDZMmjQJixcvbo9sRETURKulbTKZsGHDBnh5\necFms+H+++/HqFGjkJiY2B75iIioAaeWR7y8vAAoU7coim0aiIiIrs2p0pYkCSkpKUhOTkZycjKn\nbCIilTh1cY1Op8O2bdtQWVmJX/3qV8jLy0NsbAs7dHXvjmDp6i0Vi49nN3t48KD+zX7epcfrhKsy\nqZoHuCqT6nmaZNJEngaZNJPH7tyPmsrD42+N41tzXVdE+vr6YsiQITh48GDLpQ1Ar7t6t6trvkR8\nM8e2xfFNM6mdp2kmLeRpmEkreeyZtJSnxfuolMd+/FX3UznPVffVQJ5GH2skj7MEWW5hl3EAxcXF\nMBqN8PPzQ21tLRYuXIhFixZh9OjRLT5wUYNNz7UgLMyPmZzATM7TYi5mco5WMzmj1Um7qKgIzzzz\nDCRJgiRJmDp1aquFTUREbaPV0u7duze2bt3aHlmIiKgVvCKSiMiNsLSJiNwIS5uIyI2wtImI3AhL\nm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uI\nyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiN\nsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0iYjcCEubiMiNsLSJiNwIS5uIyI2wtImI3AhLm4jIjbC0\niYjciKG1AwoLC5GamorLly9Dr9dj9uzZmDdvXntkIyKiJlotbb1ej2effRbx8fGoqqrC3XffjeTk\nZMTExLRHPiIiaqDV5ZGwsDDEx8cDAHx8fBATE4NLly61eTAiIrrada1p5+fn47vvvkNiYmJb5SEi\noha0ujxiV1VVhaVLl2LlypXw8fFp8dju3QFJuvqY48ermj1+0KDmH8+Vx+t0V2dSMw+AqzKpnadp\nJi3kaZhJK3nszp1r9tOq5eHxt8bxrXGqtEVRxNKlS3HXXXdhwoQJTj2wTnf1EB8W5neNY5t/DFcf\n3zST2nmaZtJCnoaZtJLHnklLeVq6j1p57Mc3vZ/aeZre1kKehh9rJY+zBFmW5dYOSk1NRVBQEJ59\n9lmnH7ioqOKGArWVsDA/ZnICMzlPi7mYyTlazeSMVte0jx8/ju3bt+Obb75BSkoKZs6ciQMHDtx0\nQCIiun6tLo8MGjQIp06dao8sRETUCl4RSUTkRljaRERuhKVNRORGWNpERG6EpU1E5EacviKSiIiu\nnyQBZWVASYmA4uL6t5ISwfG5khIBn37q3OOxtImInGSxoFHRNizgpkWsfAyUlgqQJMFlGVjaRNTh\nyDJQWYmrCvdaU7D981VVzpWvXi8jKEhGaKiMuDgJQUEygoOVt6Ag1L2XHe+DgmQAvk49NkubiG4Z\nNTVAUZGAixcFXLqkw6VLyu2iIuVj5fMCLl8GLBbnLhv39lZKtUcPqVHR1pdw4/INCZHh5wcIrhuu\nG2FpE5GmSZIyEV+6JDhK2F7IDd8uXtShvLzlpvTwkBERIWPgQMDfX2yxfO23vbza6Yk6iaVNRKqo\nqUGjwm1cwvVTcVGRAFFsuYxDQiR07izhtttkhIfLiIiQEB5uvy3X3Zbg769MwMqGUTXt9Exdi6VN\nRC4likBhoYD8fB3OnxdQUQGcPevRYEpWStn5qVhCeLjUqIAblnJYmAyjsZ2enAawtInoutTUABcu\nCDh/Xof8fB3y8+23laK+cEGAzda0kE2OW9eaiusnYuVzbbku7M5Y2kTUSFkZGpVw49sCLl9u/po8\nQZARGSnjZz+T0KWL/U1GfLwnPD2rEBGhnE3RkabitsDSJupAZFlZR25Ywsq0XH+7oqL58dZolNG5\ns4z4eBFdusjo0kVCdLTkuB0VJcNkuvp+YWGeKCqS2viZdRwsbaJbiNUKnDvXtJDrlzIKCgSYzc2X\nso+P3KiEu3SxfywhOlpZtmjppdeofbC0idyMLAM//SQgL0+H3FwdzpzRIS9PeV9QAEhS8xdphIZK\niI9X1pPrC7m+mAMDuYbsDljaRBpVXQ2cOaOUccNyzsvTobr66naNiJCQlARERFgbTczR0TI6d5bg\n7a3CkyCXY2kTqcg+Nefm1heyfWrOz796LcLTU0bPnhJiYxu/xcQoZ1so5x/XqvBMqL2wtInagX1q\nbljK9um5uak5MlLCyJEiYmIal3OXLlxX7uhY2kQuIsvK+csNJ2b7W0HBtafmuDjJUc72277O7R1E\nHRBLm+g6WSzA99/rcOkScOKEqdH03NzU3KmTMjU3XMqIi5PQuTOnZrp+LG2iFtTUADk5OmRm6pGV\npbw/dUoHq9Vezh4AAC+va681c2omV2JpE9WprASys/XIzKwv6dOndY0uyfbwkJGQIKF/fxsGDzYh\nIqIasbGcmqn9sLSpQyopAbKylIJW3utx5kzj1vX2ljF4sA2JiRISEpT3cXGS4zLssDATiopsKqSn\njoylTbe8S5cEx9KGvaTPnWtc0AEBMkaOFJGQICEx0YbERBt69uT0TNrD0qZbhv3sjYblnJmpQ2Fh\n4+YNDZUwbpyIxESbo6S7dpV5NSC5BZY2uSVZBn74QXAUs30N+sqVxgUdFSVh8mRrgwlaQmQkC5rc\nF0ubNM9mA06f1jUq56ws/VWb6HfrJiEpyepYg05IkBAWJquUmqhtsLRJc8xmID1dj6+/1iMtTY/j\nx4Hqah/H1wVBeYXrCRPqp+f+/W0IDFQxNFE7YWmT6mprgRMnlIJOS9Pj2DE9amvrp+h+/YCEBKtj\nDbpfPxvPfaYOi6VN7a6mBjh+vL6kjx/XO/Z4FgQZfftKSE62YfhwG4YPF9G7NzdBIrJjaVObq64G\njh2rL+lQSYIFAAANpklEQVQTJ/SwWOpLun9/CUlJNiQl2TBsmIigIJUDE2kYS5tcrqrq6pK2X/at\n09WXdHKyiKFDuRZNdD1Y2nTTKiuBo0ftJW1AeroOolhf0omJ9klaKemAAJUDE7kxljZdt8pK4MgR\n+9kdBmRk1Je0Xi9jwAAJw4crk/SQITb4+6scmOgW0mppr1y5Evv27UNISAi2b9/eHplIYyoqgMOH\n6yfpjIz6TZT0ehkDB0pIShKRnGzDkCE8s4OoLbVa2nfffTfmzp2L1NTU9shDGlBeDnzzjVLQaWnK\nFYeSpJS0wSDjttskJCeLGD6cJU3U3lot7cGDB6OgoKA9spBKZBk4eVKHXbsMOHgQSE/3dZS00ajs\ndGc/u+P2223w8WnlAYmozXBNu4OyWoFDh/TYtcuAXbsMuHBB2bPDaARuv92G5GSlpAcPtvFVvIk0\npM1KOyzMr60e+oZ19EzV1cAXXwBbtwLbtyt7SgNAYCAwdy6QkgJMmgT4+BigtX/Ptfh7B2gzFzM5\nR4uZnNFmfzOLiira6qFvSFiYX4fMVFICfPGFATt3GrBvnwE1NcqyR2SkhAULREydKiIpyebY2N/H\np2P+Ot0ILeZiJudoNZMznCptWeZOae7kwgUBu3YpRZ2Wpnec6REba8PUqUpRDxwocYN/IjfUammv\nWLEChw8fRmlpKcaMGYMlS5Zg1qxZ7ZGNrkNurg47dypFnZ6ud3z+ttuUop4yRUSvXpKKCYnIFVot\n7bVr17ZHDrpOkqSc8WEv6rw8paj1euVls+xFHRXF/yUR3Uq09dMmapHVCqSl1Z/x8dNPyvqGl5eM\nKVOsmDpVxB13cMMlolsZS1vjqquBvXuVaXr3bgNKS5X16cBAGffeqxT1mDEiT8sj6iBY2hpUUgJ8\n/rlS1Pv315/xERUlYdYspaiHDas/44OIOg6WtkYUFAiOZY+GZ3z06lX/g8SBAyW+IC1RB8fSVtGp\nU8C775qwc6cBJ0/Wn/Hxs5/ZT82zIjaWP0gkonos7XZWWgp89JER775rxKlTAOABg0HGqFH1Z3x0\n6sSiJqLmsbTbgSwDR4/qsGGDCZ9+akBtrQCjUcbMmcCECTWYOFHkq7cQkVNY2m2orEyZqjduNOLU\nKWX5o0cPCXPnmjFnjoi+fX1RVCSqnJKI3AlL28VkGTh2TIeNG0345BPlzA+jUcZdd1kxd64VI0bY\nePk4Ed0wlraLlJUBmzcbsWFD/VTdrZuEuXMtuP9+K8LCuE5NRDePpX0TZBk4cUJZq962TZmqDQYZ\nM2YoU/XIkZyqici1WNo3oLy8fqrOyWk8Vd93nxXh4ZyqiahtsLSdJMtAeroOGzYYsW2bEdXVylQ9\nfboV8+ZZMWoUp2oianss7VZUVChT9caNRmRnK1N11671U3VEBKdqImo/LO1m2F/oduNGI7ZsUaZq\nvV7GtGnKVD16NKdqIlIHS7uBysr6qTorq36q/vnPlTNAOFUTkdpY2gAyMpS16o8/rp+qp05Vpuox\nYzhVE5F2dNjSrqwEtmxRzgDJzFSm6i5dJCxdasEDD1gRGcmpmoi0p8OVdmamDu+8o6xVV1UpU/Xk\nyVbMn69M1Xp9649BRKSWDlHalZXAtm3A3//u7dgCtXNnCYsXK1M1d9UjIndxS5d2eTnwxhsmvP66\nCeXlgE6nw+TJylr12LGcqonI/dySpV1ZCbz1lgl/+5sJpaUCQkIkrF4tICWliq9OTkRu7ZYq7epq\nYN06I/72NxOuXNEhMFDGc8+ZsXChBT16+KGoiIVNRO7tlijt2lpgwwYj/vxnE4qKdPD3l5Gaasai\nRRb4+6udjojIddy6tM1m4N13lbIuLNTBx0fG8uVmPPaYha8EQ0S3JLcsbasVeP99I/70JxMKCnTw\n9paxZIkZv/qVFSEhXAIholuXW5W2KAIffWTA2rUeOHdOB09PGY89ZsGSJRa+yAARdQhuUdo2G7Bl\niwF//KMHzp7VwWSS8cgjFixbZuF+IETUoWi6tCUJ+PRTA1591YTcXD2MRhkPPWTB449beOoeEXVI\nmixtSQJ27lTK+tQpPfR6GT//uVLWXbuyrImo49JUacsy8MUXerz8sgeys/XQ6WTMmWPF8uVm9OjB\nsiYi0kRpyzKwd69S1unpegiCjLvvtuLJJ82IjWVZExHZqVrasgwcPKiU9dGjykYgM2ZY8eSTFvTp\nI6kZjYhIk1Qr7UOH9HjpJRMOHVIiTJlixVNPWdC/P8uaiOha2r20jx7V4aWXPHDwoPKtJ04UkZpq\nxoABLGsiotY49UJaBw4cwOTJkzFp0iS88cYbN/SNTpzQ4b77vDBtmg8OHjRgzBgRu3ZV4b33aljY\nREROanXSliQJL7zwAtavX4/w8HDcc889GD9+PGJiYpz6BllZOrzyigc+/1z5ViNGiEhNtWDYMNvN\nJSci6oBaLe3MzEx069YNnTt3BgBMmzYNe/bsabW0c3J0ePVVE3bsMAIAhg4V8fTTFowYwbImIrpR\nrZb2xYsX0alTJ8fHERERyMrKavE+990HfPihN2RZwKBBNjz9tBmjR9sgCDcfmIioI2u1tGX5+s+T\n3rQJGDBAwtNPmzF+PMuaiMhVWi3tyMhIXLhwwfHxxYsXER4e3uJ9lJ7XA/C+yXiuFRbmp3aEqzCT\nc7SYCdBmLmZyjhYzOaPVs0cSEhJw7tw5FBQUwGKxYMeOHRg/fnx7ZCMioiZanbT1ej1WrVqFhx9+\nGLIs45577nH6zBEiInItQb6RRWsiIlKFUxfXEBGRNrC0iYjcCEubiMiNuHTDqAMHDmDNmjWQZRmz\nZs3CokWLXPnwN2TlypXYt28fQkJCsH37drXjAAAKCwuRmpqKy5cvQ6/XY/bs2Zg3b56qmSwWCx58\n8EFYrVbYbDZMmjQJixcvVjWTnSRJmDVrFiIiIvD666+rHQfjxo2Dr68vdDodDAYDNm/erHYkVFRU\n4LnnnkNubi50Oh3WrFmDAQMGqJrp7NmzeOKJJyAIAmRZxvnz57Fs2TLV/6yvX78emzdvhiAI6NWr\nF1588UWYTCZVM73zzjuOP0et9oHsIjabTZ4wYYKcn58vWywWecaMGXJeXp6rHv6GHT16VM7JyZGn\nT5+udhSHS5cuyTk5ObIsy3JlZaV8xx13aOLXqrq6WpZlWRZFUZ49e7ackZGhciLF22+/La9YsUJ+\n9NFH1Y4iy7Isjxs3Ti4tLVU7RiNPP/20vHnzZlmWZdlqtcoVFRUqJ2rMZrPJycnJ8oULF1TNUVhY\nKI8bN042m82yLMvysmXL5K1bt6qa6fTp0/L06dNls9ksi6IoP/TQQ/KPP/54zeNdtjzScI8So9Ho\n2KNEbYMHD4a/v7/aMRoJCwtDfHw8AMDHxwcxMTG4dOmSyqkALy8vAMrULYqiymkUhYWF2L9/P2bP\nnq12FAdZliFJ2tmZsrKyEseOHcOsWbMAAAaDAb6+viqnaiwtLQ1du3ZttCWGWiRJQk1NDURRRG1t\nbasXC7a1M2fOYODAgTCZTNDr9bj99tuxe/fuax7vstJubo8SLRSR1uXn5+O7775DYmKi2lEgSRJS\nUlKQnJyM5ORkTWRas2YNUlNTIWhoLwRBELBw4ULMmjULH374odpxkJ+fj6CgIDz77LOYOXMmVq1a\nhdraWrVjNbJz505MmzZN7RiIiIjAggULMGbMGIwaNQp+fn5ISkpSNVNcXByOHj2KsrIy1NTU4MCB\nA/jpp5+uebzLSlvm6d7XraqqCkuXLsXKlSvh4+OjdhzodDps27YNBw4cQEZGBvLy8lTNs2/fPoSG\nhiI+Pl5Tf74++OADbNmyBW+++Sbee+89HDt2TNU8oigiJycHDzzwALZu3QpPT88b3ve+LVitVnz5\n5ZeYMmWK2lFQXl6OPXv2YO/evTh48CCqq6tV/1lXTEwMfvGLX2DBggVYtGgR+vTpA4Ph2j9udFlp\n38geJR2ZKIpYunQp7rrrLkyYMEHtOI34+vpiyJAhOHjwoKo5Tpw4gS+//BLjx4/HihUrcPjwYaSm\npqqaCVCWtwAgODgYEydObHXXy7YWGRmJyMhIJCQkAAAmTZqEnJwcVTM1dODAAfTr1w/BwcFqR0Fa\nWhqio6MRGBgIvV6PiRMnIj09Xe1YmDVrFrZs2YKNGzciICAA3bp1u+axLittLe9RoqUpzW7lypWI\njY3F/Pnz1Y4CACguLkZFRQUAoLa2FocOHULPnj1VzbR8+XLs27cPe/bswWuvvYahQ4filVdeUTVT\nTU0NqqqqAADV1dX46quvEBcXp2qm0NBQdOrUCWfPngUAfPPNN5raamLHjh2YPn262jEAAFFRUcjI\nyIDZbIYsy5r5tSouLgYAXLhwAbt3727x18tlp/xpdY8S+4RWWlqKMWPGYMmSJY4f2Kjl+PHj2L59\nO3r16oWUlBQIgoAnnngCo0aNUi1TUVERnnnmGUiSBEmSMHXqVIwePVq1PFp1+fJlLF68GIIgwGaz\n4c4778SIESPUjoXnn38eTz75JERRRHR0NF588UW1IwFQBoC0tDT87ne/UzsKACAxMRGTJk1CSkoK\nDAYD+vbti3vvvVftWFiyZAnKyspgMBjwm9/8Bn5+196BkHuPEBG5EV4RSUTkRljaRERuhKVNRORG\nWNpERG6EpU1E5EZY2kREboSlTUTkRljaRERu5P8D+7Wym3BFpegAAAAASUVORK5CYII=\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0x7f5be4b8ec50\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "model = Model()\n", + "\n", + "# Collect the history of W-values and b-values to plot later\n", + "Ws, bs = [], []\n", + "epochs = range(10)\n", + "for epoch in epochs:\n", + " Ws.append(model.W.numpy())\n", + " bs.append(model.b.numpy())\n", + " current_loss = loss(model(inputs), outputs)\n", + "\n", + " train(model, inputs, outputs, learning_rate=0.1)\n", + " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n", + " (epoch, Ws[-1], bs[-1], current_loss))\n", + "\n", + "# Let's plot it all\n", + "plt.plot(epochs, Ws, 'r',\n", + " epochs, bs, 'b')\n", + "plt.plot([TRUE_W] * len(epochs), 'r--',\n", + " [TRUE_b] * len(epochs), 'b--')\n", + "plt.legend(['W', 'b', 'true W', 'true_b'])\n", + "plt.show()\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vPnIVuaSJwWz" + }, + "source": [ + "## Next Steps\n", + "\n", + "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n", + "\n", + "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n", + "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n", + "\n", + "The [next tutorial](TODO) will cover these higher level APIs." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "Training Models", + "provenance": [], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5749f22ac58e0a012ed7e3fec4dfe2913d3f8273 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/4_high_level.ipynb @@ -0,0 +1,551 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "pwX7Fii1rwsJ" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.enable_eager_execution()\n", + "tfe = tf.contrib.eager\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "UEu3q4jmpKVT" + }, + "source": [ + "# High level API\n", + "\n", + "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "zSFfVVjkrrsI" + }, + "source": [ + "## Layers: common sets of useful operations\n", + "\n", + "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n", + "\n", + "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n", + "\n", + "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "8PyXlPl-4TzQ" + }, + "outputs": [], + "source": [ + "# In the tf.keras.layers package, layers are objects. To construct a layer,\n", + "# simply construct the object. Most layers take as a first argument the number\n", + "# of output dimensions / channels.\n", + "layer = tf.keras.layers.Dense(100)\n", + "# The number of input dimensions is often unnecessary, as it can be inferred\n", + "# the first time the layer is used, but it can be provided if you want to \n", + "# specify it manually, which is useful in some complex models.\n", + "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Fn69xxPO5Psr" + }, + "source": [ + "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n", + "Conv2D, LSTM, BatchNormalization, Dropout, and many others." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 204 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 244, + "status": "ok", + "timestamp": 1527783641557, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "E3XKNknP5Mhb", + "outputId": "c5d52434-d980-4488-efa7-5660819d0207" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u003ctf.Tensor: id=30, shape=(10, 10), dtype=float32, numpy=\n", + "array([[ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)\u003e" + ] + }, + "execution_count": 3, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# To use a layer, simply call it.\n", + "layer(tf.zeros([10, 5]))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 320, + "status": "ok", + "timestamp": 1527783642457, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "Wt_Nsv-L5t2s", + "outputId": "f0d96dce-0128-4080-bfe2-0ee6fbc0ad90" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", + " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", + " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", + " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", + " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", + " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", + " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", + " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", + " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", + " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e]" + ] + }, + "execution_count": 4, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Layers have many useful methods. For example, you can inspect all variables\n", + "# in a layer by calling layer.variables. In this case a fully-connected layer\n", + "# will have variables for weights and biases.\n", + "layer.variables" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 221 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 226, + "status": "ok", + "timestamp": 1527783643252, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "6ilvKjz8_4MQ", + "outputId": "f647fced-c2d7-41a3-c237-242036784665" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(\u003ctf.Variable 'dense_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + " array([[ 0.43788117, -0.62099844, -0.30525017, -0.59352523, 0.1783089 ,\n", + " 0.47078604, -0.23620895, -0.30482283, 0.01366901, -0.1288507 ],\n", + " [ 0.18407935, -0.56550485, 0.54180616, -0.42254075, 0.3702994 ,\n", + " 0.36705834, -0.29678228, 0.36660975, 0.36717761, 0.46269661],\n", + " [ 0.1709305 , -0.11529458, 0.32710236, 0.46300393, -0.62802851,\n", + " 0.51641601, 0.39624029, 0.26918125, -0.25196898, 0.21353298],\n", + " [ 0.35752094, 0.44161648, 0.61500639, -0.12653333, 0.41629118,\n", + " 0.36193585, 0.066082 , -0.59253877, 0.47318751, 0.17115968],\n", + " [-0.22554061, -0.17727301, 0.5525015 , 0.3678053 , -0.00454676,\n", + " 0.24066836, -0.53640735, 0.13792562, -0.10727292, 0.59708995]], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32, numpy=array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)\u003e)" + ] + }, + "execution_count": 5, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# The variables are also accessible through nice accessors\n", + "layer.kernel, layer.bias" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "O0kDbE54-5VS" + }, + "source": [ + "## Implementing custom layers\n", + "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n", + " * `__init__` , where you can do all input-independent initialization\n", + " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n", + " * `call`, where you do the forward computation\n", + "\n", + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 391 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 251, + "status": "ok", + "timestamp": 1527783661512, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "5Byl3n1k5kIy", + "outputId": "6e7f9285-649a-4132-82ce-73ea92f15862" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n", + " [ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(10, 10), dtype=float32)\n", + "[\u003ctf.Variable 'my_dense_layer_1/kernel:0' shape=(5, 10) dtype=float32, numpy=\n", + "array([[-0.4011991 , 0.22458655, -0.33237562, -0.25117266, 0.33528614,\n", + " -0.01392961, 0.58580834, -0.16346583, 0.28465688, -0.47191954],\n", + " [-0.52922136, 0.22416979, -0.58209574, -0.60914612, 0.05226624,\n", + " -0.18325993, 0.5591442 , -0.24718609, 0.37148207, 0.40475875],\n", + " [ 0.16912812, -0.47618777, -0.38989353, 0.30105609, -0.08085585,\n", + " 0.44758242, 0.545829 , 0.51421839, 0.11063248, 0.20159996],\n", + " [ 0.34073615, -0.59835428, 0.06498981, -0.44489855, -0.34302285,\n", + " 0.20969599, 0.35527444, -0.03173476, -0.22227573, 0.09303057],\n", + " [ 0.41764337, -0.06435019, -0.52509922, -0.39957345, 0.56811184,\n", + " 0.23481232, -0.61666459, 0.31144124, -0.11532354, -0.42421889]], dtype=float32)\u003e]\n" + ] + } + ], + "source": [ + "class MyDenseLayer(tf.keras.layers.Layer):\n", + " def __init__(self, num_outputs):\n", + " super(MyDenseLayer, self).__init__()\n", + " self.num_outputs = num_outputs\n", + " \n", + " def build(self, input_shape):\n", + " self.kernel = self.add_variable(\"kernel\", \n", + " shape=[input_shape[-1].value, \n", + " self.num_outputs])\n", + " \n", + " def call(self, input):\n", + " return tf.matmul(input, self.kernel)\n", + " \n", + "layer = MyDenseLayer(10)\n", + "print(layer(tf.zeros([10, 5])))\n", + "print(layer.variables)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "tk8E2vY0-z4Z" + }, + "source": [ + "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n", + "\n", + "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Qhg4KlbKrs3G" + }, + "source": [ + "## Models: composing layers\n", + "\n", + "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n", + "\n", + "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 190 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 420, + "status": "ok", + "timestamp": 1527783698512, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "N30DTXiRASlb", + "outputId": "a8b23a8e-5cf9-4bbf-f93b-6c763d74e2b3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[[[ 0. 0. 0.]\n", + " [ 0. 0. 0.]\n", + " [ 0. 0. 0.]]\n", + "\n", + " [[ 0. 0. 0.]\n", + " [ 0. 0. 0.]\n", + " [ 0. 0. 0.]]]], shape=(1, 2, 3, 3), dtype=float32)\n", + "['resnet_identity_block_1/conv2d_3/kernel:0', 'resnet_identity_block_1/conv2d_3/bias:0', 'resnet_identity_block_1/batch_normalization_3/gamma:0', 'resnet_identity_block_1/batch_normalization_3/beta:0', 'resnet_identity_block_1/conv2d_4/kernel:0', 'resnet_identity_block_1/conv2d_4/bias:0', 'resnet_identity_block_1/batch_normalization_4/gamma:0', 'resnet_identity_block_1/batch_normalization_4/beta:0', 'resnet_identity_block_1/conv2d_5/kernel:0', 'resnet_identity_block_1/conv2d_5/bias:0', 'resnet_identity_block_1/batch_normalization_5/gamma:0', 'resnet_identity_block_1/batch_normalization_5/beta:0', 'resnet_identity_block_1/batch_normalization_3/moving_mean:0', 'resnet_identity_block_1/batch_normalization_3/moving_variance:0', 'resnet_identity_block_1/batch_normalization_4/moving_mean:0', 'resnet_identity_block_1/batch_normalization_4/moving_variance:0', 'resnet_identity_block_1/batch_normalization_5/moving_mean:0', 'resnet_identity_block_1/batch_normalization_5/moving_variance:0']\n" + ] + } + ], + "source": [ + "class ResnetIdentityBlock(tf.keras.Model):\n", + " def __init__(self, kernel_size, filters):\n", + " super(ResnetIdentityBlock, self).__init__(name='')\n", + " filters1, filters2, filters3 = filters\n", + "\n", + " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n", + " self.bn2a = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n", + " self.bn2b = tf.keras.layers.BatchNormalization()\n", + "\n", + " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n", + " self.bn2c = tf.keras.layers.BatchNormalization()\n", + "\n", + " def call(self, input_tensor, training=False):\n", + " x = self.conv2a(input_tensor)\n", + " x = self.bn2a(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2b(x)\n", + " x = self.bn2b(x, training=training)\n", + " x = tf.nn.relu(x)\n", + "\n", + " x = self.conv2c(x)\n", + " x = self.bn2c(x, training=training)\n", + "\n", + " x += input_tensor\n", + " return tf.nn.relu(x)\n", + "\n", + " \n", + "block = ResnetIdentityBlock(1, [1, 2, 3])\n", + "print(block(tf.zeros([1, 2, 3, 3])))\n", + "print([x.name for x in block.variables])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "wYfucVw65PMj" + }, + "source": [ + "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "base_uri": "https://localhost:8080/", + "height": 153 + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 361, + "status": "ok", + "timestamp": 1526674830777, + "user": { + "displayName": "Alexandre Passos", + "photoUrl": "//lh4.googleusercontent.com/-kmTTWXEgAPw/AAAAAAAAAAI/AAAAAAAAAC0/q_DoOzKGwds/s50-c-k-no/photo.jpg", + "userId": "108023195365833072773" + }, + "user_tz": 420 + }, + "id": "L9frk7Ur4uvJ", + "outputId": "882e9076-b6d9-4380-bb1e-7c6b57d54c39" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u003ctf.Tensor: id=1423, shape=(1, 2, 3, 3), dtype=float32, numpy=\n", + "array([[[[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]],\n", + "\n", + " [[0., 0., 0.],\n", + " [0., 0., 0.],\n", + " [0., 0., 0.]]]], dtype=float32)\u003e" + ] + }, + "execution_count": 26, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(2, 1, \n", + " padding='same'),\n", + " tf.keras.layers.BatchNormalization(),\n", + " tf.keras.layers.Conv2D(3, (1, 1)),\n", + " tf.keras.layers.BatchNormalization()])\n", + "my_seq(tf.zeros([1, 2, 3, 3]))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "c5YwYcnuK-wc" + }, + "source": [ + "# Next steps\n", + "\n", + "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured." + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "default_view": {}, + "name": "4 - High level API - TensorFlow Eager.ipynb", + "provenance": [], + "version": "0.3.2", + "views": {} + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 2d51cfdeee3f0b45514af0895366417158b01614..b14ef1df8ff4c660b9b6f2abfd5df6572d10b1e8 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -49,15 +49,17 @@ def random_batch(batch_size, data_format): return images, one_hot -def train_one_step(model, images, labels, optimizer): - +def compute_gradients(model, images, labels): with tf.GradientTape() as tape: logits = model(images, training=True) loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) tf.contrib.summary.scalar(name='loss', tensor=loss) - grads = tape.gradient(loss, model.variables) - optimizer.apply_gradients(zip(grads, model.variables)) + return tape.gradient(loss, model.variables) + + +def apply_gradients(model, optimizer, gradients): + optimizer.apply_gradients(zip(gradients, model.variables)) class ResNet50Test(tf.test.TestCase): @@ -114,7 +116,8 @@ class ResNet50Test(tf.test.TestCase): with tf.device(device), tfe.execution_mode(execution_mode): optimizer = tf.train.GradientDescentOptimizer(0.1) images, labels = random_batch(2, data_format) - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) self.assertEqual(320, len(model.variables)) tfe.async_wait() events = summary_test_util.events_from_logdir(logdir) @@ -138,14 +141,16 @@ class ResNet50Test(tf.test.TestCase): # garbage to be collected. The hope is that this is a build-only effect, # and a subsequent training loop will create nothing which needs to be # collected. - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) gc.collect() previous_gc_debug_flags = gc.get_debug() gc.set_debug(gc.DEBUG_SAVEALL) for _ in range(2): # Run twice to ensure that garbage that is created on the first # iteration is no longer accessible. - train_one_step(model, images, labels, optimizer) + apply_gradients(model, optimizer, + compute_gradients(model, images, labels)) gc.collect() # There should be no garbage requiring collection. self.assertEqual(0, len(gc.garbage)) @@ -180,9 +185,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): return (16, 32, 64) if tf.DeviceSpec.from_string(device.name).device_type == 'TPU': - # TODO(iga): Training fails with batch size of 16, probably because of - # no layout optimizations with op-by-op mode. Investigate more. - return (8,) + return (32,) return (16, 32) def _report(self, label, start, num_iters, device, batch_size, data_format): @@ -248,18 +251,21 @@ class ResNet50Benchmarks(tf.test.Benchmark): device, data_format = device_and_format for batch_size in self._train_batch_sizes(): (images, labels) = random_batch(batch_size, data_format) - num_burn = 3 - num_iters = 10 model = resnet50.ResNet50(data_format) + optimizer = tf.train.GradientDescentOptimizer(0.1) + apply_grads = apply_gradients if defun: model.call = tfe.defun(model.call, compiled=compiled) - optimizer = tf.train.GradientDescentOptimizer(0.1) + apply_grads = tfe.defun(apply_gradients, compiled=compiled) + num_burn = 3 + num_iters = 10 with tf.device(device): iterator = make_iterator((images, labels)) for _ in xrange(num_burn): (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) if execution_mode: tfe.async_wait() self._force_device_sync() @@ -268,7 +274,8 @@ class ResNet50Benchmarks(tf.test.Benchmark): start = time.time() for _ in xrange(num_iters): (images, labels) = iterator.next() - train_one_step(model, images, labels, optimizer) + apply_grads(model, optimizer, + compute_gradients(model, images, labels)) if execution_mode: tfe.async_wait() self._force_device_sync() diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 492adbe1d80941f9df96d6636e4933d11239408e..5ee2176154ec7011dcb3d7b384a86213e778014f 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -152,7 +152,7 @@ class RNNColorbot(tf.keras.Model): self.label_dimension = label_dimension self.keep_prob = keep_prob - self.cells = self._add_cells( + self.cells = tf.contrib.checkpoint.List( [tf.nn.rnn_cell.BasicLSTMCell(size) for size in rnn_cell_sizes]) self.relu = layers.Dense( label_dimension, activation=tf.nn.relu, name="relu") @@ -204,14 +204,6 @@ class RNNColorbot(tf.keras.Model): hidden_states = tf.gather_nd(chars, indices) return self.relu(hidden_states) - def _add_cells(self, cells): - # "Magic" required for keras.Model classes to track all the variables in - # a list of layers.Layer objects. - # TODO(ashankar): Figure out API so user code doesn't have to do this. - for i, c in enumerate(cells): - setattr(self, "cell-%d" % i, c) - return cells - def loss(labels, predictions): """Computes mean squared loss.""" diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 74701b2f4f7448c5f6c2c1bd7c67a8ee112f4115..c2340a293a80924f2dfa90e2fb23134b0f1feb6b 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -50,7 +50,7 @@ class RNN(tf.keras.Model): def __init__(self, hidden_dim, num_layers, keep_ratio): super(RNN, self).__init__() self.keep_ratio = keep_ratio - self.cells = self._add_cells([ + self.cells = tf.contrib.checkpoint.List([ tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim) for _ in range(num_layers) ]) @@ -74,14 +74,6 @@ class RNN(tf.keras.Model): # tuple (output, output_states). return [input_seq] - def _add_cells(self, cells): - # "Magic" required for keras.Model classes to track all the variables in - # a list of Layer objects. - # TODO(ashankar): Figure out API so user code doesn't have to do this. - for i, c in enumerate(cells): - setattr(self, "cell-%d" % i, c) - return cells - class Embedding(layers.Layer): """An Embedding layer.""" diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 1ae6415d5ecb03ef97cdf734c808e3f728dafcb0..c947ed9dcc415670a820f8a5cd9eaaf07334cfc3 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -25,6 +25,7 @@ from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -367,6 +368,9 @@ class Accuracy(Mean): Returns: The arguments, for easy chaining. """ + check_ops.assert_equal( + array_ops.shape(labels), array_ops.shape(predictions), + message="Shapes of labels and predictions are unequal") matches = math_ops.equal(labels, predictions) matches = math_ops.cast(matches, dtypes.float64) super(Accuracy, self).call(matches, weights=weights) diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 98a98a8d358d8a7d7a06505ed1a7d4c0ff1e18f4..02ee05487515b81bfae70d02c1dfdb6d816b77c7 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops @@ -117,6 +118,11 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) + def testAccuracyDifferentShapes(self): + m = metrics.Accuracy() + with self.assertRaises(errors.InvalidArgumentError): + m([[0], [0]], [0, 1]) + def testWeightedAccuracy(self): m = metrics.Accuracy() # 1 correct, total weight of 2 diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 4032e755f6e7dea9dcb42587f14e8386e5db2338..90a3711475719a7f991473c6c9067da1e76ab9f2 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -60,15 +60,9 @@ class SaverTest(test.TestCase): def testSameNameNoClobbering(self): with ops.device(self._dev()): - # Note that this test purposefully uses Graphs rather than - # IsolateTest. Users are more likely to accidentally create the same - # variable name this way. - first_graph = ops.Graph() - with first_graph.as_default(): - v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1') - with ops.Graph().as_default(): - v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1') - saver = _saver.Saver([v1_first_graph, v1_second_graph]) + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + v2 = resource_variable_ops.ResourceVariable(2.0, name='v1') + saver = _saver.Saver([v1, v2]) ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') with self.assertRaisesRegexp(ValueError, 'v1'): saver.save(ckpt_prefix) @@ -126,12 +120,11 @@ class SaverTest(test.TestCase): saver = _saver.Saver([v1]) saver.save(ckpt_prefix) - with ops.Graph().as_default(): - saver = _saver.Saver([v1]) - with _saver.restore_variables_on_create(ckpt_prefix): - # Value is from checkpoint, but not from argument. - ret, _ = model(2.0) - self.assertEqual(ret.numpy(), 1.0) + saver = _saver.Saver([v1]) + with _saver.restore_variables_on_create(ckpt_prefix): + # Value is from checkpoint, but not from argument. + ret, _ = model(2.0) + self.assertEqual(ret.numpy(), 1.0) def testRestoreNotFound(self): with ops.device(self._dev()): @@ -184,17 +177,17 @@ class SaverTest(test.TestCase): 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) # reset the graph and reload on create, so that 1 + 2 = 3 - with ops.Graph().as_default(): - with _saver.restore_variables_on_create(ckpt_prefix): - @graph_callable.graph_callable( - [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) - def model2(x): - v = variable_scope.get_variable( - 'v', initializer=init_ops.zeros_initializer(), shape=()) - return v + x - - self.assertEqual( - 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + ops.reset_default_graph() + with _saver.restore_variables_on_create(ckpt_prefix): + @graph_callable.graph_callable( + [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) + def model2(x): + v = variable_scope.get_variable( + 'v', initializer=init_ops.zeros_initializer(), shape=()) + return v + x + + self.assertEqual( + 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) class GetOptimizerTests(test.TestCase): diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 5826700c73e255198e9a6974ca240ba55e438a26..fee9db46fa4f79d7dd613436726e8ddad51faf1c 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -115,6 +115,7 @@ from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes +from tensorflow.python.framework.test_util import run_all_in_graph_and_eager_modes as run_all_tests_in_graph_and_eager_modes from tensorflow.python.ops.custom_gradient import custom_gradient from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index d5d2abf8c4c82374842ed2e10a849765a6dddd3b..1937ffb583bc727df76470d072b35fb3c9acaa88 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -312,6 +312,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", + "//tensorflow/python:variables", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", @@ -340,6 +341,7 @@ py_test( size = "medium", srcs = ["python/estimator/hooks_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":hooks", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 8b97f86db19a1bc2d9f17c9935e6678844daf177..9594e5132fd20dadea118fd1dd6768feb7fd7fff 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -529,6 +529,7 @@ def multi_label_head(n_classes, applications, the shape is `[batch_size, n_classes]`. Labels can be: + * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` * An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. @@ -845,6 +846,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access train_op = train_op_fn(regularized_training_loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = head_lib._append_update_ops(train_op) # pylint:disable=protected-access # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index d6c158608b5c564f24bc90583084306aa7084742..b2b57fa06ba818d4455871fe57dde5ce287b39a2 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -989,6 +990,34 @@ class MultiLabelHead(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) + def test_train_with_update_ops(self): + head = head_lib.multi_label_head(n_classes=2) + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=np.array([[1, 0], [1, 1]], dtype=np.int64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_with_regularization_losses(self): head = head_lib.multi_label_head( n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) diff --git a/tensorflow/contrib/estimator/python/estimator/hooks.py b/tensorflow/contrib/estimator/python/estimator/hooks.py index 4808b9ee30e10047aaf3d33f74457b2717c87a13..ddd6aa442f82bad2d4714dbcdc85b20b34773068 100644 --- a/tensorflow/contrib/estimator/python/estimator/hooks.py +++ b/tensorflow/contrib/estimator/python/estimator/hooks.py @@ -72,7 +72,7 @@ class InMemoryEvaluatorHook(training.SessionRunHook): estimator: A `tf.estimator.Estimator` instance to call evaluate. input_fn: Equivalent to the `input_fn` arg to `estimator.evaluate`. A function that constructs the input data for evaluation. - See @{$get_started/premade_estimators#create_input_functions} for more + See @{$premade_estimators#create_input_functions} for more information. The function should construct and return one of the following: diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops.py b/tensorflow/contrib/factorization/python/ops/factorization_ops.py index 5cef4068ed119d5dbccd585c5b4e5e28840d2cc7..8f73274c2a0ebbdc41ce6a647a8a5650694c9a23 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops.py @@ -197,7 +197,8 @@ class WALSModel(object): row_weights=1, col_weights=1, use_factors_weights_cache=True, - use_gramian_cache=True): + use_gramian_cache=True, + use_scoped_vars=False): """Creates model for WALS matrix factorization. Args: @@ -239,6 +240,8 @@ class WALSModel(object): weights cache to take effect. use_gramian_cache: When True, the Gramians will be cached on the workers before the updates start. Defaults to True. + use_scoped_vars: When True, the factor and weight vars will also be nested + in a tf.name_scope. """ self._input_rows = input_rows self._input_cols = input_cols @@ -251,25 +254,46 @@ class WALSModel(object): regularization * linalg_ops.eye(self._n_components) if regularization is not None else None) assert (row_weights is None) == (col_weights is None) - self._row_weights = WALSModel._create_weights( - row_weights, self._input_rows, self._num_row_shards, "row_weights") - self._col_weights = WALSModel._create_weights( - col_weights, self._input_cols, self._num_col_shards, "col_weights") self._use_factors_weights_cache = use_factors_weights_cache self._use_gramian_cache = use_gramian_cache - self._row_factors = self._create_factors( - self._input_rows, self._n_components, self._num_row_shards, row_init, - "row_factors") - self._col_factors = self._create_factors( - self._input_cols, self._n_components, self._num_col_shards, col_init, - "col_factors") + + if use_scoped_vars: + with ops.name_scope("row_weights"): + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + with ops.name_scope("col_weights"): + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + with ops.name_scope("row_factors"): + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, + row_init, "row_factors") + with ops.name_scope("col_factors"): + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, + col_init, "col_factors") + else: + self._row_weights = WALSModel._create_weights( + row_weights, self._input_rows, self._num_row_shards, "row_weights") + self._col_weights = WALSModel._create_weights( + col_weights, self._input_cols, self._num_col_shards, "col_weights") + self._row_factors = self._create_factors( + self._input_rows, self._n_components, self._num_row_shards, row_init, + "row_factors") + self._col_factors = self._create_factors( + self._input_cols, self._n_components, self._num_col_shards, col_init, + "col_factors") + self._row_gramian = self._create_gramian(self._n_components, "row_gramian") self._col_gramian = self._create_gramian(self._n_components, "col_gramian") - self._row_update_prep_gramian = self._prepare_gramian( - self._col_factors, self._col_gramian) - self._col_update_prep_gramian = self._prepare_gramian( - self._row_factors, self._row_gramian) - self._create_transient_vars() + with ops.name_scope("row_prepare_gramian"): + self._row_update_prep_gramian = self._prepare_gramian( + self._col_factors, self._col_gramian) + with ops.name_scope("col_prepare_gramian"): + self._col_update_prep_gramian = self._prepare_gramian( + self._row_factors, self._row_gramian) + with ops.name_scope("transient_vars"): + self._create_transient_vars() @property def row_factors(self): diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py index 555beddeaab419bcb23d06f960d370b706d744c8..b588f75efe9d0bbf8213a89978a627c0a0ccf554 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py @@ -346,7 +346,8 @@ def sequence_numeric_column( key, shape=(1,), default_value=0., - dtype=dtypes.float32): + dtype=dtypes.float32, + normalizer_fn=None): """Returns a feature column that represents sequences of numeric data. Example: @@ -370,6 +371,12 @@ def sequence_numeric_column( default_value: A single value compatible with `dtype` that is used for padding the sparse data into a dense `Tensor`. dtype: The type of values. + normalizer_fn: If not `None`, a function that can be used to normalize the + value of the tensor after `default_value` is applied for parsing. + Normalizer function takes the input `Tensor` as its argument, and returns + the output `Tensor`. (e.g. lambda x: (x - 3.0) / 4.2). Please note that + even though the most common use case of this function is normalization, it + can be used for any kind of Tensorflow transformations. Returns: A `_SequenceNumericColumn`. @@ -383,12 +390,16 @@ def sequence_numeric_column( if not (dtype.is_integer or dtype.is_floating): raise ValueError('dtype must be convertible to float. ' 'dtype: {}, key: {}'.format(dtype, key)) + if normalizer_fn is not None and not callable(normalizer_fn): + raise TypeError( + 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) return _SequenceNumericColumn( key, shape=shape, default_value=default_value, - dtype=dtype) + dtype=dtype, + normalizer_fn=normalizer_fn) def _assert_all_equal_and_return(tensors, name=None): @@ -407,7 +418,7 @@ class _SequenceNumericColumn( fc._SequenceDenseColumn, collections.namedtuple( '_SequenceNumericColumn', - ['key', 'shape', 'default_value', 'dtype'])): + ['key', 'shape', 'default_value', 'dtype', 'normalizer_fn'])): """Represents sequences of numeric data.""" @property @@ -419,7 +430,10 @@ class _SequenceNumericColumn( return {self.key: parsing_ops.VarLenFeature(self.dtype)} def _transform_feature(self, inputs): - return inputs.get(self.key) + input_tensor = inputs.get(self.key) + if self.normalizer_fn is not None: + input_tensor = self.normalizer_fn(input_tensor) + return input_tensor @property def _variable_shape(self): diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index 88f5d535162939e063eb1e7f43d495137c5adef4..89b5f4c4137f6c42417f539a578fd8b11f8b235d 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test from tensorflow.python.training import monitored_session @@ -670,6 +671,7 @@ class SequenceNumericColumnTest(test.TestCase): self.assertEqual((1,), a.shape) self.assertEqual(0., a.default_value) self.assertEqual(dtypes.float32, a.dtype) + self.assertIsNone(a.normalizer_fn) def test_shape_saved_as_tuple(self): a = sfc.sequence_numeric_column('aaa', shape=[1, 2]) @@ -688,6 +690,10 @@ class SequenceNumericColumnTest(test.TestCase): ValueError, 'dtype must be convertible to float'): sfc.sequence_numeric_column('aaa', dtype=dtypes.string) + def test_normalizer_fn_must_be_callable(self): + with self.assertRaisesRegexp(TypeError, 'must be a callable'): + sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable') + def test_get_sequence_dense_tensor(self): sparse_input = sparse_tensor.SparseTensorValue( # example 0, values [[0.], [1]] @@ -708,6 +714,41 @@ class SequenceNumericColumnTest(test.TestCase): self.assertAllEqual( expected_dense_tensor, dense_tensor.eval(session=sess)) + def test_get_sequence_dense_tensor_with_normalizer_fn(self): + + def _increment_two(input_sparse_tensor): + return sparse_ops.sparse_add( + input_sparse_tensor, + sparse_tensor.SparseTensor(((0, 0), (1, 1)), (2.0, 2.0), (2, 2)) + ) + + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + + # Before _increment_two: + # [[0.], [1.]], + # [[10.], [0.]], + # After _increment_two: + # [[2.], [1.]], + # [[10.], [2.]], + expected_dense_tensor = [ + [[2.], [1.]], + [[10.], [2.]], + ] + numeric_column = sfc.sequence_numeric_column( + 'aaa', normalizer_fn=_increment_two) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + def test_get_sequence_dense_tensor_with_shape(self): """Tests get_sequence_dense_tensor with shape !=(1,).""" sparse_input = sparse_tensor.SparseTensorValue( diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index daba965a98893b992abdc598ec713f13020d6e91..484ffee3e7afe55c63cab2a463454353b2663e18 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -28,7 +28,6 @@ from __future__ import print_function from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio -from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 020b5c99c61019254bef0b1dff6bc5901c92758a..b1b5126d9e9e5196a1733b80e0778e53cef7f774 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -21,7 +21,6 @@ from __future__ import print_function from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py -from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 10d1ecc738de6777784200ba934a521dff592e28..dc49383c5c300e82839c478e097074b3e8776b3b 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -119,14 +119,13 @@ from tensorflow.python.framework.smart_cond import smart_cond from tensorflow.python.framework.smart_cond import smart_constant_value from tensorflow.python.framework.tensor_spec import BoundedTensorSpec from tensorflow.python.framework.tensor_spec import TensorSpec -from tensorflow.python.ops.array_ops import broadcast_to from tensorflow.python.ops.init_ops import convolutional_delta_orthogonal from tensorflow.python.ops.init_ops import convolutional_orthogonal_1d from tensorflow.python.ops.init_ops import convolutional_orthogonal_2d from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['nest', 'broadcast_to'] +_allowed_symbols = ['nest'] _nest_allowed_symbols = [ 'assert_same_structure', 'is_sequence', diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index 0eb6889db1fae1c74aeb4392441b308392b091a5..0f0813c07f8bd330b089780064e02f8dfe7d49f6 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -75,6 +75,7 @@ tf_kernel_library( "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", "//third_party/eigen3", + "@local_config_cuda//cuda:cudnn_header", ], alwayslink = 1, ) @@ -94,6 +95,7 @@ tf_custom_op_library( "//tensorflow/core/kernels:conv_ops_gpu_hdrs", "//tensorflow/core/kernels:gpu_util_hdrs", "//tensorflow/core/kernels:ops_util_hdrs", + "@local_config_cuda//cuda:cudnn_header", ], ) diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 3d0ed899322c26bf4ae428930899d7a5885e9f21..a955e21b72e765f751318c7927f9644481fe7933 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.fused_conv.python.ops import fused_conv2d_bias_activation_op +from tensorflow.core.protobuf import config_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -33,6 +35,13 @@ from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging +def NoMemoryOptimizationConfig(): + config = config_pb2.ConfigProto() + config.graph_options.rewrite_options.memory_optimization = ( + rewriter_config_pb2.RewriterConfig.OFF) + return config + + def GetShrunkInceptionShapes(shrink=10): """Iterator for smaller versions of convolution shapes in 2015 Inception. @@ -193,7 +202,8 @@ class FusedConv2DBiasActivationTest(test.TestCase): # This is to guarantee that there is always negative values after # bias add so that we can test whether relu works correctly. x3 = bias - with self.test_session(use_gpu=True): + # TODO(b/79323979): re-enable memory optimization after this bug is fixed. + with self.test_session(use_gpu=True, config=NoMemoryOptimizationConfig()): t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) fused_t2 = t2 @@ -241,7 +251,9 @@ class FusedConv2DBiasActivationTest(test.TestCase): x3 = np.random.rand(*[filter_in_sizes[-1]]).astype(np.float32) def _SetupVal(data_format, use_gpu): - with self.test_session(use_gpu=use_gpu): + # TODO(b/79323979): re-enable memory optimization after this bug is fixed. + with self.test_session( + use_gpu=use_gpu, config=NoMemoryOptimizationConfig()): t1 = constant_op.constant(x1, shape=tensor_in_sizes) t2 = constant_op.constant(x2, shape=filter_in_sizes) t3 = constant_op.constant(x3, shape=[filter_in_sizes[-1]]) @@ -289,8 +301,8 @@ class FusedConv2DBiasActivationTest(test.TestCase): conv = tensors[i] value = values[i] ref_value = ref_values[i] - print("expected = ", ref_value) - print("actual = ", value) + tf_logging.info("expected = ", ref_value) + tf_logging.info("actual = ", value) tol = 1e-5 if value.dtype == np.float16: tol = 1e-3 @@ -831,7 +843,8 @@ class FusedConvInt8Tests(test.TestCase): vertical_stride, padding_type) output_width = CalculateConvolvedOutputDim(input_width, filter_width, horizontal_stride, padding_type) - print("output_height=", output_height, ", output_width=", output_width) + tf_logging.info("output_height=", output_height, ", output_width=", + output_width) side_input, _, _ = gen_array_ops.quantize_v2( random_ops.random_uniform( @@ -864,10 +877,12 @@ class FusedConvInt8Tests(test.TestCase): conv_input_scale, conv_input, kernel, padding_type, strides, side_input_scale, side_input, biases) - with self.test_session(use_gpu=True) as sess: + # TODO(b/79323979): re-enable memory optimization after this bug is fixed. + with self.test_session( + use_gpu=True, config=NoMemoryOptimizationConfig()) as sess: actual_y, expected_y = sess.run([actual, expected]) - print("actual_y = ", actual_y) - print("expected_y = ", expected_y) + tf_logging.info("actual_y = ", actual_y) + tf_logging.info("expected_y = ", expected_y) self.assertTrue(np.array_equal(actual_y, expected_y)) def testFusedConvInt8(self): diff --git a/tensorflow/contrib/gdr/gdr_server_lib.cc b/tensorflow/contrib/gdr/gdr_server_lib.cc index 1f9dd0decb84cf9b7b703f18c061d3c0c7a1cb25..9025c992a4467f521d6d8d514e6a5e92f5492947 100644 --- a/tensorflow/contrib/gdr/gdr_server_lib.cc +++ b/tensorflow/contrib/gdr/gdr_server_lib.cc @@ -57,7 +57,7 @@ Status GdrServer::Init() { new GdrWorker(env, remote_memory_manager_.get())); }; TF_RETURN_IF_ERROR( - GrpcServer::Init(nullptr, rendezvous_mgr_func, worker_func)); + GrpcServer::Init(nullptr, rendezvous_mgr_func, nullptr, worker_func)); return remote_memory_manager_->Init(); } diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 592d37b432ee605d74162e0b8ec6ccdf426c45d1..026a3d1200033400472c4fd763a244c04b284a9b 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -189,9 +189,6 @@ def copy_op_handler(info, op, new_inputs, copy_shape=True, nodedef_fn=None): if op._original_op: op_._original_op = op._original_op - # Add op to the graph - info.graph_._add_op(op_) - return op_, op_.outputs @@ -492,7 +489,7 @@ class Transformer(object): t_ = info.transformed_ts[t] consumer_op_ = info.transformed_ops[consumer_op] t_index_ = list(consumer_op_.inputs).index(tmp_t_) - consumer_op_._update_input(t_index_, t_, update_dtype=False) # pylint: disable=protected-access + consumer_op_._update_input(t_index_, t_) # pylint: disable=protected-access def _connect_control_inputs(self, info): """Connect the previously copied ops.""" diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py index b4a99867ed46897f60be3f230838c3f576d5455e..61f78febfc07bb4e677259366a81c16b2b585244 100644 --- a/tensorflow/contrib/integrate/python/ops/odes.py +++ b/tensorflow/contrib/integrate/python/ops/odes.py @@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops @@ -279,13 +278,27 @@ def _assert_increasing(t): return ops.control_dependencies([assert_increasing]) -def _check_input_types(t, y0): +def _check_input_types(y0, t, dt=None): if not (y0.dtype.is_floating or y0.dtype.is_complex): raise TypeError('`y0` must have a floating point or complex floating ' 'point dtype') if not t.dtype.is_floating: raise TypeError('`t` must have a floating point dtype') + if dt is not None and not dt.dtype.is_floating: + raise TypeError('`dt` must have a floating point dtype') + + +def _check_input_sizes(t, dt): + if len(t.get_shape().as_list()) > 1: + raise ValueError('t must be a 1D tensor') + + if len(dt.get_shape().as_list()) > 1: + raise ValueError('t must be a 1D tensor') + + if t.get_shape()[0] != dt.get_shape()[0] + 1: + raise ValueError('t and dt have incompatible lengths, must be N and N-1') + def _dopri5(func, y0, @@ -510,7 +523,7 @@ def odeint(func, # avoiding the need to pack/unpack in user functions. y0 = ops.convert_to_tensor(y0, name='y0') t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') - _check_input_types(t, y0) + _check_input_types(y0, t) error_dtype = abs(y0).dtype rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol') @@ -530,24 +543,74 @@ def odeint(func, class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): """Base class for fixed-grid ODE integrators.""" - def integrate(self, evol_func, y0, time_grid): - time_delta_grid = time_grid[1:] - time_grid[:-1] - - scan_func = self._make_scan_func(evol_func) + def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals): + """Returns integrated values of differential equation on the `time grid`. + + Numerically integrates differential equation defined via time derivative + evaluator `evol_func` using fixed time steps specified in dt_grid. + + Args: + evol_func: Callable, evaluates time derivative of y at a given time. + y0: N-D Tensor holds initial values of the solution. + time_grid: 1-D Tensor holding the time points at which the solution + will be recorded, must have a floating dtype. + dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid + intervals. Must be a floating dtype and have one less element than that + of the time_grid. + steps_on_intervals: 1-D Tensor of integer dtype, must have the same size + as dt_grid. Specifies number of steps needed for every interval. Assumes + steps_on_intervals * dt_grid == time intervals. + + Returns: + (N+1)-D tensor, where the first dimension corresponds to different + time points. Contains the solved value of y for each desired time point in + `t`, with the initial value `y0` being the first element along the first + dimension. + """ - y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid), - y0) - return array_ops.concat([[y0], y_grid], axis=0) + iteration_func = self._make_iteration_func(evol_func, dt_grid) + integrate_interval = self._make_interval_integrator(iteration_func, + steps_on_intervals) - def _make_scan_func(self, evol_func): + num_times = array_ops.size(time_grid) + current_time = time_grid[0] + solution_array = tensor_array_ops.TensorArray(y0.dtype, num_times) + solution_array = solution_array.write(0, y0) - def scan_func(y, t_and_dt): - t, dt = t_and_dt + solution_array, _, _, _ = control_flow_ops.while_loop( + lambda _, __, ___, i: i < num_times, + integrate_interval, + (solution_array, y0, current_time, 1) + ) + solution_array = solution_array.stack() + solution_array.set_shape(time_grid.get_shape().concatenate(y0.get_shape())) + return solution_array + + def _make_iteration_func(self, evol_func, dt_grid): + """Returns a function that builds operations of a single time step.""" + + def iteration_func(y, t, dt_step, interval_step): + """Performs a single time step advance.""" + dt = dt_grid[interval_step - 1] dy = self._step_func(evol_func, t, dt, y) dy = math_ops.cast(dy, dtype=y.dtype) - return y + dy + return y + dy, t + dt, dt_step + 1, interval_step + + return iteration_func + + def _make_interval_integrator(self, iteration_func, interval_sizes): + """Returns a function that builds operations for interval integration.""" - return scan_func + def integrate_interval(solution_array, y, t, interval_num): + """Integrates y with fixed time step on interval `interval_num`.""" + y, t, _, _ = control_flow_ops.while_loop( + lambda _, __, j, interval_num: j < interval_sizes[interval_num - 1], + iteration_func, + (y, t, 0, interval_num) + ) + return solution_array.write(interval_num, y), y, t, interval_num + 1 + + return integrate_interval @abc.abstractmethod def _step_func(self, evol_func, t, dt, y): @@ -555,6 +618,7 @@ class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): class _MidpointFixedGridIntegrator(_FixedGridIntegrator): + """Fixed grid integrator implementing midpoint scheme.""" def _step_func(self, evol_func, t, dt, y): dt_cast = math_ops.cast(dt, y.dtype) @@ -563,6 +627,7 @@ class _MidpointFixedGridIntegrator(_FixedGridIntegrator): class _RK4FixedGridIntegrator(_FixedGridIntegrator): + """Fixed grid integrator implementing RK4 scheme.""" def _step_func(self, evol_func, t, dt, y): k1 = evol_func(y, t) @@ -575,7 +640,7 @@ class _RK4FixedGridIntegrator(_FixedGridIntegrator): return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6) -def odeint_fixed(func, y0, t, method='rk4', name=None): +def odeint_fixed(func, y0, t, dt=None, method='rk4', name=None): """ODE integration on a fixed grid (with no step size control). Useful in certain scenarios to avoid the overhead of adaptive step size @@ -590,6 +655,14 @@ def odeint_fixed(func, y0, t, method='rk4', name=None): `y`. The initial time point should be the first element of this sequence, and each time must be larger than the previous time. May have any floating point dtype. + dt: 0-D or 1-D Tensor providing time step suggestion to be used on time + integration intervals in `t`. 1-D Tensor should provide values + for all intervals, must have 1 less element than that of `t`. + If given a 0-D Tensor, the value is interpreted as time step suggestion + same for all intervals. If passed None, then time step is set to be the + t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by + insuring an integer number of steps per interval, potentially reducing the + time step. method: One of 'midpoint' or 'rk4'. name: Optional name for the resulting operation. @@ -602,16 +675,29 @@ def odeint_fixed(func, y0, t, method='rk4', name=None): Raises: ValueError: Upon caller errors. """ - with ops.name_scope(name, 'odeint_fixed', [y0, t]): + with ops.name_scope(name, 'odeint_fixed', [y0, t, dt]): t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') y0 = ops.convert_to_tensor(y0, name='y0') - _check_input_types(t, y0) + + intervals = t[1:] - t[:-1] + if dt is None: + dt = intervals + dt = ops.convert_to_tensor(dt, preferred_dtype=dtypes.float64, name='dt') + + steps_on_intervals = math_ops.ceil(intervals / dt) + dt = intervals / steps_on_intervals + steps_on_intervals = math_ops.cast(steps_on_intervals, dtype=dtypes.int32) + + _check_input_types(y0, t, dt) + _check_input_sizes(t, dt) with _assert_increasing(t): with ops.name_scope(method): if method == 'midpoint': - return _MidpointFixedGridIntegrator().integrate(func, y0, t) + return _MidpointFixedGridIntegrator().integrate(func, y0, t, dt, + steps_on_intervals) elif method == 'rk4': - return _RK4FixedGridIntegrator().integrate(func, y0, t) + return _RK4FixedGridIntegrator().integrate(func, y0, t, dt, + steps_on_intervals) else: raise ValueError('method not supported: {!s}'.format(method)) diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py index 3ec01212d25ca8dc6e13f340177a5e85138868d5..c7b4e2faa84e1a87cb1904b22eb0008ab1ee4be6 100644 --- a/tensorflow/contrib/integrate/python/ops/odes_test.py +++ b/tensorflow/contrib/integrate/python/ops/odes_test.py @@ -242,40 +242,56 @@ class InterpolationTest(test.TestCase): class OdeIntFixedTest(test.TestCase): - def _test_integrate_sine(self, method): + def _test_integrate_sine(self, method, t, dt=None): def evol_func(y, t): del t return array_ops.stack([y[1], -y[0]]) y0 = [0., 1.] - time_grid = np.linspace(0., 10., 200) - y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) with self.test_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( - y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2) + y_grid_array[:, 0], np.sin(t), rtol=1e-2, atol=1e-2) - def _test_integrate_gaussian(self, method): + def _test_integrate_gaussian(self, method, t, dt=None): def evol_func(y, t): return -math_ops.cast(t, dtype=y.dtype) * y[0] y0 = [1.] - time_grid = np.linspace(0., 2., 100) - y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method) with self.test_session() as sess: y_grid_array = sess.run(y_grid) np.testing.assert_allclose( - y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2) + y_grid_array[:, 0], np.exp(-t**2 / 2), rtol=1e-2, atol=1e-2) + + def _test_integrate_sine_all(self, method): + uniform_time_grid = np.linspace(0., 10., 200) + non_uniform_time_grid = np.asarray([0.0, 0.4, 4.7, 5.2, 7.0]) + uniform_dt = 0.02 + non_uniform_dt = np.asarray([0.01, 0.001, 0.05, 0.03]) + self._test_integrate_sine(method, uniform_time_grid) + self._test_integrate_sine(method, non_uniform_time_grid, uniform_dt) + self._test_integrate_sine(method, non_uniform_time_grid, non_uniform_dt) + + def _test_integrate_gaussian_all(self, method): + uniform_time_grid = np.linspace(0., 2., 100) + non_uniform_time_grid = np.asarray([0.0, 0.1, 0.7, 1.2, 2.0]) + uniform_dt = 0.01 + non_uniform_dt = np.asarray([0.01, 0.001, 0.1, 0.03]) + self._test_integrate_gaussian(method, uniform_time_grid) + self._test_integrate_gaussian(method, non_uniform_time_grid, uniform_dt) + self._test_integrate_gaussian(method, non_uniform_time_grid, non_uniform_dt) def _test_everything(self, method): - self._test_integrate_sine(method) - self._test_integrate_gaussian(method) + self._test_integrate_sine_all(method) + self._test_integrate_gaussian_all(method) def test_midpoint(self): self._test_everything('midpoint') @@ -283,6 +299,21 @@ class OdeIntFixedTest(test.TestCase): def test_rk4(self): self._test_everything('rk4') + def test_dt_size_exceptions(self): + times = np.linspace(0., 2., 100) + dt = np.ones(99) * 0.01 + dt_wrong_length = np.asarray([0.01, 0.001, 0.1, 0.03]) + dt_wrong_dim = np.expand_dims(np.linspace(0., 2., 99), axis=0) + times_wrong_dim = np.expand_dims(np.linspace(0., 2., 100), axis=0) + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times, dt_wrong_length) + + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times, dt_wrong_dim) + + with self.assertRaises(ValueError): + self._test_integrate_gaussian('midpoint', times_wrong_dim, dt) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc index a4cd4a2cc4b99b5906185bd2b942ed15c1ddf5e4..2638b25ec424b5b4ef556ff769e94e64da32fec2 100644 --- a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -64,7 +64,7 @@ class KafkaDatasetOp : public DatasetOpKernel { eof_(eof), timeout_(timeout) {} - std::unique_ptr MakeIterator( + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr( new Iterator({this, strings::StrCat(prefix, "::Kafka")})); @@ -81,7 +81,7 @@ class KafkaDatasetOp : public DatasetOpKernel { return *shapes; } - string DebugString() override { return "KafkaDatasetOp::Dataset"; } + string DebugString() const override { return "KafkaDatasetOp::Dataset"; } protected: Status AsGraphDefInternal(DatasetGraphDefBuilder* b, diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md index 762a2f0b57e95e2fef3dd177070701afb410e93a..102626925db560e47cdc73eb1e25e08836cb4fba 100644 --- a/tensorflow/contrib/kfac/README.md +++ b/tensorflow/contrib/kfac/README.md @@ -1,5 +1,10 @@ # K-FAC: Kronecker-Factored Approximate Curvature +# WARNING: +# ==third_party/tensorflow/contrib/kfac is deprecated. This will be== +# ==removed on 15-07-2018. Please import third_party/tensorflow_kfac.== +# ==== + **K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an approximate second-order optimization method, in TensorFlow. When applied to feedforward and convolutional neural networks, K-FAC can converge `>3.5x` diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index b7f63d8d94a7a427eb57afefeda3939f0c530f8e..03b9da793307b966632789fd11162306e6cd19f9 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import warnings + # pylint disable=long-line from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp from tensorflow.contrib.kfac.python.ops import estimator as est @@ -107,6 +109,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ + warnings.warn( + "third_party.tensorflow.contrib.kfac is deprecated." + "This will be removed on 15-07-2018. Check README for further details.", + DeprecationWarning) # Parameters to be passed to the Fisher estimator: self._variables = var_list or tf_variables.trainable_variables self._cov_ema_decay = cov_ema_decay diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 00f03a111ae8be7f49761ef5fb5a82810bcca182..bc3359693562deb1229a78a2db5c256c76f7fd8d 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -19,6 +19,8 @@ See the @{$python/contrib.layers} guide. @@avg_pool2d @@avg_pool3d @@batch_norm +@@convolution +@@convolution1d @@convolution2d @@convolution3d @@conv2d_in_plane diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 2f3e57653c5d6d949c4dcc91635690322b7f90c4..b6d63c9640611abdda65f1205f544ee505dae1f0 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -57,10 +57,10 @@ from tensorflow.python.training import moving_averages __all__ = [ 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d', 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', - 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', - 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', - 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', - 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', + 'convolution1d', 'convolution2d', 'convolution2d_in_plane', + 'convolution2d_transpose', 'convolution3d', 'convolution3d_transpose', + 'dense_to_sparse', 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', + 'gdn', 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat', 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', 'sequence_to_images', 'softmax', 'spatial_softmax', 'stack', 'unit_norm', @@ -2022,6 +2022,7 @@ class GDN(base.Layer): def beta_initializer(shape, dtype=None, partition_info=None): del partition_info # unused + pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) return math_ops.sqrt(array_ops.ones(shape, dtype=dtype) + pedestal) def gamma_initializer(shape, dtype=None, partition_info=None): @@ -2029,6 +2030,7 @@ class GDN(base.Layer): assert len(shape) == 2 assert shape[0] == shape[1] eye = linalg_ops.eye(shape[0], dtype=dtype) + pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) return math_ops.sqrt(self._gamma_init * eye + pedestal) beta = self.add_variable( diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 56e9194cebbe46907707f7ac0996f9a56fb53c0f..c5c7269b1f15849956e90654e3bcf8ab0eebc393 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1312,6 +1312,29 @@ class ConvolutionInPlaneTest(test.TestCase): self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5) + def testConv1dShape(self): + width = 7 + with self.test_session(): + images = random_ops.random_uniform((5, width, 3), seed=1) + output = layers_lib.convolution1d(images, 32, 3) + self.assertEqual(output.op.name, 'Conv/Relu') + self.assertListEqual(output.get_shape().as_list(), [5, width, 32]) + + def testConvInferSpatialDims(self): + depth, height, width = 7, 9, 11 + with self.test_session(): + images = np.random.uniform(size=(5, width, 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3]) + self.assertListEqual(output.get_shape().as_list(), [5, width, 32]) + images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3, 3]) + self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32]) + images = np.random.uniform(size=(5, depth, height, width, + 4)).astype(np.float32) + output = layers_lib.convolution(images, 32, [3, 3, 3]) + self.assertListEqual(output.get_shape().as_list(), + [5, depth, height, width, 32]) + class DenseToSparseTest(test.TestCase): diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 8ed9f446bcd5f222f486e43125dafc595852e5ce..0e35b1aa8bf682c1b4f7e8d974d3e8fad69e33cb 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -46,6 +46,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect __all__ = ["rev_block", "RevBlock", "recompute_grad"] @@ -449,6 +450,15 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False): `variable_scope(name, use_resource=True), which are the default in Eager mode and when running on TPU. + Warning: Because the function will be called again on the backwards pass, the + user should be careful to not use ops in their function that mutate state or + have randomness (for example, batch normalization or dropout). If the function + does have such operations, it is recommended that the function take the + `is_recomputing` keyword argument which will be `False` on the forward pass + and `True` on the backwards pass so that it can disable state changes when + `is_recomputing=True` (for example, not updating the moving averages in batch + normalization). + Args: fn: a function that takes Tensors (all as positional arguments) and returns a tuple of Tensors. @@ -482,6 +492,7 @@ def _is_on_tpu(): def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): """See recompute_grad.""" + has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args for arg in args: if not isinstance(arg, framework_ops.Tensor): raise ValueError("All inputs to function must be Tensors") @@ -496,7 +507,10 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): vs = variable_scope.get_variable_scope() arg_scope = contrib_framework_ops.current_arg_scope() with backprop.GradientTape() as tape: - outputs = fn(*args) + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = False + outputs = fn(*args, **fn_kwargs) original_vars = set(tape.watched_variables()) # Backward pass @@ -516,7 +530,10 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False): with contrib_framework_ops.arg_scope(arg_scope): with variable_scope.variable_scope(vs, reuse=True): with backprop.GradientTape() as tape: - outputs = fn(*inputs) + fn_kwargs = {} + if has_is_recompute_kwarg: + fn_kwargs["is_recomputing"] = True + outputs = fn(*inputs, **fn_kwargs) recompute_vars = set(tape.watched_variables()) if original_vars != recompute_vars: raise ValueError(_WRONG_VARS_ERR) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index 997f53b9e1bbf9ac151cadd4a9f8e79c2e0ebca2..bc09ba8d439808c1582f207a99504012afcf33a6 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -21,9 +21,11 @@ from __future__ import print_function from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.layers import rev_block_lib from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.layers import convolutional from tensorflow.python.layers import core as core_layers +from tensorflow.python.layers import normalization as normalization_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops @@ -342,6 +344,34 @@ class RecomputeTest(test.TestCase): for grad in grads: self.assertTrue(grad is not None) + def testWithIsRecomputeKwarg(self): + + kwarg_values = [] + + @rev_block_lib.recompute_grad + def layer_with_recompute(inputs, is_recomputing=False): + kwarg_values.append(is_recomputing) + out = core_layers.dense(inputs, 2) + out = normalization_layers.batch_normalization(out, training=True) + if is_recomputing: + # Ensure that the updates are not duplicated by popping off the latest + # 2 additions. + update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS) + update_ops.pop() + update_ops.pop() + return out + + x = array_ops.ones((2, 4), dtypes.float32) + with variable_scope.variable_scope("layer1", use_resource=True): + y = layer_with_recompute(x) + loss = math_ops.reduce_sum(y) + tvars = variables.trainable_variables() + gradients_impl.gradients(loss, [x] + tvars) + + update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS) + self.assertEqual(2, len(update_ops)) + self.assertEqual([False, True], kwarg_values) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 0fdbe8f6308e30db2043c400f37d7dcb6058d1f2..b56a88659bbd4467600788fc8e3e9dbf38ce8244 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -284,6 +284,7 @@ py_test( tags = [ "manual", "noasan", # times out + "optonly", # test is flaky without optimization. ], deps = [ ":learn", diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 70b70af98c51dcb991c19152607272673953ee2a..e100bc7a1e7be4896e9ab1c965775b5185b38897 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -31,7 +31,6 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib @@ -51,6 +50,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import session_run_hook from tensorflow.python.training import training as train +from tensorflow.python.training import training_util # The default learning rate of 0.2 is a historical artifact of the initial @@ -244,7 +244,9 @@ def sdca_model_fn(features, labels, mode, params): parent_scope = "linear" with variable_scope.variable_scope( - values=features.values(), name_or_scope=parent_scope) as scope: + values=features.values(), + name_or_scope=parent_scope, + partitioner=optimizer.partitioner) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index 0a863f0e20c05d3372ffd8f7677cd518390ecc9d..597ca4e86dbf66c86182f14a2a364b662d52fb0a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -43,6 +43,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test from tensorflow.python.training import ftrl from tensorflow.python.training import input as input_lib @@ -966,6 +967,63 @@ class LinearClassifierTest(test.TestCase): scores = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerPartitionedVariables(self): + """Tests LinearClassifier with SDCAOptimizer with partitioned variables.""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [1.0], [1.0]]) + }, constant_op.constant([[1], [0], [1]]) + + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + + tf_config = { + 'cluster': { + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + config = run_config.RunConfig() + # Because we did not start a distributed cluster, we need to pass an + # empty ClusterSpec, otherwise the device_setter will look for + # distributed jobs, such as "/job:ps" which are not present. + config._cluster_spec = server_lib.ClusterSpec({}) + + classifier = linear.LinearClassifier( + feature_columns=[price, sq_footage_bucket, country, sq_footage_country], + weight_column_name='weights', + optimizer=sdca_optimizer, + config=config) + classifier.fit(input_fn=input_fn, steps=50) + scores = classifier.evaluate(input_fn=input_fn, steps=1) + print('all scores = {}'.format(scores)) + self.assertGreater(scores['accuracy'], 0.9) + def testEval(self): """Tests that eval produces correct metrics. """ @@ -1540,6 +1598,60 @@ class LinearRegressorTest(test.TestCase): loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] self.assertLess(loss, 0.05) + def testSdcaOptimizerPartitionedVariables(self): + """Tests LinearRegressor with SDCAOptimizer with partitioned variables.""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([0.6, 0.8, 0.3]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [5.0], [7.0]]) + }, constant_op.constant([[1.55], [-1.25], [-3.0]]) + + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id', symmetric_l2_regularization=1.0, + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + tf_config = { + 'cluster': { + run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1'] + } + } + with test.mock.patch.dict('os.environ', + {'TF_CONFIG': json.dumps(tf_config)}): + config = run_config.RunConfig() + # Because we did not start a distributed cluster, we need to pass an + # empty ClusterSpec, otherwise the device_setter will look for + # distributed jobs, such as "/job:ps" which are not present. + config._cluster_spec = server_lib.ClusterSpec({}) + + regressor = linear.LinearRegressor( + feature_columns=[price, sq_footage_bucket, country, sq_footage_country], + weight_column_name='weights', + optimizer=sdca_optimizer, + config=config) + regressor.fit(input_fn=input_fn, steps=20) + loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] + self.assertLess(loss, 0.05) + def testSdcaOptimizerSparseFeaturesWithL1Reg(self): """Tests LinearClassifier with SDCAOptimizer and sparse features.""" diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 541da9061732ad271f6d5456446a9c30b81e58dd..f8a3709ee57a32734afa7ac8133271c75d152b2c 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -505,7 +505,7 @@ class Experiment(object): eval_result = None last_warning_time = 0 while (not predicate_fn or predicate_fn( - eval_result, checkpoint_path=previous_path if eval_result else None)): + eval_result, checkpoint_path=previous_path)): # Exit if we have already reached number of steps to train. if self._has_training_stopped(eval_result): logging.info("Exiting continuous eval, global_step=%s >= " diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index d10927a0cdd5c67c8d2a8e569153235ee175ec4d..fb16c94c29660e2777942ea9cf30da51dbf90571 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -500,7 +500,7 @@ class ExperimentTest(test.TestCase): noop_hook = _NoopHook() def _predicate_fn(eval_result, checkpoint_path): - self.assertEqual(not eval_result, + self.assertEqual(eval_result is None, checkpoint_path is None) return est.eval_count < 3 # pylint: disable=cell-var-from-loop diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index b5741967ab52568725d7c9f03a0cc0b0f63f7459..ef0e08a777779e04f70d11fe83280ccaf1c178fd 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -35,6 +35,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import googletest @@ -132,15 +134,22 @@ def make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero): return examples_dict, variables_dict -def make_variable_dict(max_age, max_gender): +def make_variable_dict(max_age, max_gender, partitioned=False): # TODO(sibyl-toe9oF2e): Figure out how to derive max_age & max_gender from # examples_dict. - age_weights = variables_lib.Variable( - array_ops.zeros( - [max_age + 1], dtype=dtypes.float32)) - gender_weights = variables_lib.Variable( - array_ops.zeros( - [max_gender + 1], dtype=dtypes.float32)) + partitioner = None + if partitioned: + partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2, + axis=0) + with variable_scope.variable_scope( + name_or_scope='variables', + partitioner=partitioner): + age_weights = variables_lib.Variable( + array_ops.zeros( + [max_age + 1], dtype=dtypes.float32)) + gender_weights = variables_lib.Variable( + array_ops.zeros( + [max_gender + 1], dtype=dtypes.float32)) return dict( sparse_features_weights=[age_weights, gender_weights], dense_features_weights=[]) @@ -265,6 +274,54 @@ class SdcaWithLogisticLossTest(SdcaModelTest): self.assertAllClose( 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + def testPartitionedPrimals(self): + # Setup test data + example_protos = [ + make_example_proto({ + 'age': [0], + 'gender': [0] + }, 0), + make_example_proto({ + 'age': [1], + 'gender': [1] + }, 1), + ] + example_weights = [1.0, 1.0] + for num_shards in _SHARD_NUMBERS: + with self._single_threaded_test_session(): + examples = make_example_dict(example_protos, example_weights) + variables = make_variable_dict(1, 1, partitioned=True) + options = dict( + symmetric_l2_regularization=1, + symmetric_l1_regularization=0, + num_table_shards=num_shards, + loss_type='logistic_loss') + + lr = SdcaModel(examples, variables, options) + variables_lib.global_variables_initializer().run() + unregularized_loss = lr.unregularized_loss(examples) + loss = lr.regularized_loss(examples) + predictions = lr.predictions(examples) + self.assertAllClose(0.693147, unregularized_loss.eval()) + self.assertAllClose(0.693147, loss.eval()) + train_op = lr.minimize() + for _ in range(_MAX_ITERATIONS): + train_op.run() + lr.update_weights(train_op).run() + # The high tolerance in unregularized_loss comparisons is due to the + # fact that it's possible to trade off unregularized_loss vs. + # regularization and still have a sum that is quite close to the + # optimal regularized_loss value. SDCA's duality gap only ensures that + # the regularized_loss is within 0.01 of optimal. + # 0.525457 is the optimal regularized_loss. + # 0.411608 is the unregularized_loss at that optimum. + self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05) + self.assertAllClose(0.525457, loss.eval(), atol=0.01) + predicted_labels = get_binary_predictions_for_logistic(predictions) + self.assertAllEqual([0, 1], predicted_labels.eval()) + self.assertAllClose( + 0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2) + def testSparseRandom(self): dim = 20 num_examples = 1000 @@ -320,7 +377,10 @@ class SdcaWithLogisticLossTest(SdcaModelTest): train_op.run() def testDistributedSimple(self): - # Setup test data + # Distributed SDCA may not converge if the workers update concurrently the + # same example. In this test the examples are partitioned across workers. + # The examples are the same for all workers, just the example_ids are + # different. example_protos = [ make_example_proto({ 'age': [0], @@ -332,13 +392,19 @@ class SdcaWithLogisticLossTest(SdcaModelTest): }, 1), ] example_weights = [1.0, 1.0] + examples = make_example_dict(example_protos, example_weights) + example_ids = array_ops.placeholder( + dtypes.string, shape=(len(example_weights),)) + examples['example_ids'] = example_ids + variables = make_variable_dict(1, 1) for num_shards in _SHARD_NUMBERS: for num_loss_partitions in _NUM_LOSS_PARTITIONS: with self._single_threaded_test_session(): - examples = make_example_dict(example_protos, example_weights) - variables = make_variable_dict(1, 1) options = dict( - symmetric_l2_regularization=1, + # Keep the same solution as for TestSimple: since the number of + # examples is multplied by num_loss_partitions, multiply also + # L2 by the same value. + symmetric_l2_regularization=num_loss_partitions, symmetric_l1_regularization=0, loss_type='logistic_loss', num_table_shards=num_shards, @@ -354,32 +420,30 @@ class SdcaWithLogisticLossTest(SdcaModelTest): train_op = lr.minimize() - def minimize(): + def minimize(worker_id): with self._single_threaded_test_session(): + feed_dict = {example_ids: [ + str(i + worker_id*len(example_weights)) for i in range( + len(example_weights))]} for _ in range(_MAX_ITERATIONS): - train_op.run() # pylint: disable=cell-var-from-loop + train_op.run(feed_dict=feed_dict) # pylint: disable=cell-var-from-loop threads = [] - for _ in range(num_loss_partitions): - threads.append(threading.Thread(target=minimize)) + for worker_id in range(num_loss_partitions): + threads.append(threading.Thread(target=minimize, args=(worker_id,))) threads[-1].start() for t in threads: t.join() - lr.update_weights(train_op).run() - - # The high tolerance in unregularized_loss comparisons is due to the - # fact that it's possible to trade off unregularized_loss vs. - # regularization and still have a sum that is quite close to the - # optimal regularized_loss value. SDCA's duality gap only ensures - # that the regularized_loss is within 0.01 of optimal. - # 0.525457 is the optimal regularized_loss. - # 0.411608 is the unregularized_loss at that optimum. - self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05) - self.assertAllClose(0.525457, loss.eval(), atol=0.01) + lr.update_weights(train_op).run(feed_dict={ + example_ids: [str(i) for i in range(len(example_weights))]}) + + # Test only the unregularized loss because the optimal value of the + # regularized loss depends on num_loss_partitions. + self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.02) predicted_labels = get_binary_predictions_for_logistic(predictions) self.assertAllEqual([0, 1], predicted_labels.eval()) - self.assertTrue(lr.approximate_duality_gap().eval() < 0.02) + self.assertNear(0.0, lr.approximate_duality_gap().eval(), 0.02) def testSimpleNoL2(self): # Same as test above (so comments from above apply) but without an L2. diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index f980746a19fb8e0a02b9d023c127da7ab33e457f..0047d5753a773ce814d685f89da9ae6b04d21cb6 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -22,12 +22,14 @@ import collections from six.moves import range from tensorflow.contrib.linear_optimizer.python.ops.sharded_mutable_dense_hashtable import ShardedMutableDenseHashTable +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework.ops import internal_convert_to_tensor from tensorflow.python.framework.ops import name_scope from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_sdca_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -43,9 +45,6 @@ __all__ = ['SdcaModel'] class SdcaModel(object): """Stochastic dual coordinate ascent solver for linear models. - This class currently only supports a single machine (multi-threaded) - implementation. We expect the weights and duals to fit in a single machine. - Loss functions supported: * Binary logistic loss @@ -182,18 +181,41 @@ class SdcaModel(object): # TODO(sibyl-Aix6ihai): Use optimizer interface to make use of slot creation logic. def _create_slots(self): - # Make internal variables which have the updates before applying L1 - # regularization. + """Make unshrinked internal variables (slots).""" + # Unshrinked variables have the updates before applying L1 regularization. + # Each unshrinked slot variable is either a `Variable` or list of + # `Variable`, depending on the value of its corresponding primary variable. + # We avoid using `PartitionedVariable` for the unshrinked slots since we do + # not need any of the extra information. self._slots = collections.defaultdict(list) for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: - with ops.device(var.device): - # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is - # fixed - self._slots['unshrinked_' + name].append( - var_ops.Variable( - array_ops.zeros_like(var.initialized_value(), dtypes.float32), - name=var.op.name + '_unshrinked/SDCAOptimizer')) + # Our primary variable may be either a PartitionedVariable, or a list + # of Variables (each representing a partition). + if (isinstance(var, var_ops.PartitionedVariable) or + isinstance(var, list)): + var_list = [] + # pylint: disable=protected-access + for v in var: + with ops.colocate_with(v): + # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 + # is fixed. + slot_var = var_ops.Variable( + initial_value=array_ops.zeros_like(v.initialized_value(), + dtypes.float32), + name=v.op.name + '_unshrinked/SDCAOptimizer') + var_list.append(slot_var) + self._slots['unshrinked_' + name].append(var_list) + # pylint: enable=protected-access + else: + with ops.device(var.device): + # TODO(andreasst): remove SDCAOptimizer suffix once bug 30843109 is + # fixed. + self._slots['unshrinked_' + name].append( + var_ops.Variable( + array_ops.zeros_like(var.initialized_value(), + dtypes.float32), + name=var.op.name + '_unshrinked/SDCAOptimizer')) def _assertSpecified(self, items, check_in): for x in items: @@ -205,16 +227,25 @@ class SdcaModel(object): if not isinstance(check_in[x], list): raise ValueError(x + ' must be a list.') + def _var_to_list(self, var): + """Wraps var in a list if it is not a list or PartitionedVariable.""" + if not (isinstance(var, list) or + isinstance(var, var_ops.PartitionedVariable)): + var = [var] + return var + def _l1_loss(self): """Computes the (un-normalized) l1 loss of the model.""" with name_scope('sdca/l1_loss'): sums = [] for name in ['sparse_features_weights', 'dense_features_weights']: - for weights in self._convert_n_to_tensor(self._variables[name]): - with ops.device(weights.device): - sums.append( - math_ops.reduce_sum( - math_ops.abs(math_ops.cast(weights, dtypes.float64)))) + for var in self._variables[name]: + for v in self._var_to_list(var): + weights = internal_convert_to_tensor(v) + with ops.device(weights.device): + sums.append( + math_ops.reduce_sum( + math_ops.abs(math_ops.cast(weights, dtypes.float64)))) # SDCA L1 regularization cost is: l1 * sum(|weights|) return self._options['symmetric_l1_regularization'] * math_ops.add_n(sums) @@ -223,17 +254,37 @@ class SdcaModel(object): with name_scope('sdca/l2_loss'): sums = [] for name in ['sparse_features_weights', 'dense_features_weights']: - for weights in self._convert_n_to_tensor(self._variables[name]): - with ops.device(weights.device): - sums.append( - math_ops.reduce_sum( - math_ops.square(math_ops.cast(weights, dtypes.float64)))) + for var in self._variables[name]: + for v in self._var_to_list(var): + weights = internal_convert_to_tensor(v) + with ops.device(weights.device): + sums.append(math_ops.reduce_sum(math_ops.square(math_ops.cast( + weights, dtypes.float64)))) # SDCA L2 regularization cost is: l2 * sum(weights^2) / 2 return l2 * math_ops.add_n(sums) / 2.0 def _convert_n_to_tensor(self, input_list, as_ref=False): """Converts input list to a set of tensors.""" - return [internal_convert_to_tensor(x, as_ref=as_ref) for x in input_list] + # input_list can be a list of Variables (that are implicitly partitioned), + # in which case the underlying logic in internal_convert_to_tensor will not + # concatenate the partitions together. This method takes care of the + # concatenating (we only allow partitioning on the first axis). + output_list = [] + for x in input_list: + tensor_to_convert = x + if isinstance(x, list) or isinstance(x, var_ops.PartitionedVariable): + # We only allow for partitioning on the first axis. + tensor_to_convert = array_ops.concat(x, axis=0) + output_list.append(internal_convert_to_tensor( + tensor_to_convert, as_ref=as_ref)) + return output_list + + def _get_first_dimension_size_statically(self, w, num_partitions): + """Compute the static size of the first dimension for a sharded variable.""" + dim_0_size = w[0].get_shape()[0] + for p in range(1, num_partitions): + dim_0_size += w[p].get_shape()[0] + return dim_0_size def _linear_predictions(self, examples): """Returns predictions of the form w*x.""" @@ -286,6 +337,28 @@ class SdcaModel(object): result = math_ops.sigmoid(result) return result + def _get_partitioned_update_ops(self, + v_num, + num_partitions_by_var, + p_assignments_by_var, + gather_ids_by_var, + weights, + full_update, + p_assignments, + num_partitions): + """Get updates for partitioned variables.""" + num_partitions = num_partitions_by_var[v_num] + p_assignments = p_assignments_by_var[v_num] + gather_ids = gather_ids_by_var[v_num] + updates = data_flow_ops.dynamic_partition( + full_update, p_assignments, num_partitions) + update_ops = [] + for p in range(num_partitions): + with ops.colocate_with(weights[p]): + result = state_ops.scatter_add(weights[p], gather_ids[p], updates[p]) + update_ops.append(result) + return update_ops + def minimize(self, global_step=None, name=None): """Add operations to train a linear model by minimizing the loss function. @@ -318,18 +391,89 @@ class SdcaModel(object): # Solver returns example_state_update, new delta sparse_feature_weights # and delta dense_feature_weights. - weights_tensor = self._convert_n_to_tensor(self._slots[ - 'unshrinked_sparse_features_weights']) sparse_weights = [] sparse_indices = [] - for w, i in zip(weights_tensor, sparse_feature_indices): - # Find the feature ids to lookup in the variables. - with ops.device(w.device): - sparse_indices.append( - math_ops.cast( - array_ops.unique(math_ops.cast(i, dtypes.int32))[0], - dtypes.int64)) - sparse_weights.append(array_ops.gather(w, sparse_indices[-1])) + # If we have partitioned variables, keep a few lists of Tensors around + # that we need for the assign_add after the op call to + # gen_sdca_ops.sdca_optimizer(). + num_partitions_by_var = [] + p_assignments_by_var = [] + gather_ids_by_var = [] + for w, i in zip(self._slots['unshrinked_sparse_features_weights'], + sparse_feature_indices): + # Append the sparse_indices (in full-variable space). + sparse_idx = math_ops.cast( + array_ops.unique(math_ops.cast(i, dtypes.int32))[0], + dtypes.int64) + sparse_indices.append(sparse_idx) + if isinstance(w, list) or isinstance(w, var_ops.PartitionedVariable): + num_partitions = len(w) + flat_ids = array_ops.reshape(sparse_idx, [-1]) + # We use div partitioning, which is easiest to support downstream. + # Compute num_total_ids as the sum of dim-0 of w, then assign + # to partitions based on a constant number of ids per partition. + # Optimize if we already know the full shape statically. + dim_0_size = self._get_first_dimension_size_statically( + w, num_partitions) + + if dim_0_size.value: + num_total_ids = constant_op.constant(dim_0_size.value, + flat_ids.dtype) + else: + dim_0_sizes = [] + for p in range(num_partitions): + if w[p].get_shape()[0].value is not None: + dim_0_sizes.append(w[p].get_shape()[0].value) + else: + with ops.colocate_with(w[p]): + dim_0_sizes.append(array_ops.shape(w[p])[0]) + num_total_ids = math_ops.reduce_sum( + math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) + ids_per_partition = num_total_ids // num_partitions + extras = num_total_ids % num_partitions + + p_assignments = math_ops.maximum( + flat_ids // (ids_per_partition + 1), + (flat_ids - extras) // ids_per_partition) + + # Emulate a conditional using a boolean indicator tensor + new_ids = array_ops.where(p_assignments < extras, + flat_ids % (ids_per_partition + 1), + (flat_ids - extras) % ids_per_partition) + + # Cast partition assignments to int32 for use in dynamic_partition. + # There really should not be more than 2^32 partitions. + p_assignments = math_ops.cast(p_assignments, dtypes.int32) + # Partition list of ids based on assignments into num_partitions + # separate lists. + gather_ids = data_flow_ops.dynamic_partition(new_ids, + p_assignments, + num_partitions) + # Append these to the lists for use in the later update. + num_partitions_by_var.append(num_partitions) + p_assignments_by_var.append(p_assignments) + gather_ids_by_var.append(gather_ids) + + # Gather the weights from each partition. + partition_gathered_weights = [] + for p in range(num_partitions): + with ops.colocate_with(w[p]): + partition_gathered_weights.append( + array_ops.gather(w[p], gather_ids[p])) + + # Stitch the weights back together in the same order they were before + # we dynamic_partitioned them. + condition_indices = data_flow_ops.dynamic_partition( + math_ops.range(array_ops.shape(new_ids)[0]), + p_assignments, num_partitions) + batch_gathered_weights = data_flow_ops.dynamic_stitch( + condition_indices, partition_gathered_weights) + else: + w_as_tensor = internal_convert_to_tensor(w) + with ops.device(w_as_tensor.device): + batch_gathered_weights = array_ops.gather( + w_as_tensor, sparse_idx) + sparse_weights.append(batch_gathered_weights) # pylint: disable=protected-access esu, sfw, dfw = gen_sdca_ops.sdca_optimizer( @@ -355,12 +499,25 @@ class SdcaModel(object): with ops.control_dependencies([esu]): update_ops = [self._hashtable.insert(example_ids_hashed, esu)] # Update the weights before the proximal step. - for w, i, u in zip(self._slots['unshrinked_sparse_features_weights'], - sparse_indices, sfw): - update_ops.append(state_ops.scatter_add(w, i, u)) + for v_num, (w, i, u) in enumerate( + zip(self._slots['unshrinked_sparse_features_weights'], + sparse_indices, sfw)): + if (isinstance(w, var_ops.PartitionedVariable) or + isinstance(w, list)): + update_ops += self._get_partitioned_update_ops( + v_num, num_partitions_by_var, p_assignments_by_var, + gather_ids_by_var, w, u, p_assignments, num_partitions) + else: + update_ops.append(state_ops.scatter_add(w, i, u)) for w, u in zip(self._slots['unshrinked_dense_features_weights'], dfw): - update_ops.append(w.assign_add(u)) - + if (isinstance(w, var_ops.PartitionedVariable) or + isinstance(w, list)): + split_updates = array_ops.split( + u, num_or_size_splits=[v.shape.as_list()[0] for v in w]) + for v, split_update in zip(w, split_updates): + update_ops.append(state_ops.assign_add(v, split_update)) + else: + update_ops.append(state_ops.assign_add(w, u)) if not global_step: return control_flow_ops.group(*update_ops) with ops.control_dependencies(update_ops): @@ -385,21 +542,22 @@ class SdcaModel(object): for name in ['sparse_features_weights', 'dense_features_weights']: for var, slot_var in zip(self._variables[name], self._slots['unshrinked_' + name]): - update_ops.append(var.assign(slot_var)) + for v, sv in zip(self._var_to_list(var), self._var_to_list(slot_var)): + update_ops.append(v.assign(sv)) # Apply proximal step. with ops.control_dependencies(update_ops): update_ops = [] for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: - with ops.device(var.device): - # pylint: disable=protected-access - update_ops.append( - gen_sdca_ops.sdca_shrink_l1( - self._convert_n_to_tensor( - [var], as_ref=True), - l1=self._symmetric_l1_regularization(), - l2=self._symmetric_l2_regularization())) + for v in self._var_to_list(var): + with ops.device(v.device): + # pylint: disable=protected-access + update_ops.append( + gen_sdca_ops.sdca_shrink_l1( + self._convert_n_to_tensor([v], as_ref=True), + l1=self._symmetric_l1_regularization(), + l2=self._symmetric_l2_regularization())) return control_flow_ops.group(*update_ops) def approximate_duality_gap(self): diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index d4e54c82f988e0adcd16aad29702ee9f8b16aea3..200e7de6b95f17672c6ef51f887b15f9d185f775 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -116,6 +116,7 @@ def sdca_model_fn(features, labels, mode, params, config=None): num_loss_partitions = params["num_loss_partitions"] weight_column_name = params["weight_column_name"] update_weights_hook = params.get("update_weights_hook", None) + partitioner = params["partitioner"] loss_type = None if isinstance(head, head_lib._BinarySvmHead): # pylint: disable=protected-access @@ -136,12 +137,14 @@ def sdca_model_fn(features, labels, mode, params, config=None): example_id_column=example_id_column, num_loss_partitions=n_loss_partitions, symmetric_l1_regularization=l1_regularization, - symmetric_l2_regularization=l2_regularization) + symmetric_l2_regularization=l2_regularization, + partitioner=partitioner) parent_scope = "linear" with variable_scope.variable_scope( - values=features.values(), name_or_scope=parent_scope) as scope: + values=features.values(), name_or_scope=parent_scope, + partitioner=partitioner) as scope: features = features.copy() features.update(layers.transform_features(features, feature_columns)) logits, columns_to_variables, bias = ( @@ -213,7 +216,8 @@ class _SDCAEstimator(estimator.Estimator): l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `_SDCAEstimator` estimator object. Args: @@ -241,6 +245,8 @@ class _SDCAEstimator(estimator.Estimator): feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `_SDCAEstimator` estimator. @@ -267,6 +273,7 @@ class _SDCAEstimator(estimator.Estimator): "l2_regularization": l2_regularization, "weight_column_name": weight_column_name, "update_weights_hook": _SdcaUpdateWeightsHook(), + "partitioner": partitioner, } super(_SDCAEstimator, self).__init__( @@ -336,7 +343,8 @@ class SDCALogisticClassifier(_SDCAEstimator): l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `SDCALogisticClassifier` object. Args: @@ -361,6 +369,8 @@ class SDCALogisticClassifier(_SDCAEstimator): feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `SDCALogisiticClassifier` estimator. @@ -376,7 +386,8 @@ class SDCALogisticClassifier(_SDCAEstimator): l2_regularization=l2_regularization, num_loss_partitions=num_loss_partitions, config=config, - feature_engineering_fn=None) + feature_engineering_fn=None, + partitioner=partitioner) def predict_classes(self, input_fn=None): """Runs inference to determine the predicted class. @@ -463,7 +474,8 @@ class SDCALinearRegressor(_SDCAEstimator): l2_regularization=1.0, num_loss_partitions=None, config=None, - feature_engineering_fn=None): + feature_engineering_fn=None, + partitioner=None): """Construct a `SDCALinearRegressor` estimator object. @@ -489,6 +501,8 @@ class SDCALinearRegressor(_SDCAEstimator): feature_engineering_fn: Feature engineering function. Takes features and labels which are the output of `input_fn` and returns features and labels which will be fed into the model. + partitioner: Variable partitioner for the primal weights (`div` + partitioning strategy will be used). Returns: A `SDCALinearRegressor` estimator. @@ -503,7 +517,8 @@ class SDCALinearRegressor(_SDCAEstimator): l2_regularization=l2_regularization, num_loss_partitions=num_loss_partitions, config=config, - feature_engineering_fn=None) + feature_engineering_fn=None, + partitioner=partitioner) def predict_scores(self, input_fn): """Returns predicted scores for given features. diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py index bed3d5139fcbf9d9e8b85605c752736f26af6793..647667188238dc18b137eaad98356a79b3a549b4 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator_test.py @@ -25,6 +25,7 @@ from tensorflow.contrib.linear_optimizer.python import sdca_estimator from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import test @@ -273,6 +274,47 @@ class SDCALogisticClassifierTest(test.TestCase): metrics = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(metrics['accuracy'], 0.9) + def testPartitionedMixedFeatures(self): + """Tests SDCALogisticClassifier with a mix of features (partitioned).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([900.0, 700.0, 600.0]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [1.0], [1.0]]) + }, constant_op.constant([[1], [0], [1]]) + + with self._single_threaded_test_session(): + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + classifier = sdca_estimator.SDCALogisticClassifier( + example_id_column='example_id', + feature_columns=[ + price, sq_footage_bucket, country, sq_footage_country + ], + weight_column_name='weights', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + classifier.fit(input_fn=input_fn, steps=50) + metrics = classifier.evaluate(input_fn=input_fn, steps=1) + self.assertGreater(metrics['accuracy'], 0.9) + class SDCALinearRegressorTest(test.TestCase): @@ -350,6 +392,48 @@ class SDCALinearRegressorTest(test.TestCase): loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] self.assertLess(loss, 0.05) + def testMixedFeaturesArbitraryWeightsPartitioned(self): + """Tests SDCALinearRegressor works with a mix of features (partitioned).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + constant_op.constant([[0.6], [0.8], [0.3]]), + 'sq_footage': + constant_op.constant([[900.0], [700.0], [600.0]]), + 'country': + sparse_tensor.SparseTensor( + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 3], [2, 1]], + dense_shape=[3, 5]), + 'weights': + constant_op.constant([[3.0], [5.0], [7.0]]) + }, constant_op.constant([[1.55], [-1.25], [-3.0]]) + + with self._single_threaded_test_session(): + price = feature_column_lib.real_valued_column('price') + sq_footage_bucket = feature_column_lib.bucketized_column( + feature_column_lib.real_valued_column('sq_footage'), + boundaries=[650.0, 800.0]) + country = feature_column_lib.sparse_column_with_hash_bucket( + 'country', hash_bucket_size=5) + sq_footage_country = feature_column_lib.crossed_column( + [sq_footage_bucket, country], hash_bucket_size=10) + regressor = sdca_estimator.SDCALinearRegressor( + example_id_column='example_id', + feature_columns=[ + price, sq_footage_bucket, country, sq_footage_country + ], + l2_regularization=1.0, + weight_column_name='weights', + partitioner=partitioned_variables.fixed_size_partitioner( + num_shards=2, axis=0)) + regressor.fit(input_fn=input_fn, steps=20) + loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss'] + self.assertLess(loss, 0.05) + def testSdcaOptimizerSparseFeaturesWithL1Reg(self): """SDCALinearRegressor works with sparse features and L1 regularization.""" diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py index 12039ecc6f357af07e0c2a08e17d46396f3ad386..9872c6f97c879d8994b6c26e65df33e368a0603e 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py @@ -64,7 +64,8 @@ class SDCAOptimizer(object): of workers running the train steps. It defaults to 1 (single machine). `num_table_shards` defines the number of shards for the internal state table, typically set to match the number of parameter servers for large - data sets. + data sets. You can also specify a `partitioner` object to partition the primal + weights during training (`div` partitioning strategy will be used). """ def __init__(self, @@ -73,13 +74,15 @@ class SDCAOptimizer(object): num_table_shards=None, symmetric_l1_regularization=0.0, symmetric_l2_regularization=1.0, - adaptive=True): + adaptive=True, + partitioner=None): self._example_id_column = example_id_column self._num_loss_partitions = num_loss_partitions self._num_table_shards = num_table_shards self._symmetric_l1_regularization = symmetric_l1_regularization self._symmetric_l2_regularization = symmetric_l2_regularization self._adaptive = adaptive + self._partitioner = partitioner def get_name(self): return 'SDCAOptimizer' @@ -108,6 +111,10 @@ class SDCAOptimizer(object): def adaptive(self): return self._adaptive + @property + def partitioner(self): + return self._partitioner + def get_train_step(self, columns_to_variables, weight_column_name, loss_type, features, targets, global_step): """Returns the training operation of an SdcaModel optimizer.""" @@ -175,10 +182,12 @@ class SDCAOptimizer(object): sparse_feature_column = _dense_tensor_to_sparse_feature_column( dense_bucket_tensor) sparse_feature_with_values.append(sparse_feature_column) - # For bucketized columns, the variables list contains exactly one - # element. - sparse_feature_with_values_weights.append( - columns_to_variables[column][0]) + # If a partitioner was used during variable creation, we will have a + # list of Variables here larger than 1. + vars_to_append = columns_to_variables[column][0] + if len(columns_to_variables[column]) > 1: + vars_to_append = columns_to_variables[column] + sparse_feature_with_values_weights.append(vars_to_append) elif isinstance( column, ( @@ -226,8 +235,12 @@ class SDCAOptimizer(object): array_ops.shape(ids)[0]), [-1]) sparse_feature_with_values.append( SparseFeatureColumn(example_ids_filtered, reproject_ids, weights)) - sparse_feature_with_values_weights.append( - columns_to_variables[column][0]) + # If a partitioner was used during variable creation, we will have a + # list of Variables here larger than 1. + vars_to_append = columns_to_variables[column][0] + if len(columns_to_variables[column]) > 1: + vars_to_append = columns_to_variables[column] + sparse_feature_with_values_weights.append(vars_to_append) else: raise ValueError('SDCAOptimizer does not support column type %s.' % type(column).__name__) diff --git a/tensorflow/contrib/lite/Android.bp b/tensorflow/contrib/lite/Android.bp index bd470696c5879822f9a75b77b3ec132500cd0d34..ff3d18b4b117863ab483b4ff9c5dac71fb5379ee 100644 --- a/tensorflow/contrib/lite/Android.bp +++ b/tensorflow/contrib/lite/Android.bp @@ -45,7 +45,7 @@ cc_library_static { "graph_info.cc", "interpreter.cc", "model.cc", - "op_resolver.cc", + "op_resolver.cc", "nnapi_delegate.cc", "optional_debug_tools.cc", "simple_memory_arena.cc", diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 55b984f260ec49ab9b52be6402885a46226cba70..9c804d27854b8004d34c65691b48ca2b0d3bbf7c 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -90,6 +90,16 @@ cc_library( deps = [":context"], ) +cc_library( + name = "kernel_api", + hdrs = [ + "builtin_op_data.h", + "builtin_ops.h", + "context.h", + "context_util.h", + ], +) + exports_files(["builtin_ops.h"]) cc_library( diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc index a8a8755e2c9e81474f2ff9cd2b85c0eb3d5c3441..16171df10a7b18b22919c6e54fe3d1e8e0120f69 100644 --- a/tensorflow/contrib/lite/arena_planner_test.cc +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -209,11 +209,8 @@ TEST_F(ArenaPlannerTest, ZeroSizedTensors) { TestGraph graph({1}, {{{1}, {2}, {}}}, {2}); (*graph.tensors())[1].bytes = 0; SetGraph(&graph); - // TODO(ahentz): this is currently broken because the arena finds two - // allocations with the same offset and returns an error. - ASSERT_FALSE(planner_->ExecuteAllocations(0, 10) == kTfLiteOk); - // EXPECT_EQ(GetOffset(1), 0); - // EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + ASSERT_EQ(planner_->ExecuteAllocations(0, 10), kTfLiteOk); + EXPECT_EQ((*graph_->tensors())[1].data.raw, nullptr); } TEST_F(ArenaPlannerTest, SimpleGraph) { diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 9bfc0a0fbeff38fb77b6d67c1a2df37a6807528c..974e6c5d98e5691a3733495e915d919c4bf57d3a 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -204,7 +204,9 @@ def generated_test_models(): "conv", "depthwiseconv", "div", + "equal", "exp", + "expand_dims", "floor", "fully_connected", "fused_batch_norm", @@ -212,18 +214,22 @@ def generated_test_models(): "global_batch_norm", "greater", "greater_equal", - "l2_pool", "l2norm", + "l2_pool", "less", "less_equal", "local_response_norm", "log_softmax", + "log", + # TODO(b/110143200): Enable after resolving issues with LSTM conversion. + # "lstm", "max_pool", "maximum", "mean", "minimum", "mul", "neg", + "not_equal", "pad", "padv2", # "prelu", @@ -238,11 +244,13 @@ def generated_test_models(): "softmax", "space_to_batch_nd", "space_to_depth", + "sparse_to_dense", "split", "squeeze", "strided_slice", "strided_slice_1d_exhaustive", "sub", + "tile", "topk", "transpose", "transpose_conv", diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 8660c653ae4c0c69e4f5ad8fae739c8c1db7414c..c1cc4476fbd45fa6b3f5b3a1ed2cba39cc2ad54b 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -148,10 +148,20 @@ typedef struct { float beta; } TfLiteLocalResponseNormParams; +typedef enum { + kTfLiteLSTMFullKernel = 0, + kTfLiteLSTMBasicKernel +} TfLiteLSTMKernelType; + typedef struct { + // Parameters for LSTM version 1. TfLiteFusedActivation activation; float cell_clip; float proj_clip; + + // Parameters for LSTM version 2. + // kTfLiteLSTMBasicKernel is only supported in version 2 or above. + TfLiteLSTMKernelType kernel_type; } TfLiteLSTMParams; typedef struct { @@ -236,6 +246,10 @@ typedef struct { int stride_height; } TfLiteTransposeConvParams; +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 7e285186f45a61a451fd7328b061e16059049ea5..aef9a92883f18dabfc36058507d739856c3c2af7 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by -// `schema_builtin_ops_header_generator.py`. +// `schema/builtin_ops_header/generator.cc`. #ifdef __cplusplus extern "C" { @@ -93,10 +93,15 @@ typedef enum { kTfLiteBuiltinSlice = 65, kTfLiteBuiltinSin = 66, kTfLiteBuiltinTransposeConv = 67, + kTfLiteBuiltinSparseToDense = 68, + kTfLiteBuiltinTile = 69, + kTfLiteBuiltinExpandDims = 70, + kTfLiteBuiltinEqual = 71, + kTfLiteBuiltinNotEqual = 72, + kTfLiteBuiltinLog = 73, } TfLiteBuiltinOperator; #ifdef __cplusplus } // extern "C" #endif // __cplusplus #endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ -} diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h new file mode 100644 index 0000000000000000000000000000000000000000..abe802e34214caf4d5063da827b3aca4a82aa56d --- /dev/null +++ b/tensorflow/contrib/lite/context_util.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This provides a few C++ helpers that are useful for manipulating C structures +// in C++. +#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Provide a range iterable wrapper for TfLiteIntArray* (C lists that TfLite +// C api uses. Can't use the google array_view, since we can't depend on even +// absl for embedded device reasons. +class TfLiteIntArrayView { + public: + // Construct a view of a TfLiteIntArray*. Note, `int_array` should be non-null + // and this view does not take ownership of it. + explicit TfLiteIntArrayView(const TfLiteIntArray* int_array) + : int_array_(int_array) {} + + TfLiteIntArrayView(const TfLiteIntArrayView&) = default; + TfLiteIntArrayView& operator=(const TfLiteIntArrayView& rhs) = default; + + typedef const int* const_iterator; + const_iterator begin() const { return int_array_->data; } + const_iterator end() const { return &int_array_->data[int_array_->size]; } + size_t size() const { return end() - begin(); } + + private: + const TfLiteIntArray* int_array_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..35a8f6ca4166e373ea1a0af5d4a013327b30d2b6 --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD @@ -0,0 +1,31 @@ +package(default_visibility = [ + "//visibility:public", +]) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "nnapi_delegate", + srcs = ["nnapi_delegate.cc"], + hdrs = ["nnapi_delegate.h"], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:kernel_api", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + ], +) + +tf_cc_test( + name = "nnapi_delegate_test", + size = "small", + srcs = ["nnapi_delegate_test.cc"], + deps = [ + ":nnapi_delegate", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc new file mode 100644 index 0000000000000000000000000000000000000000..0731d14419d2dec2ea5efa48ef5d4b7728af635f --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -0,0 +1,464 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/builtin_ops.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/context_util.h" +#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" + +namespace tflite { +namespace { + +// TODO(b/80621585): Consider printing error string, but don't for now to +// minimize binary size. +#define CHECK_NN(context, code) \ + if (code != ANEURALNETWORKS_NO_ERROR) { \ + context->ReportError(context, "NN API returned error (%d).\n", code); \ + return kTfLiteError; \ + } + +// RAII NN API Model Destructor for use with std::unique_ptr +struct NNFreeModel { + void operator()(ANeuralNetworksModel* model) { + ANeuralNetworksModel_free(model); + } +}; +// RAII NN API Compilation Destructor for use with std::unique_ptr +struct NNFreeCompilation { + void operator()(ANeuralNetworksCompilation* model) { + ANeuralNetworksCompilation_free(model); + } +}; + +// Track tensor indices to NN API tensor indices mapping. +class OperandMapping { + public: + // Given a TFLite index return the ANN index. If it doesn't exist + // return -1. + int lite_index_to_ann(int index) const { + if (index < lite_tensor_to_ann_tensor_.size()) + return lite_tensor_to_ann_tensor_[index]; + else + return -1; + } + + // NN API uses non tensor operands instead of structs. This creates one + // and returns the index. It uses a std::vector and resizes it as needed + // keeping -1 to unmapped values. Intermediate tensors likely will not + // be mapped. + int add_new_non_tensor_operand() { return next_ann_tensor_index_++; } + + // Add a new mapping from `tflite_index` and return the NN API tensor index. + int add_new_ann_tensor_index(int tflite_index) { + if (tflite_index >= lite_tensor_to_ann_tensor_.size()) { + lite_tensor_to_ann_tensor_.resize(tflite_index + 1); + } + int new_tensor_index = next_ann_tensor_index_++; + lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index; + return new_tensor_index; + } + + private: + // Next index of ann tensor + int next_ann_tensor_index_ = 0; + + // Mapping from lite index. Use a std::vector for speed and code size + // rather than a map. + std::vector lite_tensor_to_ann_tensor_; +}; + +// Abstract builder for building an op in the NN API graph. This handles +// the disparity between TFLite and NN API operand types. NN API has singular +// operands for both tensors and parameters, and TFLite separates the two. +class NNAPIOpBuilder { + public: + NNAPIOpBuilder(TfLiteContext* context, OperandMapping* tensor_mapping, + ANeuralNetworksModel* nn_model) + : context_(context), + operand_mapping_(tensor_mapping), + nn_model_(nn_model) {} + + TfLiteStatus AddScalarInt32Operand(int value) { + ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, &value, sizeof(int32_t))); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + + TfLiteStatus AddTensorInput(int tensor_index) { + int ann_index; + TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); + augmented_inputs_.push_back(ann_index); + return kTfLiteOk; + } + + TfLiteStatus AddTensorOutput(int tensor_index) { + int ann_index; + TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index)); + augmented_outputs_.push_back(ann_index); + return kTfLiteOk; + } + + // Adds a new NN API tensor that shadows the TF Lite tensor `tensor_index`. + // This returns the NN API tensor index corresponding to the created tensor. + // If another caller previously created a NN API tensor for `tensor_index` + // then the existing one is returned. + TfLiteStatus AddTensor(int tensor_index, int* ann_tensor_index_out) { + int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index); + if (ann_tensor_index != -1) { + *ann_tensor_index_out = ann_tensor_index; + return kTfLiteOk; + } + // Allocate a new tensor index + ann_tensor_index = operand_mapping_->add_new_ann_tensor_index(tensor_index); + + // Parameters needed for new type. + int32_t nn_type = 0; + float scale = 0.0f; + int32_t zeroPoint = 0; + TfLiteTensor* tensor = &context_->tensors[tensor_index]; + switch (tensor->type) { + case kTfLiteNoType: + // Tensors added during initialization of Ops don't have a type yet and + // should not be registered with the NNAPI. + *ann_tensor_index_out = -1; + return kTfLiteOk; + case kTfLiteFloat32: + nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; + scale = 0.f; + break; + case kTfLiteUInt8: + nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; + break; + case kTfLiteInt32: + nn_type = ANEURALNETWORKS_TENSOR_INT32; + scale = 0.f; + zeroPoint = 0; + break; + default: + context_->ReportError(context_, "Logic error in NN API Delegate.\n"); + return kTfLiteError; + } + + ANeuralNetworksOperandType operand_type{ + nn_type, static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), scale, zeroPoint}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + + if (tensor->allocation_type == kTfLiteMmapRo) { + // TODO(b/80630405): Use NNAPIAllocation. + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_tensor_index, tensor->data.raw, + tensor->bytes)); + } + + *ann_tensor_index_out = ann_tensor_index; + return kTfLiteOk; + } + + // Finish emitting the op (of type `type`) into the NN API. + TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type) { + // Actually add a NN API operation + CHECK_NN(context_, ANeuralNetworksModel_addOperation( + nn_model_, type, + static_cast(augmented_inputs_.size()), + augmented_inputs_.data(), + static_cast(augmented_outputs_.size()), + augmented_outputs_.data())); + augmented_outputs_.clear(); + augmented_outputs_.clear(); + return kTfLiteOk; + } + + private: + // TfLiteContext for error handling. Must be named context for macros to + // work. + TfLiteContext* context_; + + // Tracks relationship between indices + OperandMapping* operand_mapping_; + + // The model + ANeuralNetworksModel* nn_model_; + + // Inputs and outputs for the current op. These are augmented in the sense + // that NN API uses operands for all arguments, not just tensors, unlike + // TensorFlow lite. + std::vector augmented_inputs_; + std::vector augmented_outputs_; +}; + +// The kernel that represents the subgraph of TF Lite being run on NN API. +class NNAPIDelegateKernel { + public: + NNAPIDelegateKernel() = default; + + typedef ANeuralNetworksOperationType (*MappingFn)(TfLiteContext*, + NNAPIOpBuilder* builder, + TfLiteNode* node); + + // Return a function that knows how to translate a node into its operands + // when called. You can use this function to see if a node is supported + // (i.e. that MappingFn is not nullptr). + MappingFn Map(TfLiteContext* context, int builtin_code, TfLiteNode* node) { + switch (builtin_code) { + case kTfLiteBuiltinAdd: + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_ADD; + }; + break; + case kTfLiteBuiltinAveragePool2d: + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->filter_width); + builder->AddScalarInt32Operand(builtin->filter_height); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_AVERAGE_POOL_2D; + }; + break; + default: + return nullptr; + } + } + + // Initialize the kernel (a NN model). + TfLiteStatus Init(TfLiteContext* context, + const TfLiteDelegateParams* params) { + for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) { + nodes_.push_back(node_index); + } + + if (!nn_model_) { + ANeuralNetworksModel* model; + CHECK_NN(context, ANeuralNetworksModel_create(&model)); + nn_model_.reset(model); + + TF_LITE_ENSURE_STATUS( + BuildGraph(context, params->input_tensors, params->output_tensors)); + } + + if (!nn_compilation_) { + ANeuralNetworksCompilation* compilation; + CHECK_NN(context, ANeuralNetworksCompilation_create(nn_model_.get(), + &compilation)); + CHECK_NN(context, ANeuralNetworksCompilation_finish(compilation)); + nn_compilation_.reset(compilation); + } + return kTfLiteOk; + } + + TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) { + ANeuralNetworksExecution* execution = nullptr; + CHECK_NN(context, ANeuralNetworksExecution_create(nn_compilation_.get(), + &execution)); + + // Set the input tensor buffers. Note: we access tflite tensors using + // absolute indices but NN api indices inputs by relative indices. + int relative_input_index = 0; + for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { + TfLiteTensor* tensor = &context->tensors[absolute_input_index]; + CHECK_NN(context, ANeuralNetworksExecution_setInput( + execution, relative_input_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_input_index++; + } + + // Set the output tensor buffers. + int relative_output_index = 0; + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TfLiteTensor* tensor = &context->tensors[output_index]; + CHECK_NN(context, ANeuralNetworksExecution_setOutput( + execution, relative_output_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_output_index++; + } + // Invoke ANN in blocking fashion. + ANeuralNetworksEvent* event = nullptr; + CHECK_NN(context, ANeuralNetworksExecution_startCompute(execution, &event)); + CHECK_NN(context, ANeuralNetworksEvent_wait(event)); + ANeuralNetworksEvent_free(event); + ANeuralNetworksExecution_free(execution); + + return kTfLiteOk; + } + + private: + // ANN API state. + std::unique_ptr nn_model_; + std::unique_ptr + nn_compilation_; + // Node indices that this delegate is responsible for. Indices here + // indexes into the nodes array in the TfLiteContext. + std::vector nodes_; + // Track indices we use + OperandMapping operand_mapping_; + + TfLiteStatus AddOpsAndTensors(TfLiteContext* context) { + // The operand builder allows creating a single op. We create it at this + // reduced power position rather than in the for loop to avoid reallocating + // the vectors. + NNAPIOpBuilder builder(context, &operand_mapping_, nn_model_.get()); + // Add Tensors + // allocate outside to avoid realloc + for (auto node_index : nodes_) { + // Obtain the op and registration. + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + // Map inputs to NN API tensor indices. + for (auto input_index : TfLiteIntArrayView(node->inputs)) { + TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); + } + // Get op type and operands + int nn_op_type = + Map(context, reg->builtin_code, node)(context, &builder, node); + // Map outputs to NN API tensor indices. + for (auto output_index : TfLiteIntArrayView(node->outputs)) { + TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); + } + + builder.FinalizeAddOperation(nn_op_type); + } + return kTfLiteOk; + } + + TfLiteStatus BuildGraph(TfLiteContext* context, + const TfLiteIntArray* input_tensors, + const TfLiteIntArray* output_tensors) { + // Build the ops and tensors. + TF_LITE_ENSURE_STATUS(AddOpsAndTensors(context)); + // Map input and output tensor indices to ANN + std::vector inputs; + inputs.reserve(input_tensors->size); + std::vector outputs; + outputs.reserve(output_tensors->size); + // Make the TensorFlow lite inputs and outputs to ann_indices. + for (int i : TfLiteIntArrayView(input_tensors)) + inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + for (int i : TfLiteIntArrayView(output_tensors)) + outputs.push_back(operand_mapping_.lite_index_to_ann(i)); + // Tell ANN to declare inputs/outputs + CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs( + nn_model_.get(), inputs.size(), inputs.data(), + outputs.size(), outputs.data())); + // Finalize the model + CHECK_NN(context, ANeuralNetworksModel_finish(nn_model_.get())); + + return kTfLiteOk; + } +}; + +} // namespace + +// Return a NN API Delegate struct that can check for support of ops. +TfLiteDelegate* NnApiDelegate() { + static TfLiteDelegate delegate = { + .data_ = nullptr, + .Prepare = [](TfLiteContext* context, + TfLiteDelegate* delegate) -> TfLiteStatus { + // Do not check nodes_ if NN API is unavailable. + if (!NNAPIExists()) return kTfLiteOk; + + std::vector supported_nodes(1); + // We don't care about all nodes_, we only care about ones in the + // current plan. + TfLiteIntArray* plan; + TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); + int total_supported_nodes = 0; + // Check for every node if it is supported + // TODO(b/80625235): Fix this to do more careful checking of versioning. + for (int node_index : TfLiteIntArrayView(plan)) { + TfLiteNode* node; + TfLiteRegistration* registration; + TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( + context, node_index, &node, ®istration)); + NNAPIDelegateKernel dummy_kernel; + if (dummy_kernel.Map(context, registration->builtin_code, node)) { + supported_nodes.push_back(node_index); + } + total_supported_nodes += 1; + } + // Put the size at the beginning of the array. + supported_nodes[0] = supported_nodes.size() - 1; + + // NN API Delegate Registration (the pseudo kernel that will invoke NN + // API subgraphs) + static const TfLiteRegistration nnapi_delegate_kernel = { + .init = [](TfLiteContext* context, const char* buffer, + size_t length) -> void* { + const TfLiteDelegateParams* params = + reinterpret_cast(buffer); + NNAPIDelegateKernel* kernel_state = new NNAPIDelegateKernel; + kernel_state->Init(context, params); + return kernel_state; + }, + + .free = [](TfLiteContext* context, void* buffer) -> void { + delete reinterpret_cast(buffer); + }, + + .prepare = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + // Since the underlying resize happened ahead of delegation + // worked. This does nothing. + return kTfLiteOk; + }, + + .invoke = [](TfLiteContext* context, + TfLiteNode* node) -> TfLiteStatus { + NNAPIDelegateKernel* state = + reinterpret_cast(node->user_data); + return state->Invoke(context, node); + }, + + .builtin_code = kTfLiteBuiltinDelegate, + }; + + // Request TFLite to partition the graph and make kernels + // for each independent subgraph a new nnapi_delegate_kernel. + context->ReplaceSubgraphsWithDelegateKernels( + context, nnapi_delegate_kernel, + reinterpret_cast(supported_nodes.data()), + delegate); + return kTfLiteOk; + }}; + + return &delegate; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h new file mode 100644 index 0000000000000000000000000000000000000000..44cca2fd285370d700525f98ba33c861fb97be1e --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ +#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Return a delegate that can be used to use the NN API. +// e.g. +// NnApiDelegate* delegate = NnApiDelegate(); +// interpreter->ModifyGraphWithDelegate(&delegate); +// NnApiDelegate() returns a singleton, so you should not free this +// pointer or worry about its lifetime. +TfLiteDelegate* NnApiDelegate(); +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_ diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ff2e721423f07889f36746a2889afcc3369f28fc --- /dev/null +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h" +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class FloatAddOpModel : public SingleOpModel { + public: + FloatAddOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +// Do a test with the NN API using no activation. +TEST(NNAPIDelegate, AddWithNoActivation) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +// Do a test with the NN api with relu. +TEST(NNAPIDelegate, AddWithRelu) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index 436c3e1d4cad5e6ee355d7e9cf8ee7da1a8385ce..840015a7fad173dbd2ea353786871dd4e89abb98 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -30,9 +30,7 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once -# the archive has been propagated in mirror.bazel.build. -GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD index 9322e186a280e932a2441ab16ac8579d9ab67ee2..c61445114ecc6dfbe4f2b6ab666b28a8aa746be3 100644 --- a/tensorflow/contrib/lite/examples/label_image/BUILD +++ b/tensorflow/contrib/lite/examples/label_image/BUILD @@ -53,19 +53,18 @@ cc_library( ], ) -# TODO(ahentz): Test disabled as it has a memory leek from read_bmp -# cc_test( -# name = "label_image_test", -# srcs = [ -# "get_top_n.h", -# "get_top_n_impl.h", -# "label_image_test.cc", -# ], -# data = [ -# "testdata/grace_hopper.bmp", -# ], -# deps = [ -# ":bitmap_helpers", -# "//testing/base/public:gunit", -# ], -# ) +cc_test( + name = "label_image_test", + srcs = [ + "get_top_n.h", + "get_top_n_impl.h", + "label_image_test.cc", + ], + data = [ + "testdata/grace_hopper.bmp", + ], + deps = [ + ":bitmap_helpers", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc index 0b38cd38c83927c65d251b9356301b6bef7521f2..2735d1f5ea4e2a104f71a3a6f874d9acb2f48142 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc @@ -28,8 +28,9 @@ limitations under the License. namespace tflite { namespace label_image { -uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output, - int width, int height, int channels, bool top_down) { +std::vector decode_bmp(const uint8_t* input, int row_size, int width, + int height, int channels, bool top_down) { + std::vector output(height * width * channels); for (int i = 0; i < height; i++) { int src_pos; int dst_pos; @@ -66,12 +67,11 @@ uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output, } } } - return output; } -uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, - int* channels, Settings* s) { +std::vector read_bmp(const std::string& input_bmp_name, int* width, + int* height, int* channels, Settings* s) { int begin, end; std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary); @@ -87,14 +87,15 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, if (s->verbose) LOG(INFO) << "len: " << len << "\n"; - const uint8_t* img_bytes = new uint8_t[len]; + std::vector img_bytes(len); file.seekg(0, std::ios::beg); - file.read((char*)img_bytes, len); + file.read(reinterpret_cast(img_bytes.data()), len); const int32_t header_size = - *(reinterpret_cast(img_bytes + 10)); - *width = *(reinterpret_cast(img_bytes + 18)); - *height = *(reinterpret_cast(img_bytes + 22)); - const int32_t bpp = *(reinterpret_cast(img_bytes + 28)); + *(reinterpret_cast(img_bytes.data() + 10)); + *width = *(reinterpret_cast(img_bytes.data() + 18)); + *height = *(reinterpret_cast(img_bytes.data() + 22)); + const int32_t bpp = + *(reinterpret_cast(img_bytes.data() + 28)); *channels = bpp / 8; if (s->verbose) @@ -110,10 +111,9 @@ uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, bool top_down = (*height < 0); // Decode image, allocating tensor once the image size is known - uint8_t* output = new uint8_t[abs(*height) * *width * *channels]; const uint8_t* bmp_pixels = &img_bytes[header_size]; - return decode_bmp(bmp_pixels, row_size, output, *width, abs(*height), - *channels, top_down); + return decode_bmp(bmp_pixels, row_size, *width, abs(*height), *channels, + top_down); } } // namespace label_image diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h index 97343dde6b31694e5b2de20b35a7083fb8fe4a0e..5fc75b1f7274c14d49e4a26d6ce4902c037afa6b 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h @@ -22,8 +22,8 @@ limitations under the License. namespace tflite { namespace label_image { -uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, - int* channels, Settings* s); +std::vector read_bmp(const std::string& input_bmp_name, int* width, + int* height, int* channels, Settings* s); template void resize(T* out, uint8_t* in, int image_height, int image_width, diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc index 966fcd2a31fd4d4ff2c3e91633550a8effa81ee8..86d7d1cc4a625243791d5e7d5b746526a58efb6d 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -138,8 +138,8 @@ void RunInference(Settings* s) { int image_width = 224; int image_height = 224; int image_channels = 3; - uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height, - &image_channels, s); + std::vector in = read_bmp(s->input_bmp_name, &image_width, + &image_height, &image_channels, s); int input = interpreter->inputs()[0]; if (s->verbose) LOG(INFO) << "input: " << input << "\n"; @@ -168,12 +168,12 @@ void RunInference(Settings* s) { switch (interpreter->tensor(input)->type) { case kTfLiteFloat32: s->input_floating = true; - resize(interpreter->typed_tensor(input), in, image_height, - image_width, image_channels, wanted_height, wanted_width, - wanted_channels, s); + resize(interpreter->typed_tensor(input), in.data(), + image_height, image_width, image_channels, wanted_height, + wanted_width, wanted_channels, s); break; case kTfLiteUInt8: - resize(interpreter->typed_tensor(input), in, + resize(interpreter->typed_tensor(input), in.data(), image_height, image_width, image_channels, wanted_height, wanted_width, wanted_channels, s); break; diff --git a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc index ce35483f76e8f40ced79e1ee30774c62d0eba94e..de7de21f7741d3d46cb96e793e8bc4bfb21384fe 100644 --- a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc +++ b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc @@ -27,20 +27,20 @@ namespace label_image { TEST(LabelImageTest, GraceHopper) { std::string lena_file = - "tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp"; + "tensorflow/contrib/lite/examples/label_image/testdata/" + "grace_hopper.bmp"; int height, width, channels; Settings s; - uint8_t *data; - - data = read_bmp(lena_file, &width, &height, &channels, &s); + std::vector input = + read_bmp(lena_file, &width, &height, &channels, &s); ASSERT_EQ(height, 606); ASSERT_EQ(width, 517); ASSERT_EQ(channels, 3); - uint8_t *out = new uint8_t[606 * 517 * 3]; - downsize(out, data, 606, 517, 3, 214, 214, 3, &s); - ASSERT_EQ(out[0], 0x15); - ASSERT_EQ(out[214 * 214 * 3 - 1], 0x12); + std::vector output(606 * 517 * 3); + resize(output.data(), input.data(), 606, 517, 3, 214, 214, 3, &s); + ASSERT_EQ(output[0], 0x15); + ASSERT_EQ(output[214 * 214 * 3 - 1], 0x11); } TEST(LabelImageTest, GetTopN) { diff --git a/tensorflow/contrib/lite/examples/minimal/minimal.cc b/tensorflow/contrib/lite/examples/minimal/minimal.cc index 106e3b027055b67092f653c6bcdc4827b56bdbaa..8b0ace96ccaf06ac1cbdc2ea95ac6e92ef886993 100644 --- a/tensorflow/contrib/lite/examples/minimal/minimal.cc +++ b/tensorflow/contrib/lite/examples/minimal/minimal.cc @@ -38,7 +38,7 @@ using namespace tflite; int main(int argc, char *argv[]) { if(argc != 2) { - fprintf(stderr, "Usage: %s \n"); + fprintf(stderr, "minimal \n"); return 1; } const char* filename = argv[1]; diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md new file mode 100644 index 0000000000000000000000000000000000000000..bd2f797e6c5b05f52bec9fc34f1b8011aca70330 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md @@ -0,0 +1,206 @@ +# TensorFlow Lite Ops Versioning + +This document describes TensorFlow Lite's op versioning schema. Op +versioning enables developers to add new functionalities and parameters into +existing ops. In addition, it guarantees the following: + +* Backward compatibility: New TensorFlow Lite implementation should + handle an old model file. +* Forward compatibility: Old TensorFlow Lite implementation should + handle a new model file produced by new version of TOCO, as long as no new + features are used. +* Forward in-compatibility detection: If an old TensorFlow Lite implementation + reads a new model that contains a new version of an op which isn't + supported, it should report the error. + +## Example: Adding Dilation into Convolution + +The remainder of this document explains op versioning in TFLite by showing how +to add dilation parameters to the convolution operation. + +Knowledge of dilation is not required to understand this document. Note that: + +* 2 new integer parameters will be added: `dilation_width_factor` and + `dilation_height_factor`. +* Old convolution kernels that don't support dilation are equivalent to + setting the dilation factors to 1. + +### Change FlatBuffer Schema + +To add new parameters into an op, change the options table in +`lite/schema/schema.fbs`. + +For example, the options table of convolution looks like this: + +``` +table Conv2DOptions { + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; +} +``` + +When adding new parameters: + +* Add comments indicating which parameters are supported by which version. +* When the new implementation gets the default values for newly added + parameters, it should work exactly the same as the old implementation. + +The table will be like this after the new parameters are added: + +``` +table Conv2DOptions { + // Parameters supported by version 1: + padding:Padding; + stride_w:int; + stride_h:int; + fused_activation_function:ActivationFunctionType; + + // Parameters supported by version 2: + dilation_width_factor:int = 1; + dilation_height_factor:int = 1; +} +``` + +### Change C Structures and Kernel Implementation + +In TensorFlow Lite, the kernel implementation is decoupled from +FlatBuffer definition. The kernels read the parameter from C structures defined +in `lite/builtin_op_data.h`. + +The original convolution parameter is as follows: + +``` +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + TfLiteFusedActivation activation; +} TfLiteConvParams; +``` + +As with the FlatBuffer schema, add comments indicating which parameters are +supported starting from which version. The result is seen below: + +``` +typedef struct { + // Parameters supported by version 1: TfLitePadding padding; int + stride_width; + int stride_height; + TfLiteFusedActivation activation; + + // Parameters supported by version 2: + int dilation_width_factor; + int dilation_height_factor; +} TfLiteConvParams; +``` + +Please also change the kernel implementation to read the newly added parameters +from the C structures. The details are omitted here. + +### Change the FlatBuffer Reading Code + +The logic to read FlatBuffer and produce C structure is in `lite/model.cc`. + +Update the file to handle the new parameters, as shown below: + +``` +case BuiltinOperator_CONV_2D: { + TfLiteConvParams* params = MallocPOD(); + if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + params->dilation_width_factor = conv_params->dilation_width_factor(); + params->dilation_height_factor = conv_params->dilation_height_factor(); + } + *builtin_data = reinterpret_cast(params); + break; +} +``` + +It's not required to check the op version here. When the new implementation +reads an old model file where dilation factors are missing, it will use 1 as +the default value, and the new kernel will work consistently with the old +kernel. + +### Change Kernel Registration + +The MutableOpResolver (defined in `lite/op_resolver.h`) provides a few functions +to register op kernels. The minimum and maximum version are 1 by default: + +``` +void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); +void AddCustom(const char* name, TfLiteRegistration* registration, + int min_version = 1, int max_version = 1); +``` + +The built-in ops are registered in `lite/kernels/register.cc`. In this example, +we implemented a new op kernel which can handle `Conv2D` version 1 and 2, so we +need to change this line: + +``` +AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D()); +``` + +to: + +``` +AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(), 1, 2); +``` + +### Change TOCO TFLite exporter + +The last step is to make TOCO populate the minimum version that's required to +execute the op. In this example, it means: + +* Populate version=1 when dilation factors are all 1. +* Populate version=2 otherwise. + +To do this, you need to override `GetVersion` function for the operator class in +`lite/toco/tflite/operator.cc`. + +For ops with only one version, the `GetVersion` function is defined as: + +``` +int GetVersion(const Operator& op) const override { return 1; } +``` + +When supporting multiple versions, check the parameters and determine the +version for the op, as shown in the following example: + +``` +int GetVersion(const Operator& op) const override { + const auto& conv_op = static_cast(op); + if (conv_op.dilation_width_factor != 1 || + conv_op.dilation_height_factor != 1) { + return 2; + } + return 1; +} +``` + +### Delegation Implementation + +TensorFlow Lite provides a delegation API which enables delegating ops to +hardware backends. In Delegate's `Prepare` function, check if the version +is supported for every node in Delegation code. + +``` +const int kMinVersion = 1; +TfLiteNode* node; +TfLiteRegistration; +context->GetNodeAndRegistration(context, node_index, &node, ®istration); + +if (registration->version > kMinVersion) { + // Reject the node if the version isn't supported. +} +``` + +This is required even if the delegation only supports version 1 ops, so the +delegation can detect incompatibility when getting a higher version op. + diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index d8c46e633151cba94ff3d2a3c8b0ab5c230f245e..965273f0f04d33b52903c0551fff3533c31d3bd8 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -95,11 +95,7 @@ Here is a list of TensorFlow operations that are usually removed from the graph: * [tf.divide](https://www.tensorflow.org/api_docs/python/tf/divide) * [tf.fake_quant_with_min_max_args](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_args) * [tf.fake_quant_with_min_max_vars](https://www.tensorflow.org/api_docs/python/tf/fake_quant_with_min_max_vars) -* [tf.greater](https://www.tensorflow.org/api_docs/python/tf/greater) -* [tf.greater_equal](https://www.tensorflow.org/api_docs/python/tf/greater_equal) * [tf.identity](https://www.tensorflow.org/api_docs/python/tf/identity) -* [tf.less](https://www.tensorflow.org/api_docs/python/tf/less) -* [tf.less_equal](https://www.tensorflow.org/api_docs/python/tf/less_equal) * [tf.maximum](https://www.tensorflow.org/api_docs/python/tf/maximum) * [tf.minimum](https://www.tensorflow.org/api_docs/python/tf/minimum) * [tf.multiply](https://www.tensorflow.org/api_docs/python/tf/multiply) @@ -257,6 +253,19 @@ Options { } ``` +**EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is + equal to the corresponding element of the second tensor. +} +``` + **EXP** ``` @@ -420,6 +429,17 @@ Outputs { } ``` +**LOG** + +``` +Inputs { + 0: a tensor +} +Outputs { + 0: a tensor equivalent to log(input) +} +``` + **LOG_SOFTMAX** ``` @@ -503,6 +523,19 @@ Options { } ``` +**NOT_EQUAL** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: a tensor of type bool, true whenever an element of the first tensor is not + equal to the corresponding element of the second tensor. +} +``` + **RELU** ``` @@ -607,6 +640,21 @@ Outputs { } ``` +**SPARSE_TO_DENSE** + +``` +Inputs { + 0: 0D or 1D or 2D tensor + 1: 1D tensor + 2: 0D or 1D tensor + 3: 0D tensor + 4: a boolean value +} +Outputs { + 0: Dense Tensor of shape output_shape. Has the same type as sparse_values. +} +``` + **SPLIT** ``` diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 453c1ada1cf6263be14a3b170f209e3a30580cc3..4c78466480bdcfb42b1f582ecc7c185201a81b05 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -211,7 +211,7 @@ TEST(BasicInterpreter, CheckArenaAllocation) { TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; std::vector sizes{2048, 4096, 1023, 2047, 1021, - 2047, 1023, 2046, 1021, 2048}; + 2047, 1023, 2046, 0, 2048}; for (int i = 0; i < sizes.size(); ++i) { interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]}, quant); @@ -228,6 +228,7 @@ TEST(BasicInterpreter, CheckArenaAllocation) { ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw); ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw); + ASSERT_EQ(interpreter.tensor(8)->data.raw, nullptr); ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw); ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw); @@ -314,6 +315,18 @@ TEST(BasicInterpreter, ResizingTensors) { EXPECT_EQ(tensor->bytes, 8 * sizeof(float)); ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(t, {}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 1 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {0}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 0); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 0}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 0); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + // TODO(ahentz): We shouldn't have to force reallocation, but // ResizeInputTensor doesn't realloc dynamic tensors. Also note that // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op. diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java index 644ce4cb3e0beaed2b9ae542cdacbb912ab0f010..fd1f0ffa68eeca7b5866b146ecaa1f9216ef377d 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -17,6 +17,7 @@ package org.tensorflow.lite; import java.io.File; import java.nio.ByteBuffer; +import java.nio.MappedByteBuffer; import java.util.HashMap; import java.util.Map; import org.checkerframework.checker.nullness.qual.NonNull; @@ -103,6 +104,27 @@ public final class Interpreter implements AutoCloseable { wrapper = new NativeInterpreterWrapper(byteBuffer, numThreads); } + /** + * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. + * + *

The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code + * Interpreter}. + */ + public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer) { + wrapper = new NativeInterpreterWrapper(mappedByteBuffer); + } + + /** + * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file and + * specifies the number of threads used for inference. + * + *

The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code + * Interpreter}. + */ + public Interpreter(@NonNull MappedByteBuffer mappedByteBuffer, int numThreads) { + wrapper = new NativeInterpreterWrapper(mappedByteBuffer, numThreads); + } + /** * Runs model inference if the model takes only one input, and provides only one output. * @@ -231,5 +253,14 @@ public final class Interpreter implements AutoCloseable { wrapper = null; } + @Override + protected void finalize() throws Throwable { + try { + close(); + } finally { + super.finalize(); + } + } + NativeInterpreterWrapper wrapper; } diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc index 005dca0253d2c30d56a15adf6e2b371d43f50945..9e9387da86ebde7d711a7ce967461e370c95bc3e 100644 --- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -43,31 +43,27 @@ size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, } switch (type) { case kTfLiteFloat32: { - jfloatArray a = static_cast(array); - jfloat* values = env->GetFloatArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseFloatArrayElements(a, values, JNI_ABORT); + jfloatArray float_array = static_cast(array); + jfloat* float_dst = static_cast(dst); + env->GetFloatArrayRegion(float_array, 0, num_elements, float_dst); return to_copy; } case kTfLiteInt32: { - jintArray a = static_cast(array); - jint* values = env->GetIntArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseIntArrayElements(a, values, JNI_ABORT); + jintArray int_array = static_cast(array); + jint* int_dst = static_cast(dst); + env->GetIntArrayRegion(int_array, 0, num_elements, int_dst); return to_copy; } case kTfLiteInt64: { - jlongArray a = static_cast(array); - jlong* values = env->GetLongArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseLongArrayElements(a, values, JNI_ABORT); + jlongArray long_array = static_cast(array); + jlong* long_dst = static_cast(dst); + env->GetLongArrayRegion(long_array, 0, num_elements, long_dst); return to_copy; } case kTfLiteUInt8: { - jbyteArray a = static_cast(array); - jbyte* values = env->GetByteArrayElements(a, nullptr); - memcpy(dst, values, to_copy); - env->ReleaseByteArrayElements(a, values, JNI_ABORT); + jbyteArray byte_array = static_cast(array); + jbyte* byte_dst = static_cast(dst); + env->GetByteArrayRegion(byte_array, 0, num_elements, byte_dst); return to_copy; } default: { diff --git a/tensorflow/contrib/lite/kernels/Android.bp b/tensorflow/contrib/lite/kernels/Android.bp index 59262d398ef173008d18ed912fc64d7ea5cb97d9..0e89cc33f6e7438f38d121cac2639965b35f9e51 100644 --- a/tensorflow/contrib/lite/kernels/Android.bp +++ b/tensorflow/contrib/lite/kernels/Android.bp @@ -49,10 +49,11 @@ cc_library_static { "depthwise_conv.cc", "dequantize.cc", "div.cc", - "elementwise.cc", + "elementwise.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", "exp.cc", + "expand_dims.cc", "floor.cc", "fully_connected.cc", "gather.cc", @@ -76,11 +77,13 @@ cc_library_static { "slice.cc", "space_to_batch_nd.cc", "space_to_depth.cc", + "sparse_to_dense.cc", "split.cc", "squeeze.cc", "strided_slice.cc", "sub.cc", "svdf.cc", + "tile.cc", "topk_v2.cc", "transpose.cc", "transpose_conv.cc", diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index b7291dd379a6c09a70a78de7bc6c2f217b293b26..cf5d0b4ce9cb3c516c185f31fea12db70a2c3bdb 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -147,6 +147,7 @@ cc_library( "embedding_lookup.cc", "embedding_lookup_sparse.cc", "exp.cc", + "expand_dims.cc", "floor.cc", "fully_connected.cc", "gather.cc", @@ -170,11 +171,13 @@ cc_library( "slice.cc", "space_to_batch_nd.cc", "space_to_depth.cc", + "sparse_to_dense.cc", "split.cc", "squeeze.cc", "strided_slice.cc", "sub.cc", "svdf.cc", + "tile.cc", "topk_v2.cc", "transpose.cc", "transpose_conv.cc", @@ -857,6 +860,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "tile_test", + size = "small", + srcs = ["tile_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "comparisons_test", size = "small", @@ -934,6 +951,34 @@ tf_cc_test( ], ) +tf_cc_test( + name = "expand_dims_test", + size = "small", + srcs = ["expand_dims_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "sparse_to_dense_test", + size = "small", + srcs = ["sparse_to_dense_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index 7ca1e35489cba3b5d2567bc04e532fedf8a527a7..443ce8924a43669fb264e19561c733d7e3436cb0 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -126,16 +126,19 @@ void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, int32 input1_multiplier; int input1_shift; - QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, - &input1_shift); + QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, + &input1_multiplier, &input1_shift); + input1_shift *= -1; int32 input2_multiplier; int input2_shift; - QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, - &input2_shift); + QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, + &input2_multiplier, &input2_shift); + input2_shift *= -1; int32 output_multiplier; int output_shift; - QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, - &output_shift); + QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, + &output_multiplier, &output_shift); + output_shift *= -1; int32 output_activation_min, output_activation_max; CalculateActivationRangeUint8(params->activation, output, diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index 7dc0c5656dca02a86339c558f4fe2babb4961695..c09b15b3d263d6cd639234590c99a50a9a48f4a7 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -36,7 +36,7 @@ constexpr int kOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index); + context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); return scratch_tensor_index; } @@ -91,7 +91,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) { int* scratch_tensor_index = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries = TfLiteIntArrayCreate(3); node->temporaries->data[0] = *scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); input_quantized->type = kTfLiteUInt8; @@ -114,6 +114,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, hidden_state_quantized, hidden_state_quantized_size)); } + node->temporaries->data[2] = *scratch_tensor_index + 2; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = batch_size; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } } return kTfLiteOk; @@ -145,14 +155,14 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input, return kTfLiteOk; } -TfLiteStatus EvalQuantized(const TfLiteTensor* input, - const TfLiteTensor* input_weights, - const TfLiteTensor* recurrent_weights, - const TfLiteTensor* bias, - const TfLiteRNNParams* params, - TfLiteTensor* input_scratch, - TfLiteTensor* hidden_state_scratch, - TfLiteTensor* hidden_state, TfLiteTensor* output) { +TfLiteStatus EvalHybrid(const TfLiteTensor* input, + const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, + const TfLiteTensor* bias, const TfLiteRNNParams* params, + TfLiteTensor* input_scratch, + TfLiteTensor* hidden_state_scratch, + TfLiteTensor* scaling_factors, + TfLiteTensor* hidden_state, TfLiteTensor* output) { const int batch_size = input->dims->data[0]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[1]; @@ -176,12 +186,14 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, reinterpret_cast(input_scratch->data.uint8); int8_t* quantized_hidden_state_ptr = reinterpret_cast(hidden_state_scratch->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; kernel_utils::RnnBatchStep( input_ptr_batch, input_weights_ptr, input_weights_scale, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, batch_size, params->activation, quantized_input_ptr, - quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch); + quantized_hidden_state_ptr, scaling_factors_ptr, hidden_state_ptr_batch, + output_ptr_batch); return kTfLiteOk; } @@ -205,9 +217,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(mirkov): implement eval with quantized inputs as well. TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); - return EvalQuantized(input, input_weights, recurrent_weights, bias, - params, input_quantized, hidden_state_quantized, - hidden_state, output); + TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + return EvalHybrid(input, input_weights, recurrent_weights, bias, params, + input_quantized, hidden_state_quantized, + scaling_factors, hidden_state, output); } default: context->ReportError(context, "Type %d not currently supported.", diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index 3b81062cd42f04582b33ea919ef2742d3d869c22..f678f48fa5bbbcece6c5b87030d951783378d78f 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -23,6 +23,7 @@ namespace tflite { namespace ops { namespace builtin { namespace comparisons { +namespace { constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; @@ -67,6 +68,57 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { GetTensorData(input2), GetTensorDims(input2), \ GetTensorData(output), GetTensorDims(output)); +TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, Equal, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, Equal, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, Equal, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type %d, requires float|int", + input1->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +// TODO(renjieliu): Refactor the logic to avoid duplications. +TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + bool requires_broadcast = !HaveSameShapes(input1, input2); + // TODO(renjieliu): Support quantized data. + switch (input1->type) { + case kTfLiteFloat32: + TF_LITE_COMPARISON(float, NotEqual, requires_broadcast); + break; + case kTfLiteInt32: + TF_LITE_COMPARISON(int32_t, NotEqual, requires_broadcast); + break; + case kTfLiteInt64: + TF_LITE_COMPARISON(int64_t, NotEqual, requires_broadcast); + break; + default: + context->ReportError(context, + "Does not support type %d, requires float|int", + input1->type); + return kTfLiteError; + } + return kTfLiteOk; +} + TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); @@ -167,8 +219,22 @@ TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +} // namespace } // namespace comparisons +TfLiteRegistration* Register_EQUAL() { + static TfLiteRegistration r = { + nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::EqualEval}; + return &r; +} + +TfLiteRegistration* Register_NOT_EQUAL() { + static TfLiteRegistration r = {nullptr, nullptr, + comparisons::ComparisonPrepare, + comparisons::NotEqualEval}; + return &r; +} + TfLiteRegistration* Register_GREATER() { static TfLiteRegistration r = {nullptr, nullptr, comparisons::ComparisonPrepare, diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc index 835d238d36d1757a27119ae24b3c07232e9d3dc0..bb02e1c812fdc40bf515f1f978e9e39b5a16a4ea 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc @@ -21,18 +21,17 @@ limitations under the License. namespace tflite { namespace { -using ::testing::ElementsAreArray; +using ::testing::ElementsAre; -class GreaterOpModel : public SingleOpModel { +class ComparisonOpModel : public SingleOpModel { public: - GreaterOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, - TensorType input_type) { + ComparisonOpModel(std::initializer_list input1_shape, + std::initializer_list input2_shape, + TensorType input_type, BuiltinOperator op) { input1_ = AddInput(input_type); input2_ = AddInput(input_type); output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_GREATER, BuiltinOptions_GreaterOptions, - CreateGreaterOptions(builder_).Union()); + ConfigureBuiltinOp(op); BuildInterpreter({input1_shape, input2_shape}); } @@ -46,245 +45,313 @@ class GreaterOpModel : public SingleOpModel { int input1_; int input2_; int output_; + + void ConfigureBuiltinOp(BuiltinOperator op) { + switch (op) { + case BuiltinOperator_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_EqualOptions, + CreateEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_NOT_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_NotEqualOptions, + CreateNotEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_GREATER: { + SetBuiltinOp(op, BuiltinOptions_GreaterOptions, + CreateGreaterOptions(builder_).Union()); + break; + } + case BuiltinOperator_GREATER_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_GreaterEqualOptions, + CreateGreaterEqualOptions(builder_).Union()); + break; + } + case BuiltinOperator_LESS: { + SetBuiltinOp(op, BuiltinOptions_LessOptions, + CreateLessOptions(builder_).Union()); + break; + } + case BuiltinOperator_LESS_EQUAL: { + SetBuiltinOp(op, BuiltinOptions_LessEqualOptions, + CreateLessEqualOptions(builder_).Union()); + break; + } + default: { FAIL() << "We shouldn't get here."; } + } + } }; -TEST(ComparisonsTest, GreaterFloat) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); +TEST(ComparisonsTest, EqualFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterInt) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); +TEST(ComparisonsTest, EqualInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterBroadcast) { - GreaterOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); +TEST(ComparisonsTest, EqualBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } -TEST(ComparisonsTest, GreaterBroadcastTwoD) { - GreaterOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); +TEST(ComparisonsTest, EqualBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, - false, true, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, false, false, + false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class GreaterEqualOpModel : public SingleOpModel { - public: - GreaterEqualOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, - TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_GREATER_EQUAL, - BuiltinOptions_GreaterEqualOptions, - CreateGreaterEqualOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } +TEST(ComparisonsTest, NotEqualFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); - int input1() { return input1_; } - int input2() { return input2_; } + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } +TEST(ComparisonsTest, NotEqualInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {1, 2, 7, 5}); + model.Invoke(); - private: - int input1_; - int input2_; - int output_; -}; + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, NotEqualBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, NotEqualBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_NOT_EQUAL); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, true, true, true, true, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); +} + +TEST(ComparisonsTest, GreaterFloat) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); + model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterInt) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {1, 2, 7, 5}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterBroadcast) { + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); + model.PopulateTensor(model.input2(), {7}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(ComparisonsTest, GreaterBroadcastTwoD) { + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER); + model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); + model.PopulateTensor(model.input2(), {7, 1, 2, 4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, true, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); +} TEST(ComparisonsTest, GreaterEqualFloat) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualInt) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualBroadcast) { - GreaterEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, GreaterEqualBroadcastTwoD) { - GreaterEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_GREATER_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, true, true, false, - false, true, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(false, true, true, false, false, true, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class LessOpModel : public SingleOpModel { - public: - LessOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_LESS, BuiltinOptions_LessOptions, - CreateLessOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } - - int input1() { return input1_; } - int input2() { return input2_; } - - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } - - private: - int input1_; - int input2_; - int output_; -}; TEST(ComparisonsTest, LessFloat) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(false, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessInt) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 6, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessBroadcast) { - LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessBroadcastTwoD) { - LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true, - true, false, false, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, false, false, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } -class LessEqualOpModel : public SingleOpModel { - public: - LessEqualOpModel(std::initializer_list input1_shape, - std::initializer_list input2_shape, - TensorType input_type) { - input1_ = AddInput(input_type); - input2_ = AddInput(input_type); - output_ = AddOutput(TensorType_BOOL); - SetBuiltinOp(BuiltinOperator_LESS_EQUAL, BuiltinOptions_LessEqualOptions, - CreateLessEqualOptions(builder_).Union()); - BuildInterpreter({input1_shape, input2_shape}); - } - - int input1() { return input1_; } - int input2() { return input2_; } - - std::vector GetOutput() { return ExtractVector(output_); } - std::vector GetOutputShape() { return GetTensorShape(output_); } - - private: - int input1_; - int input2_; - int output_; -}; - TEST(ComparisonsTest, LessEqualFloat) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {0.1, 0.9, 0.7, 0.3}); model.PopulateTensor(model.input2(), {0.1, 0.2, 0.6, 0.5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, false, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualInt) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {1, 2, 7, 5}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualBroadcast) { - LessEqualOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32); + ComparisonOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3}); model.PopulateTensor(model.input2(), {7}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, true, true})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4})); + EXPECT_THAT(model.GetOutput(), ElementsAre(true, false, true, true)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 1, 4)); } TEST(ComparisonsTest, LessEqualBroadcastTwoD) { - LessEqualOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32); + ComparisonOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32, + BuiltinOperator_LESS_EQUAL); model.PopulateTensor(model.input1(), {-1, 9, 7, 3, 2, 4, 2, 8}); model.PopulateTensor(model.input2(), {7, 1, 2, 4}); model.Invoke(); - EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true, - true, false, true, false})); - EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4})); + EXPECT_THAT(model.GetOutput(), + ElementsAre(true, false, false, true, true, false, true, false)); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 1, 2, 4)); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 0b35a220e783572985206e918f3fcd8361f16790..14b399ef96eab1d5066a22a7eb95ab061e8ba2bc 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -134,7 +134,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, // optimized_ops.h, in order to avoid a DCHECK(!im2col_data). data->need_im2col = (params->stride_width != 1 || params->stride_height != 1 || - filter_width != 1 || filter_height != 1); + params->dilation_width_factor != 1 || + params->dilation_height_factor != 1 || filter_width != 1 || + filter_height != 1); // If we're using the optimized multithreaded EigenTensor implementation of // convolution, it expects the filter weights to be transposed compared to // the normal TF Lite buffer format. Typical TF Lite weights are @@ -254,8 +256,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); - QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, - &data->output_shift); + TF_LITE_ENSURE(context, real_multiplier < 1.0); + QuantizeMultiplierSmallerThanOneExp( + real_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc index abb2549f85164e40875456fb732716e53b263127..a308de055f49eddba99d02e264fad11409a799f4 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -151,8 +151,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); - QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, - &data->output_shift); + int exponent; + QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent); + data->output_shift = -exponent; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc index 1439c8bce14ad127ed68dc54991aed8b8bb39383..c00cafb9fbfaf53d4dc301ccd3f21a6c6fd892e6 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc @@ -47,12 +47,6 @@ class BaseDepthwiseConvolutionOpModel : public SingleOpModel { } output_ = AddOutput(output); - if (input.type != TensorType_FLOAT32) { - // The following is required by quantized inference. It is the unittest's - // responsibility to make sure the output scale falls into the correct - // range. - CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); - } int input_depth = GetShape(input_)[3]; int output_depth = GetShape(filter_)[3]; @@ -176,6 +170,43 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) { })); } +TEST(QuantizedDepthwiseConvolutionOpTest, + SimpleTestQuantizedFilterMultiplierGreaterThan1) { + QuantizedDepthwiseConvolutionOpModel quant_op( + {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64}, + {TensorType_UINT8, {1, 2, 2, 4}, -128.5, 128}, + {TensorType_UINT8, {}, -127, 128}); + DepthwiseConvolutionOpModel float_op({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + std::initializer_list input = { + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }; + std::initializer_list filter = { + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }; + std::initializer_list bias = {1, 2, 3, 4}; + + quant_op.SetInput(input); + quant_op.SetFilter(filter); + quant_op.SetBias(bias); + quant_op.Invoke(); + + float_op.SetInput(input); + float_op.SetFilter(filter); + float_op.SetBias(bias); + float_op.Invoke(); + + EXPECT_THAT(quant_op.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear(float_op.GetOutput(), 1))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc index 0bd504695074011efd946f4c4d1f8d4854e82730..98c21ce9d390aaa1f3cb5fdb8f31cbffb1b81d6a 100644 --- a/tensorflow/contrib/lite/kernels/elementwise.cc +++ b/tensorflow/contrib/lite/kernels/elementwise.cc @@ -23,7 +23,7 @@ namespace ops { namespace builtin { namespace elementwise { -TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); const TfLiteTensor* input = GetInput(context, node, 0); @@ -35,7 +35,8 @@ TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArrayCopy(input->dims)); } -TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { +inline TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, + float float_func(float)) { const TfLiteTensor* input = GetInput(context, node, 0); TfLiteTensor* output = GetOutput(context, node, 0); switch (input->type) { @@ -44,7 +45,7 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { const float* in = GetTensorData(input); const float* in_end = in + elements; float* out = output->data.f; - for (; in < in_end; in++, out++) *out = std::sin(*in); + for (; in < in_end; in++, out++) *out = float_func(*in); return kTfLiteOk; } default: { @@ -55,14 +56,28 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { } } +TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::sin); +} + +TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { + return Eval(context, node, std::log); +} + } // namespace elementwise TfLiteRegistration* Register_SIN() { - static TfLiteRegistration r = {nullptr, nullptr, elementwise::SinPrepare, + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, elementwise::SinEval}; return &r; } +TfLiteRegistration* Register_LOG() { + static TfLiteRegistration r = {nullptr, nullptr, elementwise::GenericPrepare, + elementwise::LogEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/elementwise_test.cc b/tensorflow/contrib/lite/kernels/elementwise_test.cc index 412ffb04b90fbc24d232d25d2a86ce639752c3e8..10e88d5a31868eeb5f65c7ade1f1c73827dea24a 100644 --- a/tensorflow/contrib/lite/kernels/elementwise_test.cc +++ b/tensorflow/contrib/lite/kernels/elementwise_test.cc @@ -24,12 +24,13 @@ namespace { using ::testing::ElementsAreArray; -class SinOpModel : public SingleOpModel { +class ElementWiseOpModel : public SingleOpModel { public: - SinOpModel(std::initializer_list input_shape) { + ElementWiseOpModel(BuiltinOperator op, + std::initializer_list input_shape) { input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp(BuiltinOperator_SIN, BuiltinOptions_NONE, 0); + SetBuiltinOp(op, BuiltinOptions_NONE, 0); BuildInterpreter({input_shape}); } @@ -42,7 +43,7 @@ class SinOpModel : public SingleOpModel { }; TEST(ElementWise, Sin) { - SinOpModel m({1, 1, 4, 1}); + ElementWiseOpModel m(BuiltinOperator_SIN, {1, 1, 4, 1}); m.PopulateTensor(m.input(), {0, 3.1415926, -3.1415926, 1}); m.Invoke(); EXPECT_THAT(m.ExtractVector(m.output()), @@ -50,6 +51,15 @@ TEST(ElementWise, Sin) { EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); } +TEST(ElementWise, Log) { + ElementWiseOpModel m(BuiltinOperator_LOG, {1, 1, 4, 1}); + m.PopulateTensor(m.input(), {1, 3.1415926, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.ExtractVector(m.output()), + ElementsAreArray(ArrayFloatNear({0, 1.14473, 0, 0}))); + EXPECT_THAT(m.GetTensorShape(m.output()), ElementsAreArray({1, 1, 4, 1})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc index 7539c0b30ded921df957217bebdc7b20ea4b40b4..9410bead5e7a68363d034c22fb2c0eff9f060ef1 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -24,7 +24,8 @@ limitations under the License. // Output: // Output.dim[0] == Tensor[0].dim[0], num of lookups // Output.dim[1] == Tensor[1].dim[1], num of items per row -// Each item in output is a raw bytes copy of corresponding item in input. +// Each item in output is a raw bytes copy of the corresponding item in input, +// or a dequantized value in the case of a uint8 input. // When indices are out of bound, the ops will not succeed. // @@ -69,11 +70,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, outputSize); } -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - TfLiteTensor* output = GetOutput(context, node, 0); - const TfLiteTensor* lookup = GetInput(context, node, 0); - const TfLiteTensor* value = GetInput(context, node, 1); - +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* lookup, const TfLiteTensor* value, + TfLiteTensor* output) { const int row_size = SizeOfDimension(value, 0); const int row_bytes = value->bytes / row_size; @@ -91,6 +90,52 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, + const TfLiteTensor* lookup, const TfLiteTensor* value, + TfLiteTensor* output) { + const int row_size = SizeOfDimension(value, 0); + const double scaling_factor = 1.0 / value->params.scale; + + // col_size after we flatten tensor into 2D. + int col_size = 1; + for (int i = 1; i < NumDimensions(value); i++) { + col_size *= SizeOfDimension(value, i); + } + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = lookup->data.i32[i]; + if (idx >= row_size || idx < 0) { + context->ReportError(context, "Embedding Lookup: index out of bounds."); + return kTfLiteError; + } else { + // Dequantize embedding values. + // TODO(alanchiao): refactor scalar multiply into separate function + // for ease of adding a neon equivalent if ever necessary. + for (int j = 0; j < col_size; j++) { + output->data.f[j + i * col_size] = + value->data.uint8[j + idx * col_size] * scaling_factor; + } + } + } + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* lookup = GetInput(context, node, 0); + const TfLiteTensor* value = GetInput(context, node, 1); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (value->type) { + case kTfLiteFloat32: + return EvalFloat(context, node, lookup, value, output); + case kTfLiteUInt8: + return EvalHybrid(context, node, lookup, value, output); + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } +} + } // namespace embedding_lookup TfLiteRegistration* Register_EMBEDDING_LOOKUP() { diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc index 9b501878f196216a61568bfa36e6615f4dd07478..04657fd86323ef1c58d069c06097c7665f55cc87 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -7,13 +7,14 @@ You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License +for the specific language governing permissions and limitations under the +License. ==============================================================================*/ // Unit test for TFLite Lookup op. +#include #include #include @@ -29,12 +30,13 @@ namespace { using ::testing::ElementsAreArray; -class EmbeddingLookupOpModel : public SingleOpModel { +class BaseEmbeddingLookupOpModel : public SingleOpModel { public: - EmbeddingLookupOpModel(std::initializer_list index_shape, - std::initializer_list weight_shape) { + BaseEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape, + TensorType weight_type = TensorType_FLOAT32) { input_ = AddInput(TensorType_INT32); - weight_ = AddInput(TensorType_FLOAT32); + weight_ = AddInput(weight_type); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); BuildInterpreter({index_shape, weight_shape}); @@ -44,6 +46,18 @@ class EmbeddingLookupOpModel : public SingleOpModel { PopulateTensor(input_, data); } + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weight_; + int output_; +}; + +class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + using BaseEmbeddingLookupOpModel::BaseEmbeddingLookupOpModel; + void Set3DWeightMatrix(const std::function& function) { TfLiteTensor* tensor = interpreter_->tensor(weight_); int rows = tensor->dims->data[0]; @@ -57,20 +71,25 @@ class EmbeddingLookupOpModel : public SingleOpModel { } } } +}; - std::vector GetOutput() { return ExtractVector(output_); } +class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { + public: + HybridEmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape) + : BaseEmbeddingLookupOpModel(index_shape, weight_shape, + TensorType_UINT8) {} - private: - int input_; - int weight_; - int output_; + void SetWeight(std::initializer_list data) { + SymmetricQuantizeAndPopulate(weight_, data); + } }; // TODO(ahentz): write more tests that exercise the details of the op, such as // lookup errors and variable input shapes. TEST(EmbeddingLookupOpTest, SimpleTest) { EmbeddingLookupOpModel m({3}, {3, 2, 4}); - m.PopulateTensor(0, {1, 0, 2}); + m.SetInput({1, 0, 2}); m.Set3DWeightMatrix( [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); @@ -84,6 +103,69 @@ TEST(EmbeddingLookupOpTest, SimpleTest) { }))); } +TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 8}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + +TEST(HybridEmbeddingLookupHybridOpTest, Simple3DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + +TEST(HybridEmbeddingLookupHybridOpTest, Simple4DTest) { + HybridEmbeddingLookupOpModel m({3}, {3, 2, 2, 2}); + m.SetInput({1, 0, 2}); + m.SetWeight({ + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + { + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }, + 7.41e-03))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc new file mode 100644 index 0000000000000000000000000000000000000000..ed33012864354cd93eac2344f75d7eca302c8952 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/expand_dims.cc @@ -0,0 +1,113 @@ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +namespace tflite { +namespace ops { +namespace builtin { +namespace expand_dims { +constexpr int kInput = 0; +constexpr int kAxis = 1; +constexpr int kOutput = 0; + +namespace { +TfLiteStatus ExpandTensorDim(TfLiteContext* context, const TfLiteTensor& input, + int axis, TfLiteTensor* output) { + const TfLiteIntArray& input_dims = *input.dims; + if (axis < 0) { + axis = input_dims.size + 1 + axis; + } + TF_LITE_ENSURE(context, axis <= input_dims.size); + + TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_dims.size + 1); + for (int i = 0; i < output_dims->size; ++i) { + if (i < axis) { + output_dims->data[i] = input_dims.data[i]; + } else if (i == axis) { + output_dims->data[i] = 1; + } else { + output_dims->data[i] = input_dims.data[i - 1]; + } + } + + return context->ResizeTensor(context, output, output_dims); +} + +TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context, + const TfLiteTensor& axis, int* axis_value) { + TF_LITE_ENSURE_EQ(context, NumElements(&axis), 1); + switch (axis.type) { + case kTfLiteInt32: + *axis_value = *GetTensorData(&axis); + return kTfLiteOk; + case kTfLiteInt64: + *axis_value = *GetTensorData(&axis); + return kTfLiteOk; + default: + return kTfLiteError; + } +} + +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + const TfLiteTensor* input = GetInput(context, node, kInput); + const TfLiteTensor* axis = GetInput(context, node, kAxis); + TfLiteTensor* output = GetOutput(context, node, 0); + output->type = input->type; + if (IsConstantTensor(axis)) { + int axis_value; + TF_LITE_ENSURE_OK(context, + GetAxisValueFromTensor(context, *axis, &axis_value)); + return ExpandTensorDim(context, *input, axis_value, output); + } + SetTensorToDynamic(output); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + // Just copy input to output. + const TfLiteTensor* input = GetInput(context, node, kInput); + TfLiteTensor* output = GetOutput(context, node, 0); + const TfLiteTensor* axis = GetInput(context, node, kAxis); + if (IsDynamicTensor(output)) { + int axis_value; + TF_LITE_ENSURE_OK(context, + GetAxisValueFromTensor(context, *axis, &axis_value)); + TF_LITE_ENSURE_OK(context, + ExpandTensorDim(context, *input, axis_value, output)); + } + memcpy(output->data.raw, input->data.raw, input->bytes); + return kTfLiteOk; +} + +} // namespace expand_dims +TfLiteRegistration* Register_EXPAND_DIMS() { + static TfLiteRegistration r = {nullptr, nullptr, expand_dims::Prepare, + expand_dims::Eval}; + return &r; +} +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b755e8ce293442813b26ec3177162a3c95af2f89 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc @@ -0,0 +1,83 @@ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ExpandDimsOpModel : public SingleOpModel { + public: + ExpandDimsOpModel(std::initializer_list input_shape, + TensorType input_type) { + input_ = AddInput(input_type); + axis_ = AddInput(TensorType_INT32); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_EXPAND_DIMS, BuiltinOptions_ExpandDimsOptions, + 0); + BuildInterpreter({input_shape, {1}}); + } + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } + std::vector GetValuesFloat() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int axis_; + int output_; +}; + +TEST(ExpandDimsOpTest, DifferentAxis) { + ExpandDimsOpModel m({2, 2}, TensorType_FLOAT32); + const auto values = {-1.f, 1.f, -2.f, 2.f}; + m.SetInputFloat(values); + m.SetAxis(0); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2})); + + m.SetAxis(1); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2})); + + m.SetAxis(2); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1})); + + m.SetAxis(-1); + m.Invoke(); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 1})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index 3374923e6e0f2f4520d2d67e698ae5b4fd8e5443..f6fc0f5b6ad12d58c541efc6eae566ab4b8327f4 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -101,16 +101,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { input_size *= input->dims->data[i]; } + TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2); const int batch_size = input_size / filter->dims->data[1]; const int num_units = filter->dims->data[0]; - TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]); + TF_LITE_ENSURE_EQ(context, input_size, batch_size * filter->dims->data[1]); if (bias) { TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0)); } - TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2); - // Note that quantized inference requires that all tensors have their // parameters set. This is usually done during quantized training. TfLiteType data_type = input->type; @@ -118,8 +117,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { double real_multiplier = 0.0; TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( context, input, filter, bias, output, &real_multiplier)); - QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, - &data->output_shift); + TF_LITE_ENSURE(context, real_multiplier < 1.0); + QuantizeMultiplierSmallerThanOneExp( + real_multiplier, &data->output_multiplier, &data->output_shift); + data->output_shift *= -1; CalculateActivationRangeUint8(params->activation, output, &data->output_activation_min, &data->output_activation_max); @@ -218,11 +219,8 @@ TfLiteStatus EvalPieQuantized(TfLiteContext* context, TfLiteNode* node, tensor_utils::ZeroVector(output->data.f, batch_size * num_units); } - // TODO(mirkov): change std::minmax_element with a vectorized call. - auto minmax_element = - std::minmax_element(input->data.f, input->data.f + total_input_size); // Save matrix multiplication computation for all zero input. - if (*minmax_element.first == 0.0 && *minmax_element.second == 0.0) { + if (tensor_utils::IsZeroVector(input->data.f, total_input_size)) { tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units, params->activation, output->data.f); diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index aabbb0685c5f9160cd863952cb47526215dc31a6..75298b995d6184985efc76c60c2f5541e9cbea40 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -420,6 +420,15 @@ cc_library( }), ) +cc_library( + name = "test_util", + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + ":types", + ], +) + cc_test( name = "tensor_utils_test", srcs = ["tensor_utils_test.cc"], @@ -440,6 +449,84 @@ cc_test( ], ) +cc_test( + name = "depthwiseconv_float_test", + srcs = ["depthwiseconv_float_test.cc"], + deps = [ + ":optimized_base", + ":reference_base", + ":test_util", + ":types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "depthwiseconv_quantized_test", + srcs = ["depthwiseconv_quantized_test.cc"], + deps = [ + ":optimized_base", + ":reference_base", + ":test_util", + ":types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "resize_bilinear_test", + srcs = ["resize_bilinear_test.cc"], + tags = ["tflite_not_portable"], + deps = [ + ":optimized_base", + ":reference_base", + ":test_util", + ":types", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "softmax_quantized_test", + timeout = "long", + srcs = [ + "softmax_quantized_test.cc", + ], + deps = [ + ":optimized_base", + ":quantization_util", + ":reference_base", + ":test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "logsoftmax_quantized_test", + timeout = "long", + srcs = [ + "logsoftmax_quantized_test.cc", + ], + tags = ["tflite_not_portable"], + deps = [ + ":optimized_base", + ":quantization_util", + ":reference_base", + ":test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "log_quantized_test", + srcs = ["log_quantized_test.cc"], + deps = [ + ":optimized_base", + ":reference_base", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "cpu_check", hdrs = [ diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index 4d3803a07fbdaf9d0d27664037c219715884ee76..3543b6f79e1633b4ffa0872b2999161a519c8cad 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -89,12 +89,12 @@ float ActivationFunction(float x) { output_activation_max); } -inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( - int32 x, int32 quantized_multiplier, int right_shift) { +inline int32 MultiplyByQuantizedMultiplierSmallerThanOneExp( + int32 x, int32 quantized_multiplier, int left_shift) { using gemmlowp::RoundingDivideByPOT; using gemmlowp::SaturatingRoundingDoublingHighMul; return RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), -left_shift); } inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..844ee6a53dd65b81f21ae1ef5b6d04192744a304 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_float_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" + +namespace tflite { +namespace { + +// Runs the DepthwiseConv and compares against the reference implementation. +template +void TestOneDepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride, int pad_width, int pad_height, + int depth_multiplier, const Dims<4>& output_dims) { + const int output_buffer_size = RequiredBufferSizeForDims(output_dims); + std::vector output_data(output_buffer_size); + std::vector reference_output_data(output_buffer_size); + reference_ops::DepthwiseConv(input_data, input_dims, filter_data, + filter_dims, bias_data, bias_dims, stride, + pad_width, pad_height, depth_multiplier, + reference_output_data.data(), output_dims); + optimized_ops::DepthwiseConv(input_data, input_dims, filter_data, + filter_dims, bias_data, bias_dims, stride, + pad_width, pad_height, depth_multiplier, + output_data.data(), output_dims); + double sum_abs_diff = 0; + float max_abs_val = 0; + for (int i = 0; i < output_buffer_size; i++) { + sum_abs_diff += std::abs(output_data[i] - reference_output_data[i]); + max_abs_val = std::max(max_abs_val, std::abs(reference_output_data[i])); + } + if (sum_abs_diff != 0.f) { + const float mean_diff = + static_cast(sum_abs_diff / output_buffer_size); + const float relative_error = std::abs(mean_diff) / max_abs_val; + ASSERT_LT(relative_error, 1e-5f); + } +} + +void TestOneDepthwiseConv(FusedActivationFunctionType Ac, + const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride, int pad_width, int pad_height, + int depth_multiplier, const Dims<4>& output_dims) { +#define TOCO_HANDLE_CASE(AC_TYPE) \ + if (AC_TYPE == Ac) { \ + TestOneDepthwiseConv(input_data, input_dims, filter_data, \ + filter_dims, bias_data, bias_dims, stride, \ + pad_width, pad_height, depth_multiplier, \ + output_dims); \ + return; \ + } + TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6) +#undef TOCO_HANDLE_CASE +} + +// This function picks some random DepthwiseConv params, which may or may not +// be legal. If they're not legal, it returns false. If they're legal, +// it runs the DepthwiseConv test and returns true. This allows the caller +// to loop until a test has been run. +bool TryTestOneDepthwiseConv() { + // We have to pick a lot of positive values, where we are particularly + // interested in small values because they are most likely to be special + // cases in optimized implementations, and secondarily because they allow + // tests to run fast, which means we can run more tests and get more + // coverage. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int filter_width = ExponentialRandomPositiveInt(0.9f, 4, 10); + const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10); + const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const int output_depth = input_depth * depth_multiplier; + // The optimized DepthwiseConv implementation currently uses a fixed-size + // accumulator buffer on the stack, with that size. This currently means + // that it does not support larger output depths. It CHECK's for it, + // so it's safe in the sense that if a larger output depth was encountered, + // it would explicitly fail. We just need to adjust our testing to that + // constraint. + const int kMaxSupportedOutputDepth = 1024; + if (output_depth > kMaxSupportedOutputDepth) { + return false; + } + const auto ac = RandomElement(std::vector( + {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu, + FusedActivationFunctionType::kRelu6, + FusedActivationFunctionType::kRelu1})); + Dims<4> input_dims_inference = + MakeDimsForInference(input_depth, input_width, input_height, batch); + Dims<4> output_dims_inference; + int pad_width, pad_height; + const auto padding_type = + UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid; + if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width, + filter_height, stride, padding_type, + &output_dims_inference, &pad_width, &pad_height)) { + return false; + } + Dims<4> filter_dims_inference = + MakeDimsForInference(output_depth, filter_width, filter_height, 1); + Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1); + const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference); + const int filter_buffer_size = + RequiredBufferSizeForDims(filter_dims_inference); + std::vector input_data(input_buffer_size); + std::vector filter_data(filter_buffer_size); + std::vector bias_data(output_depth); + const float input_amplitude = 1.f; + const float filter_amplitude = 1.f; + const float bias_amplitude = + filter_width * filter_height * input_amplitude * filter_amplitude; + FillRandom(&input_data, -input_amplitude, input_amplitude); + FillRandom(&filter_data, -filter_amplitude, filter_amplitude); + FillRandom(&bias_data, -bias_amplitude, bias_amplitude); + TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference, + filter_data.data(), filter_dims_inference, + bias_data.data(), bias_dims_inference, stride, pad_width, + pad_height, depth_multiplier, output_dims_inference); + return true; +} + +void TestOneDepthwiseConv() { + while (!TryTestOneDepthwiseConv()) { + } +} + +TEST(TestDepthwiseConv, TestDepthwiseConv) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + TestOneDepthwiseConv(); + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c0fc8433e18fb7f7f89c17380210d94b39ffc94 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -0,0 +1,330 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" + +namespace tflite { +namespace { + +// Runs the DepthwiseConv and compares against the reference implementation. +template +int TestOneDepthwiseConvWithGivenOutputShift( + const std::uint8_t* input_data, const Dims<4>& input_dims, + std::int32_t input_offset, const std::uint8_t* filter_data, + const Dims<4>& filter_dims, std::int32_t filter_offset, + const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + std::int32_t output_offset, std::int32_t output_multiplier, + int output_shift, std::int32_t output_activation_min, + std::int32_t output_activation_max, const Dims<4>& output_dims) { + const int output_buffer_size = RequiredBufferSizeForDims(output_dims); + std::vector output_data(output_buffer_size); + std::vector reference_output_data(output_buffer_size); + reference_ops::DepthwiseConv( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, + reference_output_data.data(), output_dims); + optimized_ops::DepthwiseConv( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data.data(), + output_dims); + int saturated_min = 0; + int saturated_max = 0; + std::vector diff(output_buffer_size); + std::int64_t sum_diff = 0; + std::int64_t sum_abs_diff = 0; + for (int i = 0; i < output_buffer_size; i++) { + diff[i] = static_cast(output_data[i]) - + static_cast(reference_output_data[i]); + sum_diff += diff[i]; + sum_abs_diff += std::abs(diff[i]); + saturated_min += output_data[i] == output_activation_min; + saturated_max += output_data[i] == output_activation_max; + } + // These stats help understand test failures. + std::sort(std::begin(diff), std::end(diff)); + const int min_diff = diff.front(); + const int max_diff = diff.back(); + const int median_diff = diff[diff.size() / 2]; + const float mean_diff = static_cast(sum_diff) / output_buffer_size; + const float mean_abs_diff = + static_cast(sum_abs_diff) / output_buffer_size; + // Normally we should require bit-for-bit exact results. Unfortunately a bug + // in the Intel arm_neon_sse.h translation header that we use for x86 tests + // causes 1-bit inaccuracy in + // the vqrdmulh_n_s32 intrinsic, which causes off-by-1 errors in quantized + // DepthwiseConv ops. So we have to live with a few off-by-one errors for now, + // yet still ensure that no more than a small minority of values are wrong. + EXPECT_TRUE(std::abs(mean_diff) < 1e-5f && mean_abs_diff < 1e-5f && + std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 && + std::abs(max_diff) <= 1); + if (saturated_min > 2 * saturated_max) { + return -1; + } + if (saturated_max > 2 * saturated_min) { + return 1; + } + return 0; +} + +// The point of this function is that we can't practically know which +// output_shift value to pass to test DepthwiseConv. It's not easy to guess (we +// could do some +// statistics for large size, but they would be fragile at smaller sizes), and +// guessing wrong would mean that all the values get saturated so the test +// becomes +// vacuous. So we just bisect our way to reasonable output_shift values. +template +void TestOneDepthwiseConvBisectOutputShift( + const std::uint8_t* input_data, const Dims<4>& input_dims, + std::int32_t input_offset, const std::uint8_t* filter_data, + const Dims<4>& filter_dims, std::int32_t filter_offset, + const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + std::int32_t output_offset, std::int32_t output_multiplier, + int output_activation_bisect_start, int output_activation_bisect_end, + std::int32_t output_activation_min, std::int32_t output_activation_max, + const Dims<4>& output_dims) { + ASSERT_LT(output_activation_bisect_start, output_activation_bisect_end) + << "Bisection failed ?!?!"; + int output_shift_bisect_midpoint = + (output_activation_bisect_start + output_activation_bisect_end) / 2; + int bisect_result = TestOneDepthwiseConvWithGivenOutputShift( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, + output_shift_bisect_midpoint, output_activation_min, + output_activation_max, output_dims); + // At this point we know that the test succeeded (otherwise it would have + // aborted). + if (bisect_result == 0) { + // The result isn't particularly saturated on one or the other side. + // All good, we're done. + return; + } + if (output_activation_bisect_start == output_activation_bisect_end - 1) { + // There is still some saturation on one side, but the bisection is + // finished anyways. We're done; nothing more we can do about it. This + // happens + // in particular when using an activation with a narrow range. + return; + } + // Continue the bisection based on the present result. + int new_output_activation_bisect_start = bisect_result == 1 + ? output_shift_bisect_midpoint + : output_activation_bisect_start; + int new_output_activation_bisect_end = bisect_result == 1 + ? output_activation_bisect_end + : output_shift_bisect_midpoint; + TestOneDepthwiseConvBisectOutputShift( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, + new_output_activation_bisect_start, new_output_activation_bisect_end, + output_activation_min, output_activation_max, output_dims); +} + +template +void TestOneDepthwiseConv( + const std::uint8_t* input_data, const Dims<4>& input_dims, + std::int32_t input_offset, const std::uint8_t* filter_data, + const Dims<4>& filter_dims, std::int32_t filter_offset, + const std::int32_t* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + std::int32_t output_offset, std::int32_t output_multiplier, + std::int32_t output_activation_min, std::int32_t output_activation_max, + const Dims<4>& output_dims) { + TestOneDepthwiseConvBisectOutputShift( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, + depth_multiplier, output_offset, output_multiplier, 0, 32, + output_activation_min, output_activation_max, output_dims); +} + +void TestOneDepthwiseConv( + FusedActivationFunctionType Ac, const std::uint8_t* input_data, + const Dims<4>& input_dims, std::int32_t input_offset, + const std::uint8_t* filter_data, const Dims<4>& filter_dims, + std::int32_t filter_offset, const std::int32_t* bias_data, + const Dims<4>& bias_dims, int stride, int pad_width, int pad_height, + int depth_multiplier, std::int32_t output_offset, + std::int32_t output_multiplier, std::int32_t output_activation_min, + std::int32_t output_activation_max, const Dims<4>& output_dims) { +#define TOCO_HANDLE_CASE(AC_TYPE) \ + if (AC_TYPE == Ac) { \ + TestOneDepthwiseConv( \ + input_data, input_dims, input_offset, filter_data, filter_dims, \ + filter_offset, bias_data, bias_dims, stride, pad_width, pad_height, \ + depth_multiplier, output_offset, output_multiplier, \ + output_activation_min, output_activation_max, output_dims); \ + return; \ + } + TOCO_HANDLE_CASE(FusedActivationFunctionType::kNone) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu1) + TOCO_HANDLE_CASE(FusedActivationFunctionType::kRelu6) +#undef TOCO_HANDLE_CASE +} + +bool TryTestDepthwiseConv(int batch, int input_depth, int input_width, + int input_height, int filter_width, int filter_height, + int depth_multiplier, int stride, + PaddingType padding_type) { + const int output_depth = input_depth * depth_multiplier; + // The optimized DepthwiseConv implementation currently uses a fixed-size + // accumulator buffer on the stack, with that size. This currently means + // that it does not support larger output depths. It CHECK's for it, + // so it's safe in the sense that if a larger output depth was encountered, + // it would explicitly fail. We just need to adjust our testing to that + // constraint. + const int kMaxSupportedOutputDepth = 1024; + if (output_depth > kMaxSupportedOutputDepth) { + return false; + } + const auto ac = RandomElement(std::vector( + {FusedActivationFunctionType::kNone, FusedActivationFunctionType::kRelu, + FusedActivationFunctionType::kRelu6, + FusedActivationFunctionType::kRelu1})); + int output_activation_min = 0; + int output_activation_max = 255; + if (ac != FusedActivationFunctionType::kNone && UniformRandomInt(0, 1)) { + output_activation_min = UniformRandomInt(0, 50); + output_activation_max = UniformRandomInt(200, 255); + } + const std::int32_t output_multiplier = + UniformRandomInt(1 << 29, std::numeric_limits::max()); + const std::int32_t input_offset = UniformRandomInt(-256, 0); + const std::int32_t filter_offset = UniformRandomInt(-256, 0); + const std::int32_t output_offset = UniformRandomInt(-256, 0); + Dims<4> input_dims_inference = + MakeDimsForInference(input_depth, input_width, input_height, batch); + Dims<4> output_dims_inference; + int pad_width, pad_height; + if (!ComputeConvSizes(input_dims_inference, output_depth, filter_width, + filter_height, stride, padding_type, + &output_dims_inference, &pad_width, &pad_height)) { + return false; + } + Dims<4> filter_dims_inference = + MakeDimsForInference(output_depth, filter_width, filter_height, 1); + Dims<4> bias_dims_inference = MakeDimsForInference(output_depth, 1, 1, 1); + const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference); + const int filter_buffer_size = + RequiredBufferSizeForDims(filter_dims_inference); + std::vector input_data(input_buffer_size); + std::vector filter_data(filter_buffer_size); + std::vector bias_data(output_depth); + FillRandom(&input_data); + FillRandom(&filter_data); + FillRandom(&bias_data, -10000, 10000); + TestOneDepthwiseConv(ac, input_data.data(), input_dims_inference, + input_offset, filter_data.data(), filter_dims_inference, + filter_offset, bias_data.data(), bias_dims_inference, + stride, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_activation_min, + output_activation_max, output_dims_inference); + return true; +} + +// This function picks some random DepthwiseConv params, which may or may not +// be legal. If they're not legal, it returns false. If they're legal, +// it runs the DepthwiseConv test and returns true. This allows the caller +// to loop until a test has been run. +bool TryTestOneDepthwiseConv() { + // We have to pick a lot of positive values, where we are particularly + // interested in small values because they are most likely to be special + // cases in optimized implementations, and secondarily because they allow + // tests to run fast, which means we can run more tests and get more + // coverage. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int filter_width = ExponentialRandomPositiveInt(0.9f, 4, 10); + const int filter_height = ExponentialRandomPositiveInt(0.9f, 4, 10); + const int depth_multiplier = ExponentialRandomPositiveInt(0.8f, 6, 50); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const auto padding_type = + UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid; + + return TryTestDepthwiseConv(batch, input_depth, input_width, input_height, + filter_width, filter_height, depth_multiplier, + stride, padding_type); +} + +// Tests parameters for the 3x3 filter kernel. +bool TryTestOneDepthwiseConv3x3Filter() { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = 8 * ExponentialRandomPositiveInt(0.9f, 10, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int filter_width = 3; + const int filter_height = 3; + const int depth_multiplier = 1; + const int stride = UniformRandomInt(1, 2); + // Although the kernel supports only kValid padding, we test that kSame + // is using the correct code path. + const auto padding_type = + UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid; + + return TryTestDepthwiseConv(batch, input_depth, input_width, input_height, + filter_width, filter_height, depth_multiplier, + stride, padding_type); +} + +void TestOneDepthwiseConv() { + while (!TryTestOneDepthwiseConv()) { + } +} + +void TestOneDepthwiseConv3x3Filter() { + while (!TryTestOneDepthwiseConv3x3Filter()) { + } +} + +TEST(TestDepthwiseConv, TestDepthwiseConv) { + const int kTestsToRun = 10 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + TestOneDepthwiseConv(); + } +} + +TEST(TestDepthwiseConv3x3Filter, TestDepthwiseConv) { + const int kTestsToRun = 3 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + TestOneDepthwiseConv3x3Filter(); + } +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc index 5f9cfc450db1c25ff604f99f93481e5ca590a5a2..36c25388e8bde721d7644dc83d5b7c490d37b4d3 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -52,21 +52,19 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, - float* hidden_state_ptr_batch, float* output_ptr_batch) { + float* scaling_factors, float* hidden_state_ptr_batch, + float* output_ptr_batch) { // Output = bias tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size, output_ptr_batch); - // TODO(mirkov): change std::minmax_element with a vectorized call. - auto minmax_element = std::minmax_element( - input_ptr_batch, input_ptr_batch + batch_size * input_size); - // Save quantization and matmul computation for all zero input. - if (!(*minmax_element.first == 0.0 && *minmax_element.second == 0.0)) { + if (!tensor_utils::IsZeroVector(input_ptr_batch, batch_size * input_size)) { // Quantize input from float to uint8 + quantization params (scaling // factor). float unused_min, unused_max; - float* scaling_factors = new float[batch_size]; + // TODO(mirkov,raziel): replace this for-loop with a MACRO (or function) + // whichever is faster. for (int b = 0; b < batch_size; ++b) { const int offset = b * input_size; tensor_utils::SymmetricQuantizeFloats( @@ -80,16 +78,13 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_weights_ptr, num_units, input_size, quantized_input_ptr_batch, scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1); - delete[] scaling_factors; } - minmax_element = std::minmax_element( - hidden_state_ptr_batch, hidden_state_ptr_batch + batch_size * num_units); // Save quantization and matmul computation for all zero input. - if (!(*minmax_element.first == 0.0 && *minmax_element.second == 0.0)) { + if (!tensor_utils::IsZeroVector(hidden_state_ptr_batch, + batch_size * num_units)) { // Quantize hidden_state float unused_min, unused_max; - float* scaling_factors = new float[batch_size]; for (int b = 0; b < batch_size; ++b) { const int offset = b * num_units; tensor_utils::SymmetricQuantizeFloats( @@ -104,7 +99,6 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, recurrent_weights_ptr, num_units, num_units, quantized_hidden_state_ptr_batch, scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1); - delete[] scaling_factors; } // Output = activation(Output) and update hidden_state @@ -155,6 +149,7 @@ void LstmStep( input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, input_gate_scratch, /*result_stride=*/1); } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1); @@ -169,8 +164,7 @@ void LstmStep( if (!use_cifg) { tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, - n_batch, input_gate_scratch, - /*result_stride=*/1); + n_batch, input_gate_scratch, /*result_stride=*/1); } tensor_utils::MatrixBatchVectorMultiplyAccumulate( recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, @@ -261,5 +255,263 @@ void LstmStep( output_state_ptr); } +// TODO(alanchiao): move this to tensor_utils. +void VectorMultiply(const int8_t* vector, const int v_size, const float scale, + float* result) { + for (int i = 0; i < v_size; ++i) { + *result++ = scale * *vector++; + } +} + +void LstmStep( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_output, float* input_gate_scratch, float* forget_gate_scratch, + float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_cell_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch) { + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights_ptr == nullptr); + const bool use_peephole = (cell_to_output_weights_ptr != nullptr); + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + + if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_input; + tensor_utils::SymmetricQuantizeFloats( + input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset, + &unused_min, &unused_max, &scaling_factors[b]); + } + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, + quantized_input_ptr_batch, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, forget_gate_scratch, + /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * input_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch, + product_scaling_factors, n_batch, output_gate_scratch, + /*result_stride=*/1); + } + + if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_output; + tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output, + quantized_output_state_ptr + offset, + &unused_min, &unused_max, + &scaling_factors[b]); + } + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_input_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_forget_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + forget_gate_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_cell_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + cell_scratch, /*result_stride=*/1); + + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * recurrent_to_output_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, + quantized_output_state_ptr, product_scaling_factors, n_batch, + output_gate_scratch, /*result_stride=*/1); + } + + // Save quantization and matmul computation for all zero input. + bool is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_input_weights_ptr, n_cell, + 1. / cell_to_input_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_forget_weights_ptr, n_cell, + 1. / cell_to_forget_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, + params->cell_clip, cell_state_ptr); + } + + is_cell_state_all_zeros = + tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell); + // For each batch and cell: update the output gate. + if (use_peephole && !is_cell_state_all_zeros) { + VectorMultiply(cell_to_output_weights_ptr, n_cell, + 1. / cell_to_output_weights_scale, recovered_cell_weights); + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + recovered_cell_weights, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_batch); + } else { + tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output); + } + if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) { + // Save quantization and matmul computation for all zero input. + float unused_min, unused_max; + for (int b = 0; b < n_batch; ++b) { + const int offset = b * n_cell; + tensor_utils::SymmetricQuantizeFloats( + output_gate_scratch + offset, n_cell, + quantized_cell_state_ptr + offset, &unused_min, &unused_max, + &scaling_factors[b]); + } + for (int b = 0; b < n_batch; ++b) { + product_scaling_factors[b] = + scaling_factors[b] * projection_weights_scale; + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr, + product_scaling_factors, n_batch, output_ptr_batch, + /*result_stride=*/1); + } + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, + params->proj_clip, output_ptr_batch); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_batch); + } + tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output, + output_state_ptr); +} + } // namespace kernel_utils } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h index cbfbcbeefcd34fa732799d89f52791b18855857d..2a11b37a6069367e8232350c2fc68d4c385e14ba 100644 --- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -41,6 +41,9 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, // values of hidden_state_ptr_batch and input_ptr_batch, respectively. // These temporary storages are expected to be preallocated to the same size as // the respective pointers. +// An additional preallocated temporary storage 'scaling_factors' (of size +// batch_size) is used to store the scaling factors of the quantization (used +// for recovery). // {input,recurrent}_weights_scale params are used for dequantization/recovery. void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, float input_weights_scale, @@ -50,7 +53,8 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr, TfLiteFusedActivation activation, int8_t* quantized_input_ptr_batch, int8_t* quantized_hidden_state_ptr_batch, - float* hidden_state_ptr_batch, float* output_ptr_batch); + float* scaling_factors, float* hidden_state_ptr_batch, + float* output_ptr_batch); // Performs an LSTM batch inference step for input specified by input_ptr_batch. // The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and @@ -88,6 +92,89 @@ void LstmStep( float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch, float* output_ptr_batch); +// Same as above but with quantized weight matrices. In detail: +// Input of size 'n_batch * n_input': +// input_ptr_batch +// +// LSTM weights: +// Quantized input weights of size 'n_cell * n_input': +// input_to_input_weights - optional (can be nullptr) +// input_to_forget_weights +// input_to_cell_weights +// input_to_input_weights +// Quantized recurrent weights of size 'n_cell * n_output': +// recurrent_to_input_weights - optional +// recurrent_to_forget_weights +// recurrent_to_cell_weights +// recurrent_to_input_weights +// Quantized peephole weights of size 'n_cell', representing diagonal matrices. +// cell_to_input_weights - optional +// cell_to_cell_weights - optional +// cell_to_output_weights - optional +// Quantized projection weights of size 'n_output * n_cell' +// projection_weights_ptr - optional +// Weight scales (scalars) for each of the weights above. +// input_to_input_weights_scale - optional +// input_to_forget_weights_scale +// input_to_cell_weights_scale +// input_to_output_weights_scale +// recurrent_to_input_weights_scale - optional +// recurrent_to_forget_weights_scale +// recurrent_to_cell_weights_scale +// recurrent_to_output_weights_scale +// cell_to_input_weights_scale, +// cell_to_forget_weights_scale, +// cell_to_output_weights_scale, +// projection_weights_scale - optional +// Gate biases of size 'n_cell': +// input_gate_bias_ptr - optional +// forget_gate_bias_ptr +// cell_gate_bias_ptr +// output_gate_bias_ptr +// +// Temporary pre-allocated storage for quantized values: +// quantized_input_ptr_batch (same size as input_ptr_batch) +// quantized_output_state_ptr (same size as output_state_ptr) +// quantized_cell_state_ptr (same size as cell_state_ptr) +// Temporary pre-allocated storage for recovered values: +// recovered_cell_weights (same size as cell_to_*_weights) +// +// Outputs: +// output_state_ptr - size 'n_batch * n_output' +// cell_state_ptr - size 'n_batch * n_cell' +// output_ptr_batch - size 'n_batch * n_output' +void LstmStep( + const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr, + float input_to_input_weights_scale, + const int8_t* input_to_forget_weights_ptr, + float input_to_forget_weights_scale, + const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale, + const int8_t* input_to_output_weights_ptr, + float input_to_output_weights_scale, + const int8_t* recurrent_to_input_weights_ptr, + float recurrent_to_input_weights_scale, + const int8_t* recurrent_to_forget_weights_ptr, + float recurrent_to_forget_weights_scale, + const int8_t* recurrent_to_cell_weights_ptr, + float recurrent_to_cell_weights_scale, + const int8_t* recurrent_to_output_weights_ptr, + float recurrent_to_output_weights_scale, + const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale, + const int8_t* cell_to_forget_weights_ptr, + float cell_to_forget_weights_scale, + const int8_t* cell_to_output_weights_ptr, + float cell_to_output_weights_scale, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr, + float projection_weights_scale, const float* projection_bias_ptr, + const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input, + int n_output, float* input_gate_scratch, float* forget_gate_scratch, + float* cell_scratch, float* output_gate_scratch, float* scaling_factors, + float* product_scaling_factors, float* recovered_cell_weights, + int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr, + int8_t* quantized_cell_state_ptr, float* output_state_ptr, + float* cell_state_ptr, float* output_ptr_batch); + } // namespace kernel_utils } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e9ff5242a43a8b54e0e6ae167cdcf7a341c918e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/log_quantized_test.cc @@ -0,0 +1,333 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GEMMLOWP_ENABLE_FIXEDPOINT_CONSTANTS_CHECKS + +#include +#include +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" + +namespace { + +class NumberGenerator { + public: + std::vector RandomIntVector(int n, int min_val, int max_val) { + std::vector vec(n); + double scale = static_cast(max_val + 1 - min_val) / engine_.max(); + for (auto& it : vec) { + it = min_val + std::floor(engine_() * scale); + } + return vec; + } + + std::mt19937 engine_; +}; + +class LogQuantizedTest : public ::testing::Test { + public: + NumberGenerator generator_; +}; + +// input_integer_bits <= 30. output_integer_bits > 0. +inline int32 LogPositiveValuesViaFloat(int32 input_val, int input_integer_bits, + int output_integer_bits) { + const double float_log_sum_of_exps = std::log( + static_cast(input_val) * 0.5 / (1 << (30 - input_integer_bits))); + static constexpr double min_int = + static_cast(std::numeric_limits::min()); + static constexpr double max_int = + static_cast(std::numeric_limits::max()); + double double_result = tflite::TfLiteRound(float_log_sum_of_exps * + (1 << (31 - output_integer_bits))); + return static_cast( + std::min(max_int, std::max(min_int, double_result))); +} + +void CheckOutputData(const std::vector& test_output, + const std::vector& reference_output, + const std::vector& test_input, + const string& check_label, int input_integer_bits, + int output_integer_bits, int tolerance) { + // In the special case of small input, specifically raw value of 5, a rounding + // up leads to difference in the output. We do not aim to be accurate for + // very small input values, and there should be sufficient input fractional + // bits that this is a small input. + static constexpr double error_from_rounding_up = 0.0224585; + const int n = test_output.size(); + ASSERT_EQ(n, reference_output.size()); + for (int i = 0; i < n; ++i) { + // Adjust tolerance when input <= 5*2^-(31-input_integer_bits). + const int adjusted_tolerance = + test_input[i] > 5 + ? tolerance + : std::max(tolerance, static_cast(std::ceil( + error_from_rounding_up * + (1 << (31 - output_integer_bits))))); + ASSERT_LE(std::abs(test_output[i] - reference_output[i]), + adjusted_tolerance) + << "Failure in \"" << check_label << "\" at i=" << i + << ", test_input[i]=" << test_input[i] << "=" + << static_cast(test_input[i]) / (1 << (31 - input_integer_bits)) + << ", test_output[i]=" << test_output[i] << "=" + << static_cast(test_output[i]) / + (1 << (31 - output_integer_bits)) + << ", reference_output[i]=" << reference_output[i] << "=" + << static_cast(reference_output[i]) / + (1 << (31 - output_integer_bits)) + << ", difference[i]=" << std::abs(reference_output[i] - test_output[i]) + << "=" + << static_cast(std::abs(reference_output[i] - test_output[i])) / + (1 << (31 - output_integer_bits)) + << "; tolerance=" << tolerance + << ", adj tolerance=" << adjusted_tolerance; + } +} + +void RightShiftVector(const std::vector& shifts, + std::vector* vec) { + const int n = vec->size(); + ASSERT_EQ(n, shifts.size()); + for (int i = 0; i < n; ++i) { + vec->at(i) = std::max(1, vec->at(i) >> shifts[i]); + } +} + +template +void RunSingleTest(const std::vector& test_input, + const string& check_label, int tolerance) { + const int n = test_input.size(); + std::vector float_gen_output(n, 0); + std::vector reference_output(n, 0); + std::vector optimized_output(n, 0); + + // Workaround the stupid things that intelligent humans do. + // Consequence of __builtin_clz(0u) may equal 31 instead of 32. + std::vector fudged_input(n, 0); + for (int i = 0; i < n; ++i) { + fudged_input[i] = std::max(test_input[i], 2); + } + + for (int i = 0; i < n; ++i) { + reference_output[i] = + tflite::reference_ops::log_x_for_x_greater_than_or_equal_to_1_impl< + OutputIntegerBits, InputIntegerBits>( + gemmlowp::FixedPoint::FromRaw( + fudged_input[i])) + .raw(); + optimized_output[i] = + tflite::optimized_ops::log_x_for_x_greater_than_or_equal_to_1_impl< + OutputIntegerBits, InputIntegerBits>( + gemmlowp::FixedPoint::FromRaw( + fudged_input[i])) + .raw(); + float_gen_output[i] = LogPositiveValuesViaFloat( + fudged_input[i], InputIntegerBits, OutputIntegerBits); + } + // Note that first check is intolerant. + { + std::ostringstream label; + label << check_label << " / optimized vs reference / InputIntegerBits=" + << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; + CheckOutputData( + optimized_output, reference_output, test_input, label.str(), + InputIntegerBits, OutputIntegerBits, 0); + } + { + std::ostringstream label; + label << check_label << " / reference vs float-gen / InputIntegerBits=" + << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; + CheckOutputData( + reference_output, float_gen_output, test_input, label.str(), + InputIntegerBits, OutputIntegerBits, tolerance); + } + { + std::ostringstream label; + label << check_label << " optimized vs float-gen / InputIntegerBits=" + << InputIntegerBits << ", OutputIntegerBits=" << OutputIntegerBits; + CheckOutputData( + optimized_output, float_gen_output, test_input, label.str(), + InputIntegerBits, OutputIntegerBits, tolerance); + } +} + +template +void RunSingleTest(const std::vector& test_input, int input_integer_bits, + const string& check_label, int tolerance) { +#define INPUT_CASE(K) \ + case K: \ + return RunSingleTest(test_input, check_label, \ + tolerance) + switch (input_integer_bits) { + INPUT_CASE(0); + INPUT_CASE(1); + INPUT_CASE(2); + INPUT_CASE(3); + INPUT_CASE(4); + INPUT_CASE(5); + INPUT_CASE(6); + INPUT_CASE(7); + INPUT_CASE(8); + INPUT_CASE(9); + INPUT_CASE(10); + INPUT_CASE(11); + INPUT_CASE(12); + INPUT_CASE(13); + INPUT_CASE(14); + INPUT_CASE(15); + INPUT_CASE(16); + INPUT_CASE(17); + INPUT_CASE(18); + INPUT_CASE(19); + INPUT_CASE(20); + INPUT_CASE(21); + INPUT_CASE(22); + INPUT_CASE(23); + INPUT_CASE(24); + INPUT_CASE(25); + INPUT_CASE(26); + INPUT_CASE(27); + INPUT_CASE(28); + INPUT_CASE(29); + default: + ASSERT_LE(input_integer_bits, 30) + << "Input integer bits not handled: " << input_integer_bits; + } +#undef INPUT_CASE +} + +void RunSingleTest(const std::vector& test_input, int input_integer_bits, + int output_integer_bits, const string& check_label, + int tolerance) { +#define OUTPUT_CASE(K) \ + case K: \ + return RunSingleTest(test_input, input_integer_bits, check_label, \ + tolerance) + switch (output_integer_bits) { + OUTPUT_CASE(0); + OUTPUT_CASE(1); + OUTPUT_CASE(2); + OUTPUT_CASE(3); + OUTPUT_CASE(4); + OUTPUT_CASE(5); + OUTPUT_CASE(6); + OUTPUT_CASE(7); + OUTPUT_CASE(8); + OUTPUT_CASE(9); + OUTPUT_CASE(10); + OUTPUT_CASE(11); + OUTPUT_CASE(12); + OUTPUT_CASE(13); + OUTPUT_CASE(14); + OUTPUT_CASE(15); + OUTPUT_CASE(16); + OUTPUT_CASE(17); + OUTPUT_CASE(18); + OUTPUT_CASE(19); + OUTPUT_CASE(20); + OUTPUT_CASE(21); + OUTPUT_CASE(22); + OUTPUT_CASE(23); + OUTPUT_CASE(24); + OUTPUT_CASE(25); + OUTPUT_CASE(26); + OUTPUT_CASE(27); + OUTPUT_CASE(28); + OUTPUT_CASE(29); + default: + ASSERT_LE(input_integer_bits, 30) + << "Input integer bits not handled: " << input_integer_bits; + } +#undef OUTPUT_CASE +} + +void RunUniformTest(int test_size, int input_integer_bits, + int output_integer_bits, const string& check_label, + int tolerance, NumberGenerator* generator) { + std::vector test_data = generator->RandomIntVector( + test_size, 2, std::numeric_limits::max() - 1); + test_data[0] = 2; + test_data[1] = 3; + test_data[2] = 4; + test_data[3] = std::numeric_limits::max() - 2; + test_data[4] = std::numeric_limits::max() - 1; + test_data[5] = std::numeric_limits::max(); + + RunSingleTest(test_data, input_integer_bits, output_integer_bits, + check_label + " / uniform test", tolerance); +} + +void RunUniformShiftUniformTest(int test_size, int input_integer_bits, + int output_integer_bits, + const string& check_label, int tolerance, + NumberGenerator* generator) { + std::vector test_data = generator->RandomIntVector( + test_size, 2, std::numeric_limits::max() - 1); + std::vector shifts = generator->RandomIntVector(test_size, 0, 29); + RightShiftVector(shifts, &test_data); + + RunSingleTest(test_data, input_integer_bits, output_integer_bits, + check_label + " / shifted test", tolerance); +} + +TEST_F(LogQuantizedTest, VariedIntegerBits) { + static constexpr int kVariations = 250; + static constexpr int kRunSize = 250; + static constexpr int kIntegerTolerance = 8; + static constexpr double kOutputFloatTolerance = 7.0e-7; + + std::vector input_integer_bits = + generator_.RandomIntVector(kVariations, 0, 24); + std::vector output_integer_bits = + generator_.RandomIntVector(kVariations, 1, 10); + + for (int i = 0; i < kVariations; ++i) { + int var_output_integer_bits = output_integer_bits[i]; + int tolerance = + std::max(1.0 * kIntegerTolerance, + (1 << (31 - var_output_integer_bits)) * kOutputFloatTolerance); + + RunUniformTest(kRunSize, input_integer_bits[i], var_output_integer_bits, + "VariedIntegerBits", tolerance, &generator_); + RunUniformShiftUniformTest(kRunSize, input_integer_bits[i], + var_output_integer_bits, "VariedIntegerBits", + tolerance, &generator_); + } +} + +TEST_F(LogQuantizedTest, SelectedIntegerBits) { + static constexpr int kInputBits = 12; + static constexpr int kOutputBits = 5; + static constexpr int kRunSize = 100000; + static constexpr int kIntegerTolerance = 4; + + RunUniformTest(kRunSize, kInputBits, kOutputBits, "SelectedIntegerBits", + kIntegerTolerance, &generator_); + RunUniformShiftUniformTest(kRunSize, kInputBits, kOutputBits, + "SelectedIntegerBits", kIntegerTolerance, + &generator_); +} + +} // namespace diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e786f785abe3aa66a9fb243dd4f332ca91676863 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -0,0 +1,242 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" + +namespace tflite { +namespace { + +void RunLogSoftmaxFloatReference(const uint8* input_data, + const Dims<4>& dims_common, int32 input_offset, + const double input_scale, int stride, + float beta, uint8* reference_output_data) { + const int ref_buffer_size = RequiredBufferSizeForDims(dims_common); + std::vector reference_dequant_data(ref_buffer_size); + std::vector reference_output_float_data(ref_buffer_size); + + // Reference data generated via Dequant of input into float, and then applying + // float LogSoftmax. + reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale, + reference_dequant_data.data(), dims_common); + optimized_ops::LogSoftmax(reference_dequant_data.data(), dims_common, + reference_output_float_data.data(), dims_common); + // Work with quantized scaling for LogSoftmax, under which 255 represents 0, + // and -16 gets nudged up to 0. + for (int i = 0; i < ref_buffer_size; i++) { + reference_output_data[i] = std::max( + 0, static_cast( + 255 + std::round(16.0f * reference_output_float_data[i]))); + } +} + +void CheckOutputData(const uint8* test_output, const uint8* reference_output, + const Dims<4>& dims_common, const string& check_label, + bool be_exacting) { + const int buffer_size = RequiredBufferSizeForDims(dims_common); + // While calculating some metrics in floating point, we work with quantized + // scaling. + std::vector diff(buffer_size); + int64_t sum_diff = 0; + int64_t sum_abs_diff = 0; + for (int i = 0; i < buffer_size; i++) { + diff[i] = static_cast(test_output[i]) - reference_output[i]; + sum_diff += diff[i]; + sum_abs_diff += std::abs(diff[i]); + } + // These stats help understand test failures. + std::sort(std::begin(diff), std::end(diff)); + const int min_diff = diff.front(); + const int max_diff = diff.back(); + const int median_diff = diff[diff.size() / 2]; + const float mean_diff = static_cast(sum_diff) / buffer_size; + const float mean_abs_diff = static_cast(sum_abs_diff) / buffer_size; + // We either check for bit exactness (against the reference quantized version) + // or for general accuracy, allowing off-by-one (against the float reference). + if (be_exacting) { + ASSERT_TRUE(std::abs(min_diff) == 0 && std::abs(max_diff) == 0) + << check_label << ": " + << "std::abs(min_diff)=" << std::abs(min_diff) + << ", std::abs(max_diff)=" << std::abs(max_diff); + } else { + // For small numbers of samples, the estimates of the means vary more. + // Rather than widen the tolerances, we skip the smaller tests. + ASSERT_TRUE(((std::abs(mean_diff) < 2e-2f && mean_abs_diff < 3e-2f) || + buffer_size < 10000) && + std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 && + std::abs(max_diff) <= 1) + << check_label << ": " + << "buffer_size=" << buffer_size << ", mean_diff=" << mean_diff + << ", mean_abs_diff=" << mean_abs_diff + << ", median_diff=" << median_diff << ", min_diff=" << min_diff + << ", max_diff=" << max_diff; + } +} + +// Runs the LogSoftmax and compares against the float reference implementation +// and the quantized reference implementation. +void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, + int32 input_offset, const double input_scale, + int stride, float beta) { + const int buffer_size = RequiredBufferSizeForDims(dims_common); + std::vector optimized_logsoftmax_output(buffer_size); + std::vector reference_float_logsoftmax_output(buffer_size); + std::vector reference_quant_logsoftmax_output(buffer_size); + + RunLogSoftmaxFloatReference(input_data, dims_common, input_offset, + input_scale, stride, beta, + reference_float_logsoftmax_output.data()); + + int32 input_beta_multiplier; + int input_beta_left_shift; + int32 reverse_scaling_divisor; + int reverse_scaling_right_shift; + static const int kScaledDiffIntegerBits = 5; + tflite::PreprocessLogSoftmaxScalingExp( + beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier, + &input_beta_left_shift, &reverse_scaling_divisor, + &reverse_scaling_right_shift); + reverse_scaling_right_shift *= -1; + // diff_min has a negative value, and is used to limit the maximum magnitude + // of the diffs, which are <= 0. + const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, + input_beta_left_shift); + + optimized_ops::LogSoftmax(input_data, dims_common, input_beta_multiplier, + input_beta_left_shift, reverse_scaling_divisor, + reverse_scaling_right_shift, diff_min, + optimized_logsoftmax_output.data(), dims_common); + reference_ops::LogSoftmax( + input_data, dims_common, input_beta_multiplier, input_beta_left_shift, + reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, + reference_quant_logsoftmax_output.data(), dims_common); + + CheckOutputData(optimized_logsoftmax_output.data(), + reference_float_logsoftmax_output.data(), dims_common, + "Optimized vs float reference", false); + CheckOutputData(optimized_logsoftmax_output.data(), + reference_quant_logsoftmax_output.data(), dims_common, + "Optimized vs quant reference", true); + CheckOutputData(reference_quant_logsoftmax_output.data(), + reference_float_logsoftmax_output.data(), dims_common, + "Quant reference vs float reference", false); +} + +// This function picks some random LogSoftmax params, which are checked for +// desirability. If not acceptable, it returns false. If they're OK, +// it runs the LogSoftmax test and returns true. This allows the caller +// to loop until a test has been run. +// +// Currently we do not reject for any reason. +bool TryOneUniformLogSoftmax() { + // We pick mostly positive values, on the whole emphasizing smaller values and + // therefore faster tests. We test a wider range of depths. In the case of + // LogSoftmax, the width and height really just create test repetitions. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = ExponentialRandomPositiveInt(0.75f, 175, 500); + const int input_width = ExponentialRandomPositiveInt(0.8f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); + const int32 input_offset = UniformRandomInt(-256, 0); + static constexpr float beta = 1.0f; + + Dims<4> dims_common = + MakeDimsForInference(input_depth, input_width, input_height, batch); + const int buffer_size = RequiredBufferSizeForDims(dims_common); + + std::vector input_data(buffer_size); + FillRandom(&input_data); + RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset, + input_scale, stride, beta); + return true; +} + +// See TryOneUniformLogSoftmax() for a general description. +// +// Tests with "skyscraper" input patterns are included for two reasons. (a) +// Bimodal distributions are potentially challenging and perhaps more +// realistic than simple uniform random inputs. (b) Some implementations of +// LogSoftmax may adapt as they traverse the depth, and so we test handling of +// cases where relatively small values are encountered at the beginning and end. +bool TryOneSkyscraperLogSoftmax(bool small_depth) { + // We pick mostly positive values, on the whole emphasizing smaller values and + // therefore faster tests. We test a wider range of depths. In the case of + // LogSoftmax, the width and height really just create test repetitions. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = small_depth + ? ExponentialRandomPositiveInt(0.75f, 40, 500) + : ExponentialRandomPositiveInt(0.75f, 175, 500); + const int input_width = ExponentialRandomPositiveInt(0.7f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); + const int32 input_offset = UniformRandomInt(-256, 0); + static constexpr float beta = 1.0f; + // Extra parameters for skyscraper input patterns. + const double middle_proportion = + ExponentialRandomPositiveFloat(0.65f, 0.1, 1.0); + const int middle_min = UniformRandomInt(0, 255); + const int sides_max = UniformRandomInt(0, middle_min); + + Dims<4> dims_common = + MakeDimsForInference(input_depth, input_width, input_height, batch); + const int buffer_size = RequiredBufferSizeForDims(dims_common); + + std::vector input_data(buffer_size); + FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, + sides_max); + RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset, + input_scale, stride, beta); + return true; +} + +TEST(TestQuantizedLogSoftmax, UniformLogSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneUniformLogSoftmax()) { + } + } +} + +TEST(TestQuantizedLogSoftmax, SkyscraperLogSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneSkyscraperLogSoftmax(false)) { + } + } +} + +TEST(TestQuantizedLogSoftmax, SmallSkyscraperLogSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneSkyscraperLogSoftmax(true)) { + } + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h index dd6932ffe7b7a6f1101f146ce6472b0df4cbda3b..3fd00c89308d3b163111fda004287c715259352f 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -1691,14 +1691,20 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, const int filter_width = ArraySize(filter_dims, 1); const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); +#ifdef USE_NEON + const bool shift_left = (output_shift <= 0); + const int32 multiplier_power_of_two = shift_left ? (1 << -output_shift) : 1; +#endif TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); -#ifdef __aarch64__ +// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on +// Jetson TX-2. This compiler does not support the offsetof() macro. +#if defined(__aarch64__) && !defined(GOOGLE_L4T) // Call kernel optimized for depthwise convolutions using 3x3 filters if // parameters are supported. - if (Fast3x3FilterKernelSupported(input_dims, filter_dims, stride_width, - stride_height, pad_width, pad_height, - depth_multiplier, output_dims)) { + if (Fast3x3FilterKernelSupported( + input_dims, filter_dims, stride_width, stride_height, pad_width, + pad_height, depth_multiplier, output_dims, output_shift)) { DepthwiseConv3x3Filter(input_data, input_dims, input_offset, filter_data, filter_dims, filter_offset, bias_data, bias_dims, stride_width, stride_height, pad_width, pad_height, @@ -1833,12 +1839,20 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, acc[j] = vld1q_s32(acc_buffer + i + 4 * j); } - // Fixed-point multiplication. - for (int j = 0; j < 4; j++) { - acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); - } - for (int j = 0; j < 4; j++) { - acc[j] = RoundingDivideByPOT(acc[j], output_shift); + if (!shift_left) { + // Fixed-point multiplication. + for (int j = 0; j < 4; j++) { + acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); + } + for (int j = 0; j < 4; j++) { + acc[j] = RoundingDivideByPOT(acc[j], output_shift); + } + } else { + // Fixed-point multiplication. + for (int j = 0; j < 4; j++) { + acc[j] = vmulq_n_s32(acc[j], multiplier_power_of_two); + acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); + } } // Add the output offset. for (int j = 0; j < 4; j++) { @@ -1870,12 +1884,21 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, for (; i <= num_output_values - 8; i += 8) { int32x4_t acc0 = vld1q_s32(acc_buffer + i); int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4); - // Fixed-point multiplication. - acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); - acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); - // Rounding right shift. - acc0 = RoundingDivideByPOT(acc0, output_shift); - acc1 = RoundingDivideByPOT(acc1, output_shift); + if (!shift_left) { + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + // Rounding right shift. + acc0 = RoundingDivideByPOT(acc0, output_shift); + acc1 = RoundingDivideByPOT(acc1, output_shift); + } else { + // Fixed-point multiplication. + acc0 = vmulq_n_s32(acc0, multiplier_power_of_two); + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + + acc1 = vmulq_n_s32(acc1, multiplier_power_of_two); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + } // Add the output offset. acc0 = vaddq_s32(acc0, output_offset_vec); acc1 = vaddq_s32(acc1, output_offset_vec); @@ -1899,10 +1922,16 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, // that will have to go through the very slow scalar code. for (; i <= num_output_values - 4; i += 4) { int32x4_t acc = vld1q_s32(acc_buffer + i); - // Fixed-point multiplication. - acc = vqrdmulhq_n_s32(acc, output_multiplier); - // Rounding right shift. - acc = RoundingDivideByPOT(acc, output_shift); + if (!shift_left) { + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + // Rounding right shift. + acc = RoundingDivideByPOT(acc, output_shift); + } else { + // Fixed-point multiplication. + acc = vmulq_n_s32(acc, multiplier_power_of_two); + acc = vqrdmulhq_n_s32(acc, output_multiplier); + } // Add the output offset. acc = vaddq_s32(acc, output_offset_vec); // Apply the activation function. @@ -1923,8 +1952,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, // Handle leftover values, one by one. This is very slow. for (; i < num_output_values; i++) { int32 acc = acc_buffer[i]; - acc = MultiplyByQuantizedMultiplierSmallerThanOne( - acc, output_multiplier, output_shift); + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, + -output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 55e0d5c3aa9ebb8b46403550e190b00a54cb53e5..a7b0d805a3acd35b592a35ba4266dfff4eb992cd 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -23,3848 +23,2912 @@ limitations under the License. namespace tflite { namespace optimized_ops { -#ifdef __aarch64__ - -inline void preload_l1_keep(const uint8* ptr) { -#ifdef GEMMLOWP_ARM_64 - asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :); -#else - gemmlowp::Prefetch(ptr); -#endif -} - -// Implementation of quantized DepthwiseConv for 3x3 filters. - -// Below are helper structs to remove the use of arrays. -// There is an llvm bug that causes significant slowdown when using arrays for -// NEON intrinsics vector data types. -// See: https://bugs.llvm.org/show_bug.cgi?id=34945 - -struct Int32x8 { - int32x4_t low, high; -}; - -struct Filter3x3x8 { - int16x8_t f0, f1, f2, f3, f4, f5, f6, f7, f8; -}; - -// Loads 3x3 filter of depth 8 and adds filter offsets. -inline Filter3x3x8 Load3x3Filter(const uint8* filter_ptr, int32 filter_offset, - int output_depth) { - Filter3x3x8 filter; - - uint8x8_t temp_u8_0, temp_u8_1, temp_u8_2, temp_u8_3, temp_u8_4, temp_u8_5, - temp_u8_6, temp_u8_7, temp_u8_8; - int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset); - - temp_u8_0 = vld1_u8(filter_ptr + 0 * output_depth); - temp_u8_1 = vld1_u8(filter_ptr + 1 * output_depth); - temp_u8_2 = vld1_u8(filter_ptr + 2 * output_depth); - temp_u8_3 = vld1_u8(filter_ptr + 3 * output_depth); - temp_u8_4 = vld1_u8(filter_ptr + 4 * output_depth); - temp_u8_5 = vld1_u8(filter_ptr + 5 * output_depth); - temp_u8_6 = vld1_u8(filter_ptr + 6 * output_depth); - temp_u8_7 = vld1_u8(filter_ptr + 7 * output_depth); - temp_u8_8 = vld1_u8(filter_ptr + 8 * output_depth); - - filter.f0 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_0)); - filter.f1 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_1)); - filter.f2 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_2)); - filter.f3 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_3)); - filter.f4 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_4)); - filter.f5 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_5)); - filter.f6 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_6)); - filter.f7 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_7)); - filter.f8 = vreinterpretq_s16_u16(vmovl_u8(temp_u8_8)); - - filter.f0 = vaddq_s16(filter.f0, filter_offset_vec); - filter.f1 = vaddq_s16(filter.f1, filter_offset_vec); - filter.f2 = vaddq_s16(filter.f2, filter_offset_vec); - filter.f3 = vaddq_s16(filter.f3, filter_offset_vec); - filter.f4 = vaddq_s16(filter.f4, filter_offset_vec); - filter.f5 = vaddq_s16(filter.f5, filter_offset_vec); - filter.f6 = vaddq_s16(filter.f6, filter_offset_vec); - filter.f7 = vaddq_s16(filter.f7, filter_offset_vec); - filter.f8 = vaddq_s16(filter.f8, filter_offset_vec); - - return filter; -} - -// Applies activation, offset and downquantize on a set of accumulator -// registers that correspond to a 2x2 output of depth 8. -// Stores results to output. -inline void DownquantizeAndStore2x2Output( - Int32x8 acc_0, Int32x8 acc_1, Int32x8 acc_2, Int32x8 acc_3, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - using gemmlowp::RoundingDivideByPOT; - const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); - const int32x4_t output_activation_min_vec = - vdupq_n_s32(output_activation_min); - const int32x4_t output_activation_max_vec = - vdupq_n_s32(output_activation_max); - - // Fixed-point multiplication. - acc_0.low = vqrdmulhq_n_s32(acc_0.low, output_multiplier); - acc_0.high = vqrdmulhq_n_s32(acc_0.high, output_multiplier); - acc_1.low = vqrdmulhq_n_s32(acc_1.low, output_multiplier); - acc_1.high = vqrdmulhq_n_s32(acc_1.high, output_multiplier); - acc_2.low = vqrdmulhq_n_s32(acc_2.low, output_multiplier); - acc_2.high = vqrdmulhq_n_s32(acc_2.high, output_multiplier); - acc_3.low = vqrdmulhq_n_s32(acc_3.low, output_multiplier); - acc_3.high = vqrdmulhq_n_s32(acc_3.high, output_multiplier); - - acc_0.low = RoundingDivideByPOT(acc_0.low, output_shift); - acc_0.high = RoundingDivideByPOT(acc_0.high, output_shift); - acc_1.low = RoundingDivideByPOT(acc_1.low, output_shift); - acc_1.high = RoundingDivideByPOT(acc_1.high, output_shift); - acc_2.low = RoundingDivideByPOT(acc_2.low, output_shift); - acc_2.high = RoundingDivideByPOT(acc_2.high, output_shift); - acc_3.low = RoundingDivideByPOT(acc_3.low, output_shift); - acc_3.high = RoundingDivideByPOT(acc_3.high, output_shift); - - // Add the output offset. - acc_0.low = vaddq_s32(acc_0.low, output_offset_vec); - acc_0.high = vaddq_s32(acc_0.high, output_offset_vec); - acc_1.low = vaddq_s32(acc_1.low, output_offset_vec); - acc_1.high = vaddq_s32(acc_1.high, output_offset_vec); - acc_2.low = vaddq_s32(acc_2.low, output_offset_vec); - acc_2.high = vaddq_s32(acc_2.high, output_offset_vec); - acc_3.low = vaddq_s32(acc_3.low, output_offset_vec); - acc_3.high = vaddq_s32(acc_3.high, output_offset_vec); - - // Apply the activation function. - acc_0.low = vmaxq_s32(acc_0.low, output_activation_min_vec); - acc_0.high = vmaxq_s32(acc_0.high, output_activation_min_vec); - acc_1.low = vmaxq_s32(acc_1.low, output_activation_min_vec); - acc_1.high = vmaxq_s32(acc_1.high, output_activation_min_vec); - acc_2.low = vmaxq_s32(acc_2.low, output_activation_min_vec); - acc_2.high = vmaxq_s32(acc_2.high, output_activation_min_vec); - acc_3.low = vmaxq_s32(acc_3.low, output_activation_min_vec); - acc_3.high = vmaxq_s32(acc_3.high, output_activation_min_vec); - - acc_0.low = vminq_s32(acc_0.low, output_activation_max_vec); - acc_0.high = vminq_s32(acc_0.high, output_activation_max_vec); - acc_1.low = vminq_s32(acc_1.low, output_activation_max_vec); - acc_1.high = vminq_s32(acc_1.high, output_activation_max_vec); - acc_2.low = vminq_s32(acc_2.low, output_activation_max_vec); - acc_2.high = vminq_s32(acc_2.high, output_activation_max_vec); - acc_3.low = vminq_s32(acc_3.low, output_activation_max_vec); - acc_3.high = vminq_s32(acc_3.high, output_activation_max_vec); - - // Saturating cast to uint8 and store to destination. - int16x4_t acc_0_low_s16 = vqmovn_s32(acc_0.low); - int16x4_t acc_0_high_s16 = vqmovn_s32(acc_0.high); - int16x4_t acc_1_low_s16 = vqmovn_s32(acc_1.low); - int16x4_t acc_1_high_s16 = vqmovn_s32(acc_1.high); - int16x4_t acc_2_low_s16 = vqmovn_s32(acc_2.low); - int16x4_t acc_2_high_s16 = vqmovn_s32(acc_2.high); - int16x4_t acc_3_low_s16 = vqmovn_s32(acc_3.low); - int16x4_t acc_3_high_s16 = vqmovn_s32(acc_3.high); - - int16x8_t res_0_s16 = vcombine_s16(acc_0_low_s16, acc_0_high_s16); - int16x8_t res_1_s16 = vcombine_s16(acc_1_low_s16, acc_1_high_s16); - int16x8_t res_2_s16 = vcombine_s16(acc_2_low_s16, acc_2_high_s16); - int16x8_t res_3_s16 = vcombine_s16(acc_3_low_s16, acc_3_high_s16); - - uint8x8_t res_0_u8 = vqmovun_s16(res_0_s16); - uint8x8_t res_1_u8 = vqmovun_s16(res_1_s16); - uint8x8_t res_2_u8 = vqmovun_s16(res_2_s16); - uint8x8_t res_3_u8 = vqmovun_s16(res_3_s16); - - vst1_u8(output_ptr, res_0_u8); - vst1_u8(output_ptr + output_depth, res_1_u8); - vst1_u8(output_ptr + output_depth * output_width, res_2_u8); - vst1_u8(output_ptr + output_depth * output_width + output_depth, res_3_u8); -} - -inline void DownquantizeAndStore(Int32x8 acc, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, - uint8* output_ptr) { - using gemmlowp::RoundingDivideByPOT; - const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); - const int32x4_t output_activation_min_vec = - vdupq_n_s32(output_activation_min); - const int32x4_t output_activation_max_vec = - vdupq_n_s32(output_activation_max); - - acc.low = vqrdmulhq_n_s32(acc.low, output_multiplier); - acc.high = vqrdmulhq_n_s32(acc.high, output_multiplier); - - acc.low = RoundingDivideByPOT(acc.low, output_shift); - acc.high = RoundingDivideByPOT(acc.high, output_shift); - - acc.low = vaddq_s32(acc.low, output_offset_vec); - acc.high = vaddq_s32(acc.high, output_offset_vec); - - acc.low = vmaxq_s32(acc.low, output_activation_min_vec); - acc.high = vmaxq_s32(acc.high, output_activation_min_vec); - - acc.low = vminq_s32(acc.low, output_activation_max_vec); - acc.high = vminq_s32(acc.high, output_activation_max_vec); - - int16x4_t acc_low_s16 = vqmovn_s32(acc.low); - int16x4_t acc_high_s16 = vqmovn_s32(acc.high); - - int16x8_t res_s16 = vcombine_s16(acc_low_s16, acc_high_s16); - uint8x8_t res_u8 = vqmovun_s16(res_s16); - vst1_u8(output_ptr, res_u8); -} - -inline void DownquantizeAndStore2Output( - Int32x8 acc_0, Int32x8 acc_1, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_ptr, int output_ptr_offset) { - { - using gemmlowp::RoundingDivideByPOT; - const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); - const int32x4_t output_activation_min_vec = - vdupq_n_s32(output_activation_min); - const int32x4_t output_activation_max_vec = - vdupq_n_s32(output_activation_max); - - // Fixed-point multiplication. - acc_0.low = vqrdmulhq_n_s32(acc_0.low, output_multiplier); - acc_0.high = vqrdmulhq_n_s32(acc_0.high, output_multiplier); - acc_1.low = vqrdmulhq_n_s32(acc_1.low, output_multiplier); - acc_1.high = vqrdmulhq_n_s32(acc_1.high, output_multiplier); - - acc_0.low = RoundingDivideByPOT(acc_0.low, output_shift); - acc_0.high = RoundingDivideByPOT(acc_0.high, output_shift); - acc_1.low = RoundingDivideByPOT(acc_1.low, output_shift); - acc_1.high = RoundingDivideByPOT(acc_1.high, output_shift); - - // Add the output offset. - acc_0.low = vaddq_s32(acc_0.low, output_offset_vec); - acc_0.high = vaddq_s32(acc_0.high, output_offset_vec); - acc_1.low = vaddq_s32(acc_1.low, output_offset_vec); - acc_1.high = vaddq_s32(acc_1.high, output_offset_vec); - - // Apply the activation function. - acc_0.low = vmaxq_s32(acc_0.low, output_activation_min_vec); - acc_0.high = vmaxq_s32(acc_0.high, output_activation_min_vec); - acc_1.low = vmaxq_s32(acc_1.low, output_activation_min_vec); - acc_1.high = vmaxq_s32(acc_1.high, output_activation_min_vec); - - acc_0.low = vminq_s32(acc_0.low, output_activation_max_vec); - acc_0.high = vminq_s32(acc_0.high, output_activation_max_vec); - acc_1.low = vminq_s32(acc_1.low, output_activation_max_vec); - acc_1.high = vminq_s32(acc_1.high, output_activation_max_vec); - } - - // Saturating cast to uint8 and store to destination. - int16x8_t res_0_s16; - { - int16x4_t acc_0_low_s16 = vqmovn_s32(acc_0.low); - int16x4_t acc_0_high_s16 = vqmovn_s32(acc_0.high); - res_0_s16 = vcombine_s16(acc_0_low_s16, acc_0_high_s16); - } - - int16x8_t res_1_s16; - { - int16x4_t acc_1_low_s16 = vqmovn_s32(acc_1.low); - int16x4_t acc_1_high_s16 = vqmovn_s32(acc_1.high); - res_1_s16 = vcombine_s16(acc_1_low_s16, acc_1_high_s16); - } - - uint8x8_t res_0_u8 = vqmovun_s16(res_0_s16); - uint8x8_t res_1_u8 = vqmovun_s16(res_1_s16); - vst1_u8(output_ptr, res_0_u8); - vst1_u8(output_ptr + output_ptr_offset, res_1_u8); -} - -// Performs multiply accumulate on 3 inputs of depth 8. -inline Int32x8 MultiplyAccumulateRow(Int32x8 accum, int16x8_t f0, int16x8_t f1, - int16x8_t f2, int16x8_t i0, int16x8_t i1, - int16x8_t i2) { - accum.low = vmlal_s16(accum.low, vget_low_s16(f0), vget_low_s16(i0)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f0), vget_high_s16(i0)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f1), vget_low_s16(i1)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f1), vget_high_s16(i1)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f2), vget_low_s16(i2)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f2), vget_high_s16(i2)); - return accum; -} - -// Performs multiply accumulate on 3 inputs of depth 8. -inline Int32x8 MultiplyAccumulate3x3Filter(const Filter3x3x8& f, int16x8_t i0, - int16x8_t i1, int16x8_t i2, - int16x8_t i3, int16x8_t i4, - int16x8_t i5, int16x8_t i6, - int16x8_t i7, int16x8_t i8, - Int32x8 accum) { - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f0), vget_low_s16(i0)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f0), vget_high_s16(i0)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f1), vget_low_s16(i1)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f1), vget_high_s16(i1)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f2), vget_low_s16(i2)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f2), vget_high_s16(i2)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f3), vget_low_s16(i3)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f3), vget_high_s16(i3)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f4), vget_low_s16(i4)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f4), vget_high_s16(i4)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f5), vget_low_s16(i5)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f5), vget_high_s16(i5)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f6), vget_low_s16(i6)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f6), vget_high_s16(i6)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f7), vget_low_s16(i7)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f7), vget_high_s16(i7)); - accum.low = vmlal_s16(accum.low, vget_low_s16(f.f8), vget_low_s16(i8)); - accum.high = vmlal_s16(accum.high, vget_high_s16(f.f8), vget_high_s16(i8)); - return accum; -} - -inline void DotProductAndStore(const Filter3x3x8& filter, int16x8_t i0, - int16x8_t i1, int16x8_t i2, int16x8_t i3, - int16x8_t i4, int16x8_t i5, int16x8_t i6, - int16x8_t i7, int16x8_t i8, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr) { - Int32x8 acc; - acc.low = vld1q_s32(bias_ptr); - acc.high = vld1q_s32(bias_ptr + 4); - - acc = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i3, i4, i5, i6, i7, i8, - acc); - - DownquantizeAndStore(acc, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, - output_ptr); -} - -// Performs multiply-accumulate on a 3x4 input for 2 horizontal outputs. -inline void DotProductAndStore2xStride1( - const Filter3x3x8& filter, int16x8_t i0, int16x8_t i1, int16x8_t i2, - int16x8_t i3, int16x8_t i4, int16x8_t i5, int16x8_t i6, int16x8_t i7, - int16x8_t i8, int16x8_t i9, int16x8_t i10, int16x8_t i11, - const int32* bias_ptr, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_ptr, int output_ptr_offset) { - Int32x8 acc_0, acc_1; - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i4, i5, i6, i8, i9, - i10, acc_0); - acc_1 = MultiplyAccumulate3x3Filter(filter, i1, i2, i3, i5, i6, i7, i9, i10, - i11, acc_1); - DownquantizeAndStore2Output(acc_0, acc_1, output_offset, output_multiplier, - output_shift, output_activation_min, - output_activation_max, output_ptr, - output_ptr_offset); -} - -// Performs multiply-accumulate on a 4x3 input for 2 vertical outputs. -inline void DotProductAndStore2yStride1( - const Filter3x3x8& filter, int16x8_t i0, int16x8_t i1, int16x8_t i2, - int16x8_t i3, int16x8_t i4, int16x8_t i5, int16x8_t i6, int16x8_t i7, - int16x8_t i8, int16x8_t i9, int16x8_t i10, int16x8_t i11, - const int32* bias_ptr, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_ptr, int output_ptr_offset) { - Int32x8 acc_0, acc_1; - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulate3x3Filter(filter, i0, i1, i2, i3, i4, i5, i6, i7, - i8, acc_0); - acc_1 = MultiplyAccumulate3x3Filter(filter, i3, i4, i5, i6, i7, i8, i9, i10, - i11, acc_1); - DownquantizeAndStore2Output(acc_0, acc_1, output_offset, output_multiplier, - output_shift, output_activation_min, - output_activation_max, output_ptr, - output_ptr_offset); -} - -// A kernel that is optimized on the number of output cells in the x and y -// direction, and the stride. Assumes 3x3 filters of 8 depth. -template -struct ConvKernel3x3FilterDepth8 {}; - -template <> -struct ConvKernel3x3FilterDepth8<8, 8, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - // To process 8x8 outputs using a 3x3 filter, we require 10x10 inputs. - // Load inputs for the first 2 filters on the top left, then slide to - // the right, down, left, down, right, etc. in a snake-like path. This - // minimizes the total number of loads. - // - // INPUT OUTPUT - // |\----------------\ |\------------\ - // | \ \ | \ \ - // | \----------------\ | \------------\ - // | | 0 ... 9 | | | 0 ... 7 | - // | | 10 ... 19 | ---> | | 8 ... 15 | - // | | 20 ... 29 | \ | .. ... .. | - // \ | .. ... .. | \| 56 ... 63 | - // \| 90 ... 109 | |------------| - // |----------------| - // - // The first set of loads corresponds to: - // - // INPUT OUTPUT - // |\----------------- |\----------- - // | \ | \ - // | \----------------- | \---------- - // | | 0 1 2 3 ... | | 0 1 ... - // | | 10 11 12 13 ... ---> | | .. ... - // | | 20 21 22 23 ... | .. ... - // | | .. ... ... - // - // The next set of loads correspond to a sliding window to the right. - // It loads inputs 4, 5, 14, 15, 23, 24 and keeps 2, 3, 12, 13, and 22: - // - // INPUT OUTPUT - // |\------------------- |\------------- - // | \ | \ - // | \------------------- | \------------ - // | | .. 2 3 4 5 ... | | .. 2 3 ... - // | | .. 12 13 14 15 ... ---> | | .. ... - // | | .. 21 22 23 24 ... | .. ... - // | | .. ... ... - // - // And so on... - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top left. Referring to the - // indexes in the diagram above, this corresponds to outputs (0) and (1). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - // Slide to the right for outputs x = [2, 3], y = 0. Referring to the - // indexes in the diagram above, this corresponds to outputs (2) and (3). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth, output_depth); - - // Slide to the right again for outputs x = [4, 5], y = 0. Referring to the - // indexes in the diagram above, this corresponds to outputs (4) and (5). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 6 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 4 * output_depth, output_depth); - - // Slide to the right one last time for outputs x = [6, 7], y = 0. - // Referring to the indexes in the diagram above, this corresponds to - // outputs (6) and (7). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 8 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 6 * output_depth, output_depth); - - // Slide to down for outputs x = [6, 7], y = 1. Referring to the indexes in - // the diagram above, this corresponds to outputs (14) and (15). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 6 * output_depth + output_row_size, - output_depth); - - // Slide left for outputs x = [4, 5], y = 1. Referring to the indexes in - // the diagram above, this corresponds to outputs (12) and (13). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 4 * output_depth + output_row_size, - output_depth); - - // Slide left again for outputs x = [2, 3], y = 1. Referring to the indexes - // in the diagram above, this corresponds to outputs (10) and (11). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 2 * input_depth + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth + output_row_size, - output_depth); - - // Slide left one more time for outputs x = [0, 1], y = 1. Referring to the - // indexes in the diagram above, this corresponds to outputs (8) and (9). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + output_row_size, output_depth); - - // Slide down for outputs x = [0, 1], y = 2. Referring to the - // indexes in the diagram above, this corresponds to outputs (16) and (17). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_row_size, output_depth); - - // Slide right for outputs x = [2, 3], y = 2. Referring to the - // indexes in the diagram above, this corresponds to outputs (18) and (19). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 2 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 2 * output_row_size, output_depth); - - // Slide right for outputs x = [4, 5], y = 2. Referring to the - // indexes in the diagram above, this corresponds to outputs (20) and (21). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 6 * input_depth + 2 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 2 * output_row_size, output_depth); - - // Slide right one more time for outputs x = [6, 7], y = 2. Referring to the - // indexes in the diagram above, this corresponds to outputs (22) and (23). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 8 * input_depth + 2 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 2 * output_row_size, output_depth); - - // Slide down for outputs x = [6, 7], y = 3. Referring to the indexes in - // the diagram above, this corresponds to outputs (30) and (31). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_depth + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 3 * output_row_size, output_depth); - - // Slide left for outputs x = [4, 5], y = 3. Referring to the indexes in - // the diagram above, this corresponds to outputs (28) and (29). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 3 * output_row_size, output_depth); - - // Slide left for outputs x = [2, 3], y = 3. Referring to the indexes in - // the diagram above, this corresponds to outputs (26) and (27). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 3 * output_row_size, output_depth); - - // Slide left one more time for outputs x = [0, 1], y = 3. Referring to the - // indexes in the diagram above, this corresponds to outputs (24) and (25). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 3 * output_row_size, output_depth); - - // Slide down for outputs x = [0, 1], y = 4. Referring to the indexes in - // the diagram above, this corresponds to outputs (32) and (33). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 4 * output_row_size, output_depth); - - // Slide right for outputs x = [2, 3], y = 4. Referring to the indexes in - // the diagram above, this corresponds to outputs (34) and (35). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 4 * output_row_size, output_depth); - - // Slide right for outputs x = [4, 5], y = 4. Referring to the indexes in - // the diagram above, this corresponds to outputs (36) and (37). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 6 * input_depth + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 4 * output_row_size, output_depth); - - // Slide right one more time for outputs x = [6, 7], y = 4. Referring to the - // indexes in the diagram above, this corresponds to outputs (38) and (39). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 8 * input_depth + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 4 * output_row_size, output_depth); - - // Slide down for outputs x = [6, 7], y = 5. Referring to the indexes in - // the diagram above, this corresponds to outputs (46) and (47). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_depth + 7 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 5 * output_row_size, output_depth); - - // Slide left for outputs x = [4, 5], y = 5. Referring to the indexes in - // the diagram above, this corresponds to outputs (44) and (45). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 5 * output_row_size, output_depth); - - // Slide left for outputs x = [2, 3], y = 5. Referring to the indexes in - // the diagram above, this corresponds to outputs (42) and (43). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 2 * input_depth + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 5 * output_row_size, output_depth); - - // Slide left one more time for outputs x = [0, 1], y = 5. Referring to the - // indexes in the diagram above, this corresponds to outputs (40) and (41). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 5 * output_row_size, output_depth); - - // Slide down for outputs x = [0, 1], y = 6. Referring to the indexes in - // the diagram above, this corresponds to outputs (48) and (49). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 8 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 6 * output_row_size, output_depth); - - // Slide right for outputs x = [2, 3], y = 6. Referring to the indexes in - // the diagram above, this corresponds to outputs (50) and (51). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 6 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 6 * output_row_size, output_depth); - - // Slide right for outputs x = [4, 5], y = 6. Referring to the indexes in - // the diagram above, this corresponds to outputs (52) and (53). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 6 * input_depth + 6 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 6 * output_row_size, output_depth); - - // Slide right one more time for outputs x = [6, 7], y = 6. Referring to the - // indexes in the diagram above, this corresponds to outputs (54) and (55). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 8 * input_depth + 6 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 6 * output_row_size, output_depth); - - // Slide down for outputs x = [6, 7], y = 7. Referring to the indexes in the - // diagram above, this corresponds to outputs (62) and (63). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 6 * input_depth + 9 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 6 * output_depth + 7 * output_row_size, output_depth); - - // Slide left for outputs x = [4, 5], y = 7. Referring to the indexes in the - // diagram above, this corresponds to outputs (60) and (61). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 7 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 4 * output_depth + 7 * output_row_size, output_depth); - - // Slide left for outputs x = [2, 3], y = 7. Referring to the indexes in the - // diagram above, this corresponds to outputs (58) and (59). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 2 * input_depth + 7 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 7 * output_row_size, output_depth); - - // Slide left one more time for outputs x = [0, 1], y = 7. Referring to the - // indexes in the diagram above, this corresponds to outputs (56) and (57). - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 7 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 7 * output_row_size, output_depth); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<4, 4, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - // To process 4x4 outputs using a 3x3 filter, we require 6x6 inputs. - // Load inputs for the first 2 filters on the top left, then slide to - // the right, down, left, down, right, etc. in a snake-like path. This - // minimizes the total number of loads. - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - // Now load 1x2 inputs on the top right. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth, output_depth); - - // Now load next inputs when sliding window down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth + output_row_size, - output_depth); - - // Now load next inputs when sliding window left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + output_row_size, output_depth); - - // Now load next inputs when sliding window down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_row_size, output_depth); - - // Now load next inputs when sliding window right. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth + 2 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_10, input_11, input_8, input_9, input_2, input_3, input_0, - input_1, input_6, input_7, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 2 * output_row_size, output_depth); - - // Now load next inputs when sliding window down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 2 * input_depth + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, - output_ptr + 2 * output_depth + 3 * output_row_size, output_depth); - - // Now load next inputs when sliding window left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 3 * output_row_size, output_depth); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<4, 2, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Now load next inputs one row down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Now load next row. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_8, input_9, input_10, input_11, input_0, input_1, input_2, - input_3, input_4, input_5, input_6, input_7, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Now load last row. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 5 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<4, 1, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 2x1 outputs starting from the top. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2yStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_row_size); - - // Load inputs for bottom 2 rows. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - } - - DotProductAndStore2yStride1( - filter, input_6, input_7, input_8, input_9, input_10, input_11, input_0, - input_1, input_2, input_3, input_4, input_5, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_row_size, - output_row_size); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<2, 2, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - Int32x8 acc_0, acc_1, acc_2, acc_3; - - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_2.low = vld1q_s32(bias_ptr); - acc_3.low = vld1q_s32(bias_ptr); - - bias_ptr += 4; - acc_0.high = vld1q_s32(bias_ptr); - acc_1.high = vld1q_s32(bias_ptr); - acc_2.high = vld1q_s32(bias_ptr); - acc_3.high = vld1q_s32(bias_ptr); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - - // Add scope for input registers to help the compiler know that it is - // not needed. - { - // To process 2x2 outputs using a 3x3 filter, we require 4x4 inputs. - // Load inputs for the top two filters first. - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - const uint8* ptr = input_ptr; - - // Load top 3 rows. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } +// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on +// Jetson TX-2. This compiler does not support the offsetof() macro. +#if defined(__aarch64__) && !defined(GOOGLE_L4T) - // Multiply-accum for top-left output. - acc_0 = MultiplyAccumulate3x3Filter(filter, input_0, input_1, input_2, - input_4, input_5, input_6, input_8, - input_9, input_10, acc_0); - - // Multiply-accum for top-right output. - acc_1 = MultiplyAccumulate3x3Filter(filter, input_1, input_2, input_3, - input_5, input_6, input_7, input_9, - input_10, input_11, acc_1); - - // Now load the bottom row. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - } - - // Multiply-accum for bottom-left output. - acc_2 = MultiplyAccumulate3x3Filter(filter, input_4, input_5, input_6, - input_8, input_9, input_10, input_0, - input_1, input_2, acc_2); - - // Multiply-accum for bottom-right output. - acc_3 = MultiplyAccumulate3x3Filter(filter, input_5, input_6, input_7, - input_9, input_10, input_11, input_1, - input_2, input_3, acc_3); - } - - DownquantizeAndStore2x2Output(acc_0, acc_1, acc_2, acc_3, output_offset, - output_multiplier, output_shift, - output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<2, 4, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - const int output_row_size = output_depth * output_width; - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } +// clang-format gets confused with this file and ends up formatting lines to +// be larger than 80 characters. Turn off here and back on at the end of the +// file. - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - // Now load 1x2 inputs on the top right. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + 4 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth, output_depth); - - // Now load next inputs when sliding window down. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr + 2 * input_depth + 3 * input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_6, input_7, input_4, input_5, input_10, input_11, input_8, - input_9, input_2, input_3, input_0, input_1, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth + output_row_size, - output_depth); - - // Now load next inputs when sliding window left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - } +// clang-format off - DotProductAndStore2xStride1( - filter, input_4, input_5, input_6, input_7, input_8, input_9, input_10, - input_11, input_0, input_1, input_2, input_3, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + output_row_size, output_depth); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<1, 4, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the left. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2xStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth); - - // Now load 1x2 inputs on the right. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr + input_depth * 4; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_2 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } +#define DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE 10 * 10 * 64 - DotProductAndStore2xStride1( - filter, input_2, input_3, input_0, input_1, input_6, input_7, input_4, - input_5, input_10, input_11, input_8, input_9, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr + 2 * output_depth, output_depth); - } +// Encapsulates constant parameters used in DepthwiseConv. +// 64-bit is used for types that will be added to 64-bit addresses in asm. +struct DepthwiseConvParams { + int64_t input_depth; + int64_t input_row_size; + int64_t output_depth; + int64_t output_row_size; + int64_t filter_row_size; + int32 input_offset; + int32 output_offset; + int32 filter_offset; + int32 output_multiplier; + int32 output_activation_min; + int32 output_activation_max; + int32 output_shift; + int32 input_width; + int32 input_height; + int32 stride_width; + int32 stride_height; + int32 output_width; + int32 output_height; }; -template <> -struct ConvKernel3x3FilterDepth8<2, 1, 1, 1> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - // To process 2x1 outputs using a 3x3 filter, we require 4x3 inputs. - // Load all inputs at the beginning. - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11; - - // Load inputs for 1x2 outputs starting from the top left. - { - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5; - - const uint8* ptr = input_ptr; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_10 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_11 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - input_10 = vaddq_s16(input_10, input_offset_vec); - input_11 = vaddq_s16(input_11, input_offset_vec); - } - - DotProductAndStore2yStride1( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9, input_10, input_11, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth * output_width); - } -}; +#define STR(s) STR_UNEXPANDED(s) +#define STR_UNEXPANDED(s) #s + +// Represents the number of bytes offset from the start of the +// DepthwiseConvParams struct. This is used in the asm to load parameters. +// Keep these values in sync with the static_asserts below. +#define OFFSET_INPUT_DEPTH 0 +#define OFFSET_INPUT_ROW_SIZE 8 +#define OFFSET_OUTPUT_DEPTH 16 +#define OFFSET_OUTPUT_ROW_SIZE 24 +#define OFFSET_FILTER_ROW_SIZE 32 +#define OFFSET_INPUT_OFFSET 40 +#define OFFSET_OUTPUT_OFFSET 44 +#define OFFSET_FILTER_OFFSET 48 +#define OFFSET_OUTPUT_MULTIPLIER 52 +#define OFFSET_OUTPUT_ACTIVATION_MIN 56 +#define OFFSET_OUTPUT_ACTIVATION_MAX 60 +#define OFFSET_OUTPUT_SHIFT 64 +#define OFFSET_INPUT_WIDTH 68 +#define OFFSET_INPUT_HEIGHT 72 +#define OFFSET_STRIDE_WIDTH 76 +#define OFFSET_STRIDE_HEIGHT 80 +#define OFFSET_OUTPUT_WIDTH 84 +#define OFFSET_OUTPUT_HEIGHT 88 + +static_assert(offsetof(DepthwiseConvParams, input_depth) == + OFFSET_INPUT_DEPTH, ""); +static_assert(offsetof(DepthwiseConvParams, input_row_size) == + OFFSET_INPUT_ROW_SIZE, ""); +static_assert(offsetof(DepthwiseConvParams, output_depth) == + OFFSET_OUTPUT_DEPTH, ""); +static_assert(offsetof(DepthwiseConvParams, output_row_size) == + OFFSET_OUTPUT_ROW_SIZE, ""); +static_assert(offsetof(DepthwiseConvParams, filter_row_size) == + OFFSET_FILTER_ROW_SIZE, ""); +static_assert(offsetof(DepthwiseConvParams, input_offset) == + OFFSET_INPUT_OFFSET, ""); +static_assert(offsetof(DepthwiseConvParams, output_offset) == + OFFSET_OUTPUT_OFFSET, ""); +static_assert(offsetof(DepthwiseConvParams, filter_offset) == + OFFSET_FILTER_OFFSET, ""); +static_assert(offsetof(DepthwiseConvParams, output_multiplier) == + OFFSET_OUTPUT_MULTIPLIER, ""); +static_assert(offsetof(DepthwiseConvParams, output_activation_min) == + OFFSET_OUTPUT_ACTIVATION_MIN, ""); +static_assert(offsetof(DepthwiseConvParams, output_activation_max) == + OFFSET_OUTPUT_ACTIVATION_MAX, ""); +static_assert(offsetof(DepthwiseConvParams, output_shift) == + OFFSET_OUTPUT_SHIFT, ""); +static_assert(offsetof(DepthwiseConvParams, input_width) == + OFFSET_INPUT_WIDTH, ""); +static_assert(offsetof(DepthwiseConvParams, input_height) == + OFFSET_INPUT_HEIGHT, ""); +static_assert(offsetof(DepthwiseConvParams, stride_width) == + OFFSET_STRIDE_WIDTH, ""); +static_assert(offsetof(DepthwiseConvParams, stride_height) == + OFFSET_STRIDE_HEIGHT, ""); +static_assert(offsetof(DepthwiseConvParams, output_width) == + OFFSET_OUTPUT_WIDTH, ""); +static_assert(offsetof(DepthwiseConvParams, output_height) == + OFFSET_OUTPUT_HEIGHT, ""); + +template +struct DepthwiseConvWindow {}; template <> -struct ConvKernel3x3FilterDepth8<4, 2, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - const int output_row_size = output_depth * output_width; - - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - Int32x8 acc_0, acc_1; - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9; - - const uint8* ptr = input_ptr; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4; - - // Load first 2 rows. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load next 2 rows. - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - DownquantizeAndStore2Output( - acc_0, acc_1, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Moving onto the next row of outputs. - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load next 2 rows. - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - DownquantizeAndStore2Output( - acc_0, acc_1, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Moving onto the next row of outputs. - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load next 2 rows. - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - DownquantizeAndStore2Output( - acc_0, acc_1, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth); - - output_ptr += output_row_size; - - // Moving onto the next row of outputs. - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_0.high = vld1q_s32(bias_ptr + 4); - acc_1.high = vld1q_s32(bias_ptr + 4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load last row. - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - DownquantizeAndStore2Output( - acc_0, acc_1, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth); +struct DepthwiseConvWindow<8, 1, 1> { + public: + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, int64_t input_depth, + int64_t input_row_size, int32 output_window_height, + int32 output_window_width, + const DepthwiseConvParams* params_ptr) { + const int64_t input_width_increment = 2 * input_depth; + const int64_t input_height_increment = 2 * input_row_size; + const int64_t output_height_increment = 2 * params_ptr->output_row_size; + +#define DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "1" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "2" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "3" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "4" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "5" +#define DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "6" +#define DEPTHWISECONV_LABEL_HEIGHT_1 "7" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "8" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "9" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "10" +#define DEPTHWISECONV_LABEL_HEIGHT_1_END "11" + + asm volatile( + // Performs depthwise convolutions for a window specified by + // |output_window_height| and |output_window_width|. The inner-most loop + // processes 2x2 outputs, and any leftovers at the end. + // + // Algorithm works as follows: + // + // 1. Load filters of 8 depth (8x3x3). Registers v0--v8 hold filter + // values. + // 2. For 2 output heights at a time: + // i. For 2 output widths at a time, load inputs for a 2x1 (2 + // height, 1 width) output window (4x3 input window). + // Registers v9--v20 hold input values. Mul-add with + // accumulators v21--v24. Then run activation, downquantize + // and store. Repeat for the next 2x1 output window, + // leveraging overlapping inputs. + // ii. Handle single leftover width if exists. + // 3. Handle single leftover height if exists. + // i. For 2 output widths at a time, load inputs for a 1x2 (1 + // height, 2 width) output window (3x4 input window). + // Registers v9--v20 hold input values. Mul-add with + // accumulators v21--v24. Then run activation, downquantize + // and store. Repeat for the next 1x2 output window, + // leveraging overlapping inputs. + // ii. Handle single leftover width if exists. + // + // Loads are placed as soon as the register is no longer needed and + // interleaved with arithmetic operations to take advantage of + // dual-issue pipelines. We also add input offsets as far from the loads + // as possible to give loads enough cycles to fetch data from memory. + + // Set "constant" registers. These registers may be replaced with temp + // values from time to time when there are not enough NEON registers. + // We use x9--x15 general purpose registers as they are caller-saved + // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT + "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr x3, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "cmp %w[output_window_height], #2\n" + "dup v26.8h, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v29.4s, w2\n" + "ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "dup v30.4s, w4\n" + "ldr w0, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v31.4s, w0\n" + "neg w9, w9\n" + "dup v28.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "add x10, %[bias_ptr], #16\n" + "ldr x1, [%[params_ptr], #" STR(OFFSET_OUTPUT_ROW_SIZE) "]\n" + "dup v9.8h, w9\n" + + // Load filters and add offsets. + "ld1 {v0.8b}, [%[filter_ptr]], x3\n" + "ld1 {v1.8b}, [%[filter_ptr]], x3\n" + "uaddw v0.8h, v9.8h, v0.8b\n" + "ld1 {v2.8b}, [%[filter_ptr]], x3\n" + "uaddw v1.8h, v9.8h, v1.8b\n" + "ld1 {v3.8b}, [%[filter_ptr]], x3\n" + "uaddw v2.8h, v9.8h, v2.8b\n" + "ld1 {v4.8b}, [%[filter_ptr]], x3\n" + "uaddw v3.8h, v9.8h, v3.8b\n" + "ld1 {v5.8b}, [%[filter_ptr]], x3\n" + "uaddw v4.8h, v9.8h, v4.8b\n" + "ld1 {v6.8b}, [%[filter_ptr]], x3\n" + "uaddw v5.8h, v9.8h, v5.8b\n" + "ld1 {v7.8b}, [%[filter_ptr]], x3\n" + "uaddw v6.8h, v9.8h, v6.8b\n" + "ld1 {v8.8b}, [%[filter_ptr]], x3\n" + "uaddw v7.8h, v9.8h, v7.8b\n" + "uaddw v8.8h, v9.8h, v8.8b\n" + + "blt " DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_2_LOOP ":\n" + // This loop processes 2x2 outputs. To avoid register exhaustion, + // inputs for the left 2 outputs are loaded first, then the right + // two outputs. + "mov x11, %[input_ptr]\n" + "mov x12, x11\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "add x13, x11, %[input_row_size]\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "add x14, x13, %[input_row_size]\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "add x15, x14, %[input_row_size]\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "mov w5, %w[output_window_width]\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "mov x6, %[output_ptr]\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "add x7, %[output_ptr], x1\n" + "ld1 {v15.8b}, [x14], %[input_depth]\n" + // The height 2 / width 2 loop loads an extra 2x1 outputs (2 height, + // 1 width) in anticipation for the next iteration. Make sure + // |output_window_width| is large enough to handle the additional + // loads, otherwise jump to specific the appropriate label to handle + // smaller widths. + "cmp w5, #2\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v16.8b}, [x14], %[input_depth]\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "ld1 {v17.8b}, [x14], %[input_depth]\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "ld1 {v18.8b}, [x15], %[input_depth]\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "ld1 {v19.8b}, [x15], %[input_depth]\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + "ld1 {v20.8b}, [x15], %[input_depth]\n" + "uaddw v14.8h, v26.8h, v14.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v22.4s}, [x10]\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "uaddw v17.8h, v26.8h, v17.8b\n" + "ld1 {v24.4s}, [x10]\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "uaddw v19.8h, v26.8h, v19.8b\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + + "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "f\n" + "cmp w5, #1\n" + "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP ":\n" + // Mul-add left outputs. + "smlal v21.4s, v0.4h, v9.4h\n" + "subs w5, w5, #2\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "cmp w5, #3\n" + "smlal v23.4s, v0.4h, v12.4h\n" + "ld1 {v9.8b}, [x12]\n" + "smlal2 v24.4s, v0.8h, v12.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v23.4s, v1.4h, v13.4h\n" + "smlal2 v24.4s, v1.8h, v13.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v23.4s, v2.4h, v14.4h\n" + "smlal2 v24.4s, v2.8h, v14.8h\n" + "smlal v21.4s, v3.4h, v12.4h\n" + "smlal2 v22.4s, v3.8h, v12.8h\n" + "ld1 {v12.8b}, [x13]\n" + "smlal v23.4s, v3.4h, v15.4h\n" + "smlal2 v24.4s, v3.8h, v15.8h\n" + "smlal v21.4s, v4.4h, v13.4h\n" + "smlal2 v22.4s, v4.8h, v13.8h\n" + "smlal v23.4s, v4.4h, v16.4h\n" + "smlal2 v24.4s, v4.8h, v16.8h\n" + "smlal v21.4s, v5.4h, v14.4h\n" + "smlal2 v22.4s, v5.8h, v14.8h\n" + "smlal v23.4s, v5.4h, v17.4h\n" + "smlal2 v24.4s, v5.8h, v17.8h\n" + "smlal v21.4s, v6.4h, v15.4h\n" + "smlal2 v22.4s, v6.8h, v15.8h\n" + "ld1 {v15.8b}, [x14]\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "ld1 {v18.8b}, [x15]\n" + "smlal v21.4s, v7.4h, v16.4h\n" + "smlal2 v22.4s, v7.8h, v16.8h\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v17.4h\n" + "smlal2 v22.4s, v8.8h, v17.8h\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x3\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "st1 {v23.8b}, [x7], x3\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + + // Mul-add right outputs. + "smlal v21.4s, v0.4h, v10.4h\n" + "add x11, x11, %[input_width_increment]\n" + "smlal2 v22.4s, v0.8h, v10.8h\n" + "mov x12, x11\n" + "smlal v23.4s, v0.4h, v13.4h\n" + "add x13, x11, %[input_row_size]\n" + "smlal2 v24.4s, v0.8h, v13.8h\n" + "add x14, x13, %[input_row_size]\n" + "smlal v21.4s, v1.4h, v11.4h\n" + "add x15, x14, %[input_row_size]\n" + "smlal2 v22.4s, v1.8h, v11.8h\n" + "smlal v23.4s, v1.4h, v14.4h\n" + "smlal2 v24.4s, v1.8h, v14.8h\n" + "smlal v21.4s, v2.4h, v9.4h\n" + "smlal2 v22.4s, v2.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v2.4h, v12.4h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "smlal2 v24.4s, v2.8h, v12.8h\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "smlal v23.4s, v3.4h, v16.4h\n" + "smlal2 v24.4s, v3.8h, v16.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "smlal v23.4s, v4.4h, v17.4h\n" + "smlal2 v24.4s, v4.8h, v17.8h\n" + "smlal v21.4s, v5.4h, v12.4h\n" + "smlal2 v22.4s, v5.8h, v12.8h\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v5.4h, v15.4h\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v5.8h, v15.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v21.4s, v6.4h, v16.4h\n" + "smlal2 v22.4s, v6.8h, v16.8h\n" + "smlal v23.4s, v6.4h, v19.4h\n" + "smlal2 v24.4s, v6.8h, v19.8h\n" + "smlal v21.4s, v7.4h, v17.4h\n" + "smlal2 v22.4s, v7.8h, v17.8h\n" + "smlal v23.4s, v7.4h, v20.4h\n" + "smlal2 v24.4s, v7.8h, v20.8h\n" + "smlal v21.4s, v8.4h, v15.4h\n" + "smlal2 v22.4s, v8.8h, v15.8h\n" + "ld1 {v15.8b}, [x14], %[input_depth]\n" + "smlal v23.4s, v8.4h, v18.4h\n" + "ld1 {v16.8b}, [x14], %[input_depth]\n" + "smlal2 v24.4s, v8.8h, v18.8h\n" + "ld1 {v17.8b}, [x14], %[input_depth]\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "ld1 {v18.8b}, [x15], %[input_depth]\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "ld1 {v19.8b}, [x15], %[input_depth]\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "ld1 {v20.8b}, [x15], %[input_depth]\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x3\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "st1 {v23.8b}, [x7], x3\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + "uaddw v14.8h, v26.8h, v14.8b\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "uaddw v17.8h, v26.8h, v17.8b\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "uaddw v19.8h, v26.8h, v19.8b\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + + "bge " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "b\n" + + // At this point, there will be one of 2 width or 1 width leftover, + // not both. + "cmp w5, #2\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n" + + // Handle last 2 columns if exists. + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER ":\n" + // Mul-add left outputs. + "smlal v21.4s, v0.4h, v9.4h\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "smlal v23.4s, v0.4h, v12.4h\n" + "ld1 {v9.8b}, [x12]\n" + "smlal2 v24.4s, v0.8h, v12.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v23.4s, v1.4h, v13.4h\n" + "smlal2 v24.4s, v1.8h, v13.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v23.4s, v2.4h, v14.4h\n" + "smlal2 v24.4s, v2.8h, v14.8h\n" + "smlal v21.4s, v3.4h, v12.4h\n" + "smlal2 v22.4s, v3.8h, v12.8h\n" + "ld1 {v12.8b}, [x13]\n" + "smlal v23.4s, v3.4h, v15.4h\n" + "smlal2 v24.4s, v3.8h, v15.8h\n" + "smlal v21.4s, v4.4h, v13.4h\n" + "smlal2 v22.4s, v4.8h, v13.8h\n" + "smlal v23.4s, v4.4h, v16.4h\n" + "smlal2 v24.4s, v4.8h, v16.8h\n" + "smlal v21.4s, v5.4h, v14.4h\n" + "smlal2 v22.4s, v5.8h, v14.8h\n" + "smlal v23.4s, v5.4h, v17.4h\n" + "smlal2 v24.4s, v5.8h, v17.8h\n" + "smlal v21.4s, v6.4h, v15.4h\n" + "smlal2 v22.4s, v6.8h, v15.8h\n" + "ld1 {v15.8b}, [x14]\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "ld1 {v18.8b}, [x15]\n" + "smlal v21.4s, v7.4h, v16.4h\n" + "smlal2 v22.4s, v7.8h, v16.8h\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v17.4h\n" + "smlal2 v22.4s, v8.8h, v17.8h\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x3\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "st1 {v23.8b}, [x7], x3\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + + // Mul-add right outputs. + "smlal v21.4s, v0.4h, v10.4h\n" + "smlal2 v22.4s, v0.8h, v10.8h\n" + "smlal v23.4s, v0.4h, v13.4h\n" + "smlal2 v24.4s, v0.8h, v13.8h\n" + "smlal v21.4s, v1.4h, v11.4h\n" + "smlal2 v22.4s, v1.8h, v11.8h\n" + "smlal v23.4s, v1.4h, v14.4h\n" + "smlal2 v24.4s, v1.8h, v14.8h\n" + "smlal v21.4s, v2.4h, v9.4h\n" + "smlal2 v22.4s, v2.8h, v9.8h\n" + "smlal v23.4s, v2.4h, v12.4h\n" + "smlal2 v24.4s, v2.8h, v12.8h\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "smlal v23.4s, v3.4h, v16.4h\n" + "smlal2 v24.4s, v3.8h, v16.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "smlal v23.4s, v4.4h, v17.4h\n" + "smlal2 v24.4s, v4.8h, v17.8h\n" + "smlal v21.4s, v5.4h, v12.4h\n" + "smlal2 v22.4s, v5.8h, v12.8h\n" + "smlal v23.4s, v5.4h, v15.4h\n" + "smlal2 v24.4s, v5.8h, v15.8h\n" + "smlal v21.4s, v6.4h, v16.4h\n" + "smlal2 v22.4s, v6.8h, v16.8h\n" + "smlal v23.4s, v6.4h, v19.4h\n" + "smlal2 v24.4s, v6.8h, v19.8h\n" + "smlal v21.4s, v7.4h, v17.4h\n" + "smlal2 v22.4s, v7.8h, v17.8h\n" + "smlal v23.4s, v7.4h, v20.4h\n" + "smlal2 v24.4s, v7.8h, v20.8h\n" + "smlal v21.4s, v8.4h, v15.4h\n" + "smlal2 v22.4s, v8.8h, v15.8h\n" + "smlal v23.4s, v8.4h, v18.4h\n" + "smlal2 v24.4s, v8.8h, v18.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtn2 v23.8h, v24.4s\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "st1 {v21.8b}, [x6], x3\n" + "st1 {v23.8b}, [x7], x3\n" + "b " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "f\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "smlal v23.4s, v0.4h, v12.4h\n" + "smlal2 v24.4s, v0.8h, v12.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v23.4s, v1.4h, v13.4h\n" + "smlal2 v24.4s, v1.8h, v13.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v23.4s, v2.4h, v14.4h\n" + "smlal2 v24.4s, v2.8h, v14.8h\n" + "smlal v21.4s, v3.4h, v12.4h\n" + "smlal2 v22.4s, v3.8h, v12.8h\n" + "smlal v23.4s, v3.4h, v15.4h\n" + "smlal2 v24.4s, v3.8h, v15.8h\n" + "smlal v21.4s, v4.4h, v13.4h\n" + "smlal2 v22.4s, v4.8h, v13.8h\n" + "smlal v23.4s, v4.4h, v16.4h\n" + "smlal2 v24.4s, v4.8h, v16.8h\n" + "smlal v21.4s, v5.4h, v14.4h\n" + "smlal2 v22.4s, v5.8h, v14.8h\n" + "smlal v23.4s, v5.4h, v17.4h\n" + "smlal2 v24.4s, v5.8h, v17.8h\n" + "smlal v21.4s, v6.4h, v15.4h\n" + "smlal2 v22.4s, v6.8h, v15.8h\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "smlal v21.4s, v7.4h, v16.4h\n" + "smlal2 v22.4s, v7.8h, v16.8h\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v17.4h\n" + "smlal2 v22.4s, v8.8h, v17.8h\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v9.16b, v21.16b, v28.16b\n" + "and v12.16b, v22.16b, v28.16b\n" + "and v15.16b, v23.16b, v28.16b\n" + "and v18.16b, v24.16b, v28.16b\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v12.4s, v12.4s, #31\n" + "sshr v15.4s, v15.4s, #31\n" + "sshr v18.4s, v18.4s, #31\n" + "sqadd v21.4s, v21.4s, v9.4s\n" + "sqadd v22.4s, v22.4s, v12.4s\n" + "sqadd v23.4s, v23.4s, v15.4s\n" + "sqadd v24.4s, v24.4s, v18.4s\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtn2 v23.8h, v24.4s\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "st1 {v21.8b}, [x6], x3\n" + "st1 {v23.8b}, [x7], x3\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP ":\n" + "subs %w[output_window_height], %w[output_window_height], #2\n" + "add %[input_ptr], %[input_ptr], %[input_height_increment]\n" + "cmp %w[output_window_height], #2\n" + "add %[output_ptr], %[output_ptr], %[output_height_increment]\n" + "bge " DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "b\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP ":\n" + "cmp %w[output_window_height], #1\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n" + + DEPTHWISECONV_LABEL_HEIGHT_1 ":\n" + "mov x12, %[input_ptr]\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "add x13, %[input_ptr], %[input_row_size]\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "add x14, x13, %[input_row_size]\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "add x15, x14, %[input_row_size]\n" + "mov w5, %w[output_window_width]\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "mov x6, %[output_ptr]\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "add x7, %[output_ptr], x1\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + // The height 1 / width 2 loop loads an extra 1x1 output in anticipation + // for the next iteration. Make sure |output_window_width| is large + // enough to handle the additional load, otherwise jump to the + // appropriate label to handle smaller widths. + "cmp w5, #2\n" + "ld1 {v17.8b}, [x14], %[input_depth]\n" + "ld1 {v18.8b}, [x14], %[input_depth]\n" + "ld1 {v19.8b}, [x14], %[input_depth]\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "ld1 {v22.4s}, [x10]\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "ld1 {v24.4s}, [x10]\n" + + "uaddw v9.8h, v26.8h, v9.8b\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + "uaddw v14.8h, v26.8h, v14.8b\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "uaddw v17.8h, v26.8h, v17.8b\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "uaddw v19.8h, v26.8h, v19.8b\n" + + "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "f\n" + "cmp w5, #1\n" + "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP ":\n" + // Load inputs for 3x4 input window which corresponds to a 1x2 output + // window. + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v12.8b}, [x12]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v16.8b}, [x13]\n" + "smlal v23.4s, v0.4h, v10.4h\n" + "ld1 {v20.8b}, [x14]\n" + "smlal2 v24.4s, v0.8h, v10.8h\n" + "subs w5, w5, #2\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "cmp w5, #3\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "add %[input_ptr], %[input_ptr], %[input_width_increment]\n" + "smlal v23.4s, v1.4h, v11.4h\n" + "mov x12, %[input_ptr]\n" + "smlal2 v24.4s, v1.8h, v11.8h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "add x13, %[input_ptr], %[input_row_size]\n" + "smlal v23.4s, v2.4h, v12.4h\n" + "add x14, x13, %[input_row_size]\n" + "smlal2 v24.4s, v2.8h, v12.8h\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "add x15, x14, %[input_row_size]\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v3.4h, v14.4h\n" + "smlal2 v24.4s, v3.8h, v14.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v4.4h, v15.4h\n" + "smlal2 v24.4s, v4.8h, v15.8h\n" + "smlal v21.4s, v5.4h, v15.4h\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "smlal2 v22.4s, v5.8h, v15.8h\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v5.4h, v16.4h\n" + "smlal2 v24.4s, v5.8h, v16.8h\n" + "smlal v21.4s, v6.4h, v17.4h\n" + "smlal2 v22.4s, v6.8h, v17.8h\n" + "ld1 {v17.8b}, [x14], %[input_depth]\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "smlal v21.4s, v7.4h, v18.4h\n" + "smlal2 v22.4s, v7.8h, v18.8h\n" + "ld1 {v18.8b}, [x14], %[input_depth]\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v19.4h\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + "smlal2 v22.4s, v8.8h, v19.8h\n" + "ld1 {v19.8b}, [x14], %[input_depth]\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "st1 {v21.8b}, [%[output_ptr]], x3\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "st1 {v23.8b}, [%[output_ptr]], x3\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + "uaddw v14.8h, v26.8h, v14.8b\n" + "uaddw v15.8h, v26.8h, v15.8b\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "uaddw v17.8h, v26.8h, v17.8b\n" + "uaddw v18.8h, v26.8h, v18.8b\n" + "uaddw v19.8h, v26.8h, v19.8b\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + + "bge " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "b\n" + + // At this point, there will be one of 2 width or 1 width leftover, + // not both. + "cmp w5, #2\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n" + + // Handle last two horizontal outputs if exists. + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v12.8b}, [x12], %[input_depth]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v16.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v0.4h, v10.4h\n" + "ld1 {v20.8b}, [x14], %[input_depth]\n" + "smlal2 v24.4s, v0.8h, v10.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v23.4s, v1.4h, v11.4h\n" + "smlal2 v24.4s, v1.8h, v11.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v23.4s, v2.4h, v12.4h\n" + "smlal2 v24.4s, v2.8h, v12.8h\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "smlal v23.4s, v3.4h, v14.4h\n" + "smlal2 v24.4s, v3.8h, v14.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "smlal v23.4s, v4.4h, v15.4h\n" + "smlal2 v24.4s, v4.8h, v15.8h\n" + "smlal v21.4s, v5.4h, v15.4h\n" + "uaddw v16.8h, v26.8h, v16.8b\n" + "smlal2 v22.4s, v5.8h, v15.8h\n" + "smlal v23.4s, v5.4h, v16.4h\n" + "smlal2 v24.4s, v5.8h, v16.8h\n" + "smlal v21.4s, v6.4h, v17.4h\n" + "smlal2 v22.4s, v6.8h, v17.8h\n" + "smlal v23.4s, v6.4h, v18.4h\n" + "smlal2 v24.4s, v6.8h, v18.8h\n" + "smlal v21.4s, v7.4h, v18.4h\n" + "smlal2 v22.4s, v7.8h, v18.8h\n" + "smlal v23.4s, v7.4h, v19.4h\n" + "smlal2 v24.4s, v7.8h, v19.8h\n" + "smlal v21.4s, v8.4h, v19.4h\n" + "uaddw v20.8h, v26.8h, v20.8b\n" + "smlal2 v22.4s, v8.8h, v19.8h\n" + "smlal v23.4s, v8.4h, v20.4h\n" + "smlal2 v24.4s, v8.8h, v20.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v25.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v25.4s, v25.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v25.4s\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w4\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w0\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtn2 v23.8h, v24.4s\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "st1 {v21.8b}, [%[output_ptr]], x3\n" + "st1 {v23.8b}, [%[output_ptr]], x3\n" + "b " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n" + + // Handle bottom right output if exists. + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "smlal v21.4s, v3.4h, v13.4h\n" + "smlal2 v22.4s, v3.8h, v13.8h\n" + "smlal v21.4s, v4.4h, v14.4h\n" + "smlal2 v22.4s, v4.8h, v14.8h\n" + "smlal v21.4s, v5.4h, v15.4h\n" + "smlal2 v22.4s, v5.8h, v15.8h\n" + "smlal v21.4s, v6.4h, v17.4h\n" + "smlal2 v22.4s, v6.8h, v17.8h\n" + "smlal v21.4s, v7.4h, v18.4h\n" + "smlal2 v22.4s, v7.8h, v18.8h\n" + "smlal v21.4s, v8.4h, v19.4h\n" + "smlal2 v22.4s, v8.8h, v19.8h\n" + + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "and v9.16b, v21.16b, v28.16b\n" + "and v12.16b, v22.16b, v28.16b\n" + "sshr v9.4s, v9.4s, #31\n" + "sshr v12.4s, v12.4s, #31\n" + "sqadd v21.4s, v21.4s, v9.4s\n" + "sqadd v22.4s, v22.4s, v12.4s\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtun v21.8b, v21.8h\n" + "st1 {v21.8b}, [%[output_ptr]]\n" + DEPTHWISECONV_LABEL_HEIGHT_1_END ":\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), + [output_window_height] "+r"(output_window_height) + : + // Inputs. + [bias_ptr] "r"(bias_ptr), [input_row_size] "r"(input_row_size), + [input_depth] "r"(input_depth), + [output_window_width] "r"(output_window_width), + [input_width_increment] "r"(input_width_increment), + [input_height_increment] "r"(input_height_increment), + [output_height_increment] "r"(output_height_increment), + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_HEIGHT_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_1 +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_1_END } }; template <> -struct ConvKernel3x3FilterDepth8<4, 4, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - // Reuse 4x2 kernel twice. - ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth, - output_width); - - ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run( - input_ptr + 4 * input_depth, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr + 2 * output_depth, output_depth, output_width); +struct DepthwiseConvWindow<8, 2, 2> { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, int64_t input_depth, + int64_t input_row_size, int32 output_window_height, + int32 output_window_width, + const DepthwiseConvParams* params_ptr) { + const int64_t input_width_increment = 4 * input_depth; + const int64_t input_height_increment = 4 * input_row_size; + const int64_t output_height_increment = 2 * params_ptr->output_row_size; + +#define DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "1" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "2" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "3" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "4" +#define DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "5" +#define DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "6" +#define DEPTHWISECONV_LABEL_HEIGHT_1 "7" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "8" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "9" +#define DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "10" +#define DEPTHWISECONV_LABEL_HEIGHT_1_END "11" + + asm volatile( + // Performs depthwise convolutions for a window specified by + // |output_window_height| and |output_window_width|. The inner-most loop + // processes 2x2 outputs, and any leftovers at the end. + // + // Algorithm works as follows: + // + // 1. Load filters of 8 depth (8x3x3). Registers v0--v8 hold filter + // values. + // 2. For 2 output heights at a time: + // i. For 2 output widths at a time at stride 2, a 5x5 input + // window is required. To avoid register exhaustion, we load + // the first 2 rows of the 5x5 input window into registers + // v9--v18, and use the same registers to load the next 2 + // rows, and finally v9--v13 to load the last row. + // Accumulators for all 2x2 outputs are reserved by registers + // v21-v22 (top left output), v23-v24 (top right output), + // v19-v20 (bottom left output), v25-v26 (bottom right + // output). + // ii. Handle single leftover width if exists. + // 3. Handle single leftover height if exists. + // i. For 2 output widths at a time at stride 2, load inputs for + // a 1x2 (1 height, 2 width) output window (3x5 input + // window). Registers v9--v24 hold input values. Mul-add with + // accumulators v24--v27. + // ii. Handle single leftover width if exists. + // + // Loads are placed as soon as the register is no longer needed and + // interleaved with arithmetic operations to take advantage of + // dual-issue pipelines. We also add input offsets as far from the loads + // as possible to give loads enough cycles to fetch data from memory. + + // Set "constant" registers. These registers may be replaced with temp + // values from time to time when there are not enough NEON registers. + // We use x9--x15 general purpose registers as they are caller-saved + // temporary registers (see http://infocenter.arm.com/help/topic/com.arm.doc.ihi0055b/IHI0055B_aapcs64.pdf). // NOLINT + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "ldr w0, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "cmp %w[output_window_height], #2\n" + "dup v28.8h, w0\n" + "neg w9, w9\n" + "ldr w1, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.4s, w9\n" + "ldr w2, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w1\n" + "ldr w3, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "dup v29.4s, w2\n" + "ldr w4, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w3\n" + "ldr x5, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "dup v31.4s, w4\n" + "ldr x19, [%[params_ptr], #" STR(OFFSET_OUTPUT_ROW_SIZE) "]\n" + "ldr w20, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + + // Load filters and add offsets. + "add x10, %[bias_ptr], #16\n" + "ld1 {v0.8b}, [%[filter_ptr]], x5\n" + "dup v9.8h, w20\n" + "ld1 {v1.8b}, [%[filter_ptr]], x5\n" + "uaddw v0.8h, v9.8h, v0.8b\n" + "ld1 {v2.8b}, [%[filter_ptr]], x5\n" + "uaddw v1.8h, v9.8h, v1.8b\n" + "ld1 {v3.8b}, [%[filter_ptr]], x5\n" + "uaddw v2.8h, v9.8h, v2.8b\n" + "ld1 {v4.8b}, [%[filter_ptr]], x5\n" + "uaddw v3.8h, v9.8h, v3.8b\n" + "ld1 {v5.8b}, [%[filter_ptr]], x5\n" + "uaddw v4.8h, v9.8h, v4.8b\n" + "ld1 {v6.8b}, [%[filter_ptr]], x5\n" + "uaddw v5.8h, v9.8h, v5.8b\n" + "ld1 {v7.8b}, [%[filter_ptr]], x5\n" + "uaddw v6.8h, v9.8h, v6.8b\n" + "ld1 {v8.8b}, [%[filter_ptr]]\n" + "uaddw v7.8h, v9.8h, v7.8b\n" + "uaddw v8.8h, v9.8h, v8.8b\n" + + "blt " DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_2_LOOP ":\n" + // Load the first two rows of the 5x5 input window, then reuse the + // same registers to load subsequent rows as they become available. + "mov x11, %[input_ptr]\n" + "mov x12, x11\n" + "add x13, x12, %[input_row_size]\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "mov w14, %w[output_window_width]\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + // The height 2 / width 2 loop loads an extra 1 output horizontally in + // anticipation for the next iteration. Make sure + // |output_window_width| is large enough to handle the additional + // load, otherwise jump to the appropriate label to handle smaller + // widths. + "cmp w14, #2\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "add x15, x13, %[input_row_size]\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "mov x6, %[output_ptr]\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + "add x7, %[output_ptr], x19\n" + "ld1 {v16.8b}, [x13], %[input_depth]\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "ld1 {v22.4s}, [x10]\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "ld1 {v24.4s}, [x10]\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "ld1 {v19.4s}, [%[bias_ptr]]\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "ld1 {v20.4s}, [x10]\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "ld1 {v25.4s}, [%[bias_ptr]]\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "ld1 {v26.4s}, [x10]\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + + "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER "f\n" + "cmp w14, #1\n" + "beq " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v12.8b}, [x12], %[input_depth]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v13.8b}, [x12]\n" + "add x12, x15, %[input_row_size]\n" + "smlal v23.4s, v0.4h, v11.4h\n" + "ld1 {v17.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v0.8h, v11.8h\n" + "ld1 {v18.8b}, [x13]\n" + "add x13, x12, %[input_row_size]\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "ld1 {v9.8b}, [x15], %[input_depth]\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "ld1 {v11.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v3.4h, v14.4h\n" + "smlal2 v22.4s, v3.8h, v14.8h\n" + "ld1 {v14.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v3.4h, v16.4h\n" + "subs w14, w14, #2\n" + "smlal2 v24.4s, v3.8h, v16.8h\n" + "cmp w14, #3\n" + "smlal v21.4s, v4.4h, v15.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v22.4s, v4.8h, v15.8h\n" + "ld1 {v15.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v5.4h, v16.4h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal2 v22.4s, v5.8h, v16.8h\n" + "ld1 {v16.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v1.4h, v12.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v24.4s, v1.8h, v12.8h\n" + "ld1 {v12.8b}, [x15], %[input_depth]\n" + "smlal v23.4s, v2.4h, v13.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v24.4s, v2.8h, v13.8h\n" + "ld1 {v13.8b}, [x15]\n" + "smlal v23.4s, v4.4h, v17.4h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "smlal2 v24.4s, v4.8h, v17.8h\n" + "ld1 {v17.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v5.4h, v18.4h\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "smlal2 v24.4s, v5.8h, v18.8h\n" + "ld1 {v18.8b}, [x12]\n" + + "smlal v21.4s, v6.4h, v9.4h\n" + "smlal2 v22.4s, v6.8h, v9.8h\n" + "smlal v19.4s, v0.4h, v9.4h\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "smlal2 v20.4s, v0.8h, v9.8h\n" + "ld1 {v9.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v6.4h, v11.4h\n" + "smlal2 v24.4s, v6.8h, v11.8h\n" + "smlal v21.4s, v7.4h, v10.4h\n" + "smlal2 v22.4s, v7.8h, v10.8h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal v19.4s, v1.4h, v10.4h\n" + "smlal2 v20.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v7.4h, v12.4h\n" + "smlal2 v24.4s, v7.8h, v12.8h\n" + "smlal v25.4s, v1.4h, v12.4h\n" + "smlal2 v26.4s, v1.8h, v12.8h\n" + "smlal v21.4s, v8.4h, v11.4h\n" + "smlal2 v22.4s, v8.8h, v11.8h\n" + "add x11, x11, %[input_width_increment]\n" + "smlal v19.4s, v2.4h, v11.4h\n" + "mov x12, x11\n" + "smlal2 v20.4s, v2.8h, v11.8h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal v25.4s, v0.4h, v11.4h\n" + "smlal2 v26.4s, v0.8h, v11.8h\n" + "ld1 {v11.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v8.4h, v13.4h\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v8.8h, v13.8h\n" + "smlal v25.4s, v2.4h, v13.4h\n" + "smlal2 v26.4s, v2.8h, v13.8h\n" + "ld1 {v13.8b}, [x13]\n" + "add x13, x12, %[input_row_size]\n" + "add x15, x13, %[input_row_size]\n" + + "dup v28.4s, w9\n" + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v27.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v27.4s, v27.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v27.4s\n" + "dup v27.4s, w1\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "dup v28.8h, w0\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x5\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "st1 {v23.8b}, [x6], x5\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + + "smlal v19.4s, v6.4h, v9.4h\n" + "smlal2 v20.4s, v6.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal v25.4s, v6.4h, v11.4h\n" + "smlal2 v26.4s, v6.8h, v11.8h\n" + "smlal v19.4s, v7.4h, v10.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v20.4s, v7.8h, v10.8h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "smlal v25.4s, v7.4h, v12.4h\n" + "smlal2 v26.4s, v7.8h, v12.8h\n" + "smlal v19.4s, v8.4h, v11.4h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal2 v20.4s, v8.8h, v11.8h\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "smlal v25.4s, v8.4h, v13.4h\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "smlal2 v26.4s, v8.8h, v13.8h\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "smlal v19.4s, v3.4h, v14.4h\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "smlal2 v20.4s, v3.8h, v14.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v25.4s, v3.4h, v16.4h\n" + "ld1 {v21.4s}, [%[bias_ptr]]\n" + "smlal2 v26.4s, v3.8h, v16.8h\n" + "ld1 {v23.4s}, [%[bias_ptr]]\n" + "smlal v19.4s, v4.4h, v15.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v20.4s, v4.8h, v15.8h\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + "smlal v25.4s, v4.4h, v17.4h\n" + "smlal2 v26.4s, v4.8h, v17.8h\n" + "smlal v19.4s, v5.4h, v16.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v20.4s, v5.8h, v16.8h\n" + "ld1 {v16.8b}, [x13], %[input_depth]\n" + "smlal v25.4s, v5.4h, v18.4h\n" + "smlal2 v26.4s, v5.8h, v18.8h\n" + + "dup v28.4s, w9\n" + "sqrdmulh v19.4s, v19.4s, v27.4s\n" + "sqrdmulh v20.4s, v20.4s, v27.4s\n" + "sqrdmulh v25.4s, v25.4s, v27.4s\n" + "sqrdmulh v26.4s, v26.4s, v27.4s\n" + "and v27.16b, v19.16b, v28.16b\n" + "and v29.16b, v20.16b, v28.16b\n" + "and v30.16b, v25.16b, v28.16b\n" + "and v31.16b, v26.16b, v28.16b\n" + "sshr v27.4s, v27.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v19.4s, v19.4s, v27.4s\n" + "dup v27.4s, w1\n" + "sqadd v20.4s, v20.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v25.4s, v25.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v26.4s, v26.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v19.4s, v19.4s, v28.4s\n" + "srshl v20.4s, v20.4s, v28.4s\n" + "srshl v25.4s, v25.4s, v28.4s\n" + "srshl v26.4s, v26.4s, v28.4s\n" + "dup v28.8h, w0\n" + "add v19.4s, v19.4s, v29.4s\n" + "add v20.4s, v20.4s, v29.4s\n" + "add v25.4s, v25.4s, v29.4s\n" + "add v26.4s, v26.4s, v29.4s\n" + "smax v19.4s, v19.4s, v30.4s\n" + "smax v20.4s, v20.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smax v26.4s, v26.4s, v30.4s\n" + "smin v19.4s, v19.4s, v31.4s\n" + "smin v20.4s, v20.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "smin v26.4s, v26.4s, v31.4s\n" + "sqxtn v19.4h, v19.4s\n" + "sqxtn v25.4h, v25.4s\n" + "sqxtn2 v19.8h, v20.4s\n" + "ld1 {v20.4s}, [x10]\n" + "sqxtn2 v25.8h, v26.4s\n" + "ld1 {v26.4s}, [x10]\n" + "sqxtun v19.8b, v19.8h\n" + "sqxtun v25.8b, v25.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v19.8b}, [x7], x5\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "st1 {v25.8b}, [x7], x5\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "ld1 {v19.4s}, [%[bias_ptr]]\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "ld1 {v25.4s}, [%[bias_ptr]]\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + + "bge " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP "b\n" + + // At this point, there will be one of 2 width or 1 width leftover, + // not both. + "cmp w14, #2\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER "f\n" + + // Handle last 2 columns if exists. + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER ":\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v12.8b}, [x12], %[input_depth]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v13.8b}, [x12]\n" + "add x12, x15, %[input_row_size]\n" + "smlal v23.4s, v0.4h, v11.4h\n" + "ld1 {v17.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v0.8h, v11.8h\n" + "ld1 {v18.8b}, [x13]\n" + "add x13, x12, %[input_row_size]\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "ld1 {v9.8b}, [x15], %[input_depth]\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "ld1 {v11.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v3.4h, v14.4h\n" + "smlal2 v22.4s, v3.8h, v14.8h\n" + "ld1 {v14.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v3.4h, v16.4h\n" + "smlal2 v24.4s, v3.8h, v16.8h\n" + "smlal v21.4s, v4.4h, v15.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v22.4s, v4.8h, v15.8h\n" + "ld1 {v15.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v5.4h, v16.4h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal2 v22.4s, v5.8h, v16.8h\n" + "ld1 {v16.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v1.4h, v12.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v24.4s, v1.8h, v12.8h\n" + "ld1 {v12.8b}, [x15], %[input_depth]\n" + "smlal v23.4s, v2.4h, v13.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v24.4s, v2.8h, v13.8h\n" + "ld1 {v13.8b}, [x15]\n" + "smlal v23.4s, v4.4h, v17.4h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "smlal2 v24.4s, v4.8h, v17.8h\n" + "ld1 {v17.8b}, [x12], %[input_depth]\n" + "smlal v23.4s, v5.4h, v18.4h\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "smlal2 v24.4s, v5.8h, v18.8h\n" + "ld1 {v18.8b}, [x12]\n" + + "smlal v21.4s, v6.4h, v9.4h\n" + "smlal2 v22.4s, v6.8h, v9.8h\n" + "smlal v19.4s, v0.4h, v9.4h\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "smlal2 v20.4s, v0.8h, v9.8h\n" + "ld1 {v9.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v6.4h, v11.4h\n" + "smlal2 v24.4s, v6.8h, v11.8h\n" + "smlal v21.4s, v7.4h, v10.4h\n" + "smlal2 v22.4s, v7.8h, v10.8h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal v19.4s, v1.4h, v10.4h\n" + "smlal2 v20.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v7.4h, v12.4h\n" + "smlal2 v24.4s, v7.8h, v12.8h\n" + "smlal v25.4s, v1.4h, v12.4h\n" + "smlal2 v26.4s, v1.8h, v12.8h\n" + "smlal v21.4s, v8.4h, v11.4h\n" + "smlal2 v22.4s, v8.8h, v11.8h\n" + "smlal v19.4s, v2.4h, v11.4h\n" + "smlal2 v20.4s, v2.8h, v11.8h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal v25.4s, v0.4h, v11.4h\n" + "smlal2 v26.4s, v0.8h, v11.8h\n" + "ld1 {v11.8b}, [x13], %[input_depth]\n" + "smlal v23.4s, v8.4h, v13.4h\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "smlal2 v24.4s, v8.8h, v13.8h\n" + "smlal v25.4s, v2.4h, v13.4h\n" + "smlal2 v26.4s, v2.8h, v13.8h\n" + "ld1 {v13.8b}, [x13]\n" + + "dup v28.4s, w9\n" + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v27.16b, v21.16b, v28.16b\n" + "and v29.16b, v22.16b, v28.16b\n" + "and v30.16b, v23.16b, v28.16b\n" + "and v31.16b, v24.16b, v28.16b\n" + "sshr v27.4s, v27.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v21.4s, v21.4s, v27.4s\n" + "dup v27.4s, w1\n" + "sqadd v22.4s, v22.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v23.4s, v23.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v24.4s, v24.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v21.4s, v21.4s, v28.4s\n" + "srshl v22.4s, v22.4s, v28.4s\n" + "srshl v23.4s, v23.4s, v28.4s\n" + "srshl v24.4s, v24.4s, v28.4s\n" + "dup v28.8h, w0\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "ld1 {v22.4s}, [x10]\n" + "sqxtn2 v23.8h, v24.4s\n" + "ld1 {v24.4s}, [x10]\n" + "sqxtun v21.8b, v21.8h\n" + "sqxtun v23.8b, v23.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v21.8b}, [x6], x5\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "st1 {v23.8b}, [x6]\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + + "smlal v19.4s, v6.4h, v9.4h\n" + "smlal2 v20.4s, v6.8h, v9.8h\n" + "smlal v25.4s, v6.4h, v11.4h\n" + "smlal2 v26.4s, v6.8h, v11.8h\n" + "smlal v19.4s, v7.4h, v10.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v20.4s, v7.8h, v10.8h\n" + "smlal v25.4s, v7.4h, v12.4h\n" + "smlal2 v26.4s, v7.8h, v12.8h\n" + "smlal v19.4s, v8.4h, v11.4h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "smlal2 v20.4s, v8.8h, v11.8h\n" + "smlal v25.4s, v8.4h, v13.4h\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "smlal2 v26.4s, v8.8h, v13.8h\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "smlal v19.4s, v3.4h, v14.4h\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "smlal2 v20.4s, v3.8h, v14.8h\n" + "smlal v25.4s, v3.4h, v16.4h\n" + "smlal2 v26.4s, v3.8h, v16.8h\n" + "smlal v19.4s, v4.4h, v15.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v20.4s, v4.8h, v15.8h\n" + "smlal v25.4s, v4.4h, v17.4h\n" + "smlal2 v26.4s, v4.8h, v17.8h\n" + "smlal v19.4s, v5.4h, v16.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v20.4s, v5.8h, v16.8h\n" + "smlal v25.4s, v5.4h, v18.4h\n" + "smlal2 v26.4s, v5.8h, v18.8h\n" + + "dup v28.4s, w9\n" + "sqrdmulh v19.4s, v19.4s, v27.4s\n" + "sqrdmulh v20.4s, v20.4s, v27.4s\n" + "sqrdmulh v25.4s, v25.4s, v27.4s\n" + "sqrdmulh v26.4s, v26.4s, v27.4s\n" + "and v27.16b, v19.16b, v28.16b\n" + "and v29.16b, v20.16b, v28.16b\n" + "and v30.16b, v25.16b, v28.16b\n" + "and v31.16b, v26.16b, v28.16b\n" + "sshr v27.4s, v27.4s, #31\n" + "sshr v29.4s, v29.4s, #31\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v19.4s, v19.4s, v27.4s\n" + "dup v27.4s, w1\n" + "sqadd v20.4s, v20.4s, v29.4s\n" + "dup v29.4s, w2\n" + "sqadd v25.4s, v25.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v26.4s, v26.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v19.4s, v19.4s, v28.4s\n" + "srshl v20.4s, v20.4s, v28.4s\n" + "srshl v25.4s, v25.4s, v28.4s\n" + "srshl v26.4s, v26.4s, v28.4s\n" + "dup v28.8h, w0\n" + "add v19.4s, v19.4s, v29.4s\n" + "add v20.4s, v20.4s, v29.4s\n" + "add v25.4s, v25.4s, v29.4s\n" + "add v26.4s, v26.4s, v29.4s\n" + "smax v19.4s, v19.4s, v30.4s\n" + "smax v20.4s, v20.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smax v26.4s, v26.4s, v30.4s\n" + "smin v19.4s, v19.4s, v31.4s\n" + "smin v20.4s, v20.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "smin v26.4s, v26.4s, v31.4s\n" + "sqxtn v19.4h, v19.4s\n" + "sqxtn v25.4h, v25.4s\n" + "sqxtn2 v19.8h, v20.4s\n" + "sqxtn2 v25.8h, v26.4s\n" + "sqxtun v19.8b, v19.8h\n" + "sqxtun v25.8b, v25.8h\n" + "st1 {v19.8b}, [x7], x5\n" + "st1 {v25.8b}, [x7]\n" + "b " DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP "f\n" + + // Handle last column if exists. + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER ":\n" + // Registers v9, v10, v11, v14, v15, and v16 have already been loaded + // with the correct values at this point. This corresponds to the + // first two input rows of the top left output. Now load the last + // input row for this output. Once these inputs are no longer needed, + // load the input rows for the bottom left output. + "add x12, x15, %[input_row_size]\n" + "add x13, x12, %[input_row_size]\n" + + "ld1 {v12.8b}, [x15], %[input_depth]\n" + "smlal v21.4s, v0.4h, v9.4h\n" + "ld1 {v13.8b}, [x15], %[input_depth]\n" + "smlal2 v22.4s, v0.8h, v9.8h\n" + "ld1 {v17.8b}, [x15]\n" + "smlal v21.4s, v1.4h, v10.4h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal2 v22.4s, v1.8h, v10.8h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "smlal v21.4s, v2.4h, v11.4h\n" + "smlal2 v22.4s, v2.8h, v11.8h\n" + "ld1 {v11.8b}, [x12]\n" + "smlal v21.4s, v3.4h, v14.4h\n" + "smlal2 v22.4s, v3.8h, v14.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v21.4s, v4.4h, v15.4h\n" + "smlal2 v22.4s, v4.8h, v15.8h\n" + "ld1 {v15.8b}, [x13], %[input_depth]\n" + "smlal v21.4s, v5.4h, v16.4h\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "smlal2 v22.4s, v5.8h, v16.8h\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "ld1 {v16.8b}, [x13]\n" + + "smlal v21.4s, v6.4h, v12.4h\n" + "smlal2 v22.4s, v6.8h, v12.8h\n" + "smlal v23.4s, v0.4h, v12.4h\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + "smlal2 v24.4s, v0.8h, v12.8h\n" + "smlal v21.4s, v7.4h, v13.4h\n" + "smlal2 v22.4s, v7.8h, v13.8h\n" + "smlal v23.4s, v1.4h, v13.4h\n" + "smlal2 v24.4s, v1.8h, v13.8h\n" + "smlal v21.4s, v8.4h, v17.4h\n" + "smlal2 v22.4s, v8.8h, v17.8h\n" + "smlal v23.4s, v2.4h, v17.4h\n" + "smlal2 v24.4s, v2.8h, v17.8h\n" + + "dup v26.4s, w9\n" + "sqrdmulh v21.4s, v21.4s, v27.4s\n" + "sqrdmulh v22.4s, v22.4s, v27.4s\n" + "and v18.16b, v21.16b, v26.16b\n" + "and v19.16b, v22.16b, v26.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v21.4s, v21.4s, v18.4s\n" + "sqadd v22.4s, v22.4s, v19.4s\n" + "srshl v21.4s, v21.4s, v26.4s\n" + "srshl v22.4s, v22.4s, v26.4s\n" + "add v21.4s, v21.4s, v29.4s\n" + "add v22.4s, v22.4s, v29.4s\n" + "smax v21.4s, v21.4s, v30.4s\n" + "smax v22.4s, v22.4s, v30.4s\n" + "smin v21.4s, v21.4s, v31.4s\n" + "smin v22.4s, v22.4s, v31.4s\n" + "sqxtn v21.4h, v21.4s\n" + "sqxtn2 v21.8h, v22.4s\n" + "sqxtun v21.8b, v21.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v21.8b}, [x6]\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + + "smlal v23.4s, v3.4h, v9.4h\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "smlal2 v24.4s, v3.8h, v9.8h\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "smlal v23.4s, v4.4h, v10.4h\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "smlal2 v24.4s, v4.8h, v10.8h\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "smlal v23.4s, v5.4h, v11.4h\n" + "smlal2 v24.4s, v5.8h, v11.8h\n" + + "smlal v23.4s, v6.4h, v14.4h\n" + "smlal2 v24.4s, v6.8h, v14.8h\n" + "smlal v23.4s, v7.4h, v15.4h\n" + "smlal2 v24.4s, v7.8h, v15.8h\n" + "smlal v23.4s, v8.4h, v16.4h\n" + "smlal2 v24.4s, v8.8h, v16.8h\n" + + "sqrdmulh v23.4s, v23.4s, v27.4s\n" + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "and v18.16b, v23.16b, v26.16b\n" + "and v19.16b, v24.16b, v26.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v23.4s, v23.4s, v18.4s\n" + "sqadd v24.4s, v24.4s, v19.4s\n" + "srshl v23.4s, v23.4s, v26.4s\n" + "srshl v24.4s, v24.4s, v26.4s\n" + "add v23.4s, v23.4s, v29.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "smax v23.4s, v23.4s, v30.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smin v23.4s, v23.4s, v31.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "sqxtn v23.4h, v23.4s\n" + "sqxtn2 v23.8h, v24.4s\n" + "sqxtun v23.8b, v23.8h\n" + "st1 {v23.8b}, [x7]\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP ":\n" + "subs %w[output_window_height], %w[output_window_height], #2\n" + "add %[input_ptr], %[input_ptr], %[input_height_increment]\n" + "cmp %w[output_window_height], #2\n" + "add %[output_ptr], %[output_ptr], %[output_height_increment]\n" + "bge " DEPTHWISECONV_LABEL_HEIGHT_2_LOOP "b\n" + + DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP ":\n" + "cmp %w[output_window_height], #1\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n" + + DEPTHWISECONV_LABEL_HEIGHT_1 ":\n" + "mov x11, %[input_ptr]\n" + "mov x12, x11\n" + "add x13, x12, %[input_row_size]\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "add x15, x13, %[input_row_size]\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "mov x6, %[output_ptr]\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "mov w14, %w[output_window_width]\n" + // The height 1 / width 2 loop loads an extra 1x1 output in anticipation + // for the next iteration. Make sure |output_window_width| is large + // enough to handle the additional load, otherwise jump to the + // appropriate label to handle smaller widths. + "cmp w14, #2\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "ld1 {v15.8b}, [x15], %[input_depth]\n" + "ld1 {v16.8b}, [x15], %[input_depth]\n" + "ld1 {v17.8b}, [x15], %[input_depth]\n" + + "uaddw v9.8h, v28.8h, v9.8b\n" + "ld1 {v24.4s}, [%[bias_ptr]]\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "ld1 {v25.4s}, [x10]\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "ld1 {v26.4s}, [%[bias_ptr]]\n" + "ld1 {v27.4s}, [x10]\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + + "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER "f\n" + "cmp w14, #1\n" + "beq " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP ":\n" + "smlal v24.4s, v0.4h, v9.4h\n" + "ld1 {v18.8b}, [x12], %[input_depth]\n" + "smlal2 v25.4s, v0.8h, v9.8h\n" + "ld1 {v19.8b}, [x12]\n" + "smlal v26.4s, v0.4h, v11.4h\n" + "ld1 {v20.8b}, [x13], %[input_depth]\n" + "smlal2 v27.4s, v0.8h, v11.8h\n" + "ld1 {v21.8b}, [x13]\n" + "smlal v24.4s, v1.4h, v10.4h\n" + "ld1 {v22.8b}, [x15], %[input_depth]\n" + "smlal2 v25.4s, v1.8h, v10.8h\n" + "ld1 {v23.8b}, [x15]\n" + "smlal v24.4s, v2.4h, v11.4h\n" + "subs w14, w14, #2\n" + "smlal2 v25.4s, v2.8h, v11.8h\n" + "cmp w14, #3\n" + "smlal v24.4s, v3.4h, v12.4h\n" + "add x11, x11, %[input_width_increment]\n" + "smlal2 v25.4s, v3.8h, v12.8h\n" + "mov x12, x11\n" + "smlal v26.4s, v3.4h, v14.4h\n" + "add x13, x12, %[input_row_size]\n" + "smlal2 v27.4s, v3.8h, v14.8h\n" + "add x15, x13, %[input_row_size]\n" + "smlal v24.4s, v4.4h, v13.4h\n" + "ld1 {v9.8b}, [x12], %[input_depth]\n" + "smlal2 v25.4s, v4.8h, v13.8h\n" + "ld1 {v10.8b}, [x12], %[input_depth]\n" + "smlal v24.4s, v5.4h, v14.4h\n" + "ld1 {v11.8b}, [x12], %[input_depth]\n" + "smlal2 v25.4s, v5.8h, v14.8h\n" + "ld1 {v12.8b}, [x13], %[input_depth]\n" + "smlal v24.4s, v6.4h, v15.4h\n" + "ld1 {v13.8b}, [x13], %[input_depth]\n" + "smlal2 v25.4s, v6.8h, v15.8h\n" + "ld1 {v14.8b}, [x13], %[input_depth]\n" + "smlal v26.4s, v6.4h, v17.4h\n" + "ld1 {v15.8b}, [x15], %[input_depth]\n" + "smlal2 v27.4s, v6.8h, v17.8h\n" + "smlal v24.4s, v7.4h, v16.4h\n" + "smlal2 v25.4s, v7.8h, v16.8h\n" + "ld1 {v16.8b}, [x15], %[input_depth]\n" + "smlal v24.4s, v8.4h, v17.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v25.4s, v8.8h, v17.8h\n" + "ld1 {v17.8b}, [x15], %[input_depth]\n" + "uaddw v19.8h, v28.8h, v19.8b\n" + + "smlal v26.4s, v1.4h, v18.4h\n" + "uaddw v20.8h, v28.8h, v20.8b\n" + "smlal2 v27.4s, v1.8h, v18.8h\n" + "smlal v26.4s, v2.4h, v19.4h\n" + "uaddw v21.8h, v28.8h, v21.8b\n" + "smlal2 v27.4s, v2.8h, v19.8h\n" + "smlal v26.4s, v4.4h, v20.4h\n" + "smlal v26.4s, v5.4h, v21.4h\n" + "smlal2 v27.4s, v4.8h, v20.8h\n" + "uaddw v22.8h, v28.8h, v22.8b\n" + "smlal2 v27.4s, v5.8h, v21.8h\n" + "uaddw v23.8h, v28.8h, v23.8b\n" + "smlal v26.4s, v7.4h, v22.4h\n" + "smlal2 v27.4s, v7.8h, v22.8h\n" + "smlal v26.4s, v8.4h, v23.4h\n" + "smlal2 v27.4s, v8.8h, v23.8h\n" + + "dup v28.4s, w1\n" + "dup v29.4s, w9\n" + "sqrdmulh v24.4s, v24.4s, v28.4s\n" + "sqrdmulh v25.4s, v25.4s, v28.4s\n" + "sqrdmulh v26.4s, v26.4s, v28.4s\n" + "sqrdmulh v27.4s, v27.4s, v28.4s\n" + "dup v28.4s, w2\n" + "and v30.16b, v24.16b, v29.16b\n" + "and v31.16b, v25.16b, v29.16b\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v24.4s, v24.4s, v30.4s\n" + "sqadd v25.4s, v25.4s, v31.4s\n" + "and v30.16b, v26.16b, v29.16b\n" + "and v31.16b, v27.16b, v29.16b\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v26.4s, v26.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v27.4s, v27.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v24.4s, v24.4s, v29.4s\n" + "srshl v25.4s, v25.4s, v29.4s\n" + "srshl v26.4s, v26.4s, v29.4s\n" + "srshl v27.4s, v27.4s, v29.4s\n" + "add v24.4s, v24.4s, v28.4s\n" + "add v25.4s, v25.4s, v28.4s\n" + "add v26.4s, v26.4s, v28.4s\n" + "add v27.4s, v27.4s, v28.4s\n" + "dup v28.8h, w0\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smax v26.4s, v26.4s, v30.4s\n" + "smax v27.4s, v27.4s, v30.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "smin v26.4s, v26.4s, v31.4s\n" + "smin v27.4s, v27.4s, v31.4s\n" + "sqxtn v24.4h, v24.4s\n" + "sqxtn v26.4h, v26.4s\n" + "sqxtn2 v24.8h, v25.4s\n" + "ld1 {v25.4s}, [x10]\n" + "sqxtn2 v26.8h, v27.4s\n" + "ld1 {v27.4s}, [x10]\n" + "sqxtun v24.8b, v24.8h\n" + "sqxtun v26.8b, v26.8h\n" + "uaddw v9.8h, v28.8h, v9.8b\n" + "st1 {v24.8b}, [x6], x5\n" + "uaddw v10.8h, v28.8h, v10.8b\n" + "st1 {v26.8b}, [x6], x5\n" + "uaddw v11.8h, v28.8h, v11.8b\n" + "uaddw v12.8h, v28.8h, v12.8b\n" + "uaddw v13.8h, v28.8h, v13.8b\n" + "uaddw v14.8h, v28.8h, v14.8b\n" + "ld1 {v24.4s}, [%[bias_ptr]]\n" + "uaddw v15.8h, v28.8h, v15.8b\n" + "ld1 {v26.4s}, [%[bias_ptr]]\n" + "uaddw v16.8h, v28.8h, v16.8b\n" + "uaddw v17.8h, v28.8h, v17.8b\n" + + "bge " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP "b\n" + + // At this point, there will be one of 2 width or 1 width leftover, + // not both. + "cmp w14, #2\n" + "blt " DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER "f\n" + + // Handle last two horizontal outputs if exists. + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER ":\n" + "smlal v24.4s, v0.4h, v9.4h\n" + "ld1 {v18.8b}, [x12], %[input_depth]\n" + "smlal2 v25.4s, v0.8h, v9.8h\n" + "ld1 {v19.8b}, [x12]\n" + "smlal v26.4s, v0.4h, v11.4h\n" + "ld1 {v20.8b}, [x13], %[input_depth]\n" + "smlal2 v27.4s, v0.8h, v11.8h\n" + "ld1 {v21.8b}, [x13]\n" + "smlal v24.4s, v1.4h, v10.4h\n" + "ld1 {v22.8b}, [x15], %[input_depth]\n" + "smlal2 v25.4s, v1.8h, v10.8h\n" + "ld1 {v23.8b}, [x15]\n" + "smlal v24.4s, v2.4h, v11.4h\n" + "smlal2 v25.4s, v2.8h, v11.8h\n" + "smlal v24.4s, v3.4h, v12.4h\n" + "smlal2 v25.4s, v3.8h, v12.8h\n" + "smlal v26.4s, v3.4h, v14.4h\n" + "smlal2 v27.4s, v3.8h, v14.8h\n" + "smlal v24.4s, v4.4h, v13.4h\n" + "smlal2 v25.4s, v4.8h, v13.8h\n" + "smlal v24.4s, v5.4h, v14.4h\n" + "smlal2 v25.4s, v5.8h, v14.8h\n" + "smlal v24.4s, v6.4h, v15.4h\n" + "smlal2 v25.4s, v6.8h, v15.8h\n" + "smlal v26.4s, v6.4h, v17.4h\n" + "smlal2 v27.4s, v6.8h, v17.8h\n" + "smlal v24.4s, v7.4h, v16.4h\n" + "smlal2 v25.4s, v7.8h, v16.8h\n" + "smlal v24.4s, v8.4h, v17.4h\n" + "uaddw v18.8h, v28.8h, v18.8b\n" + "smlal2 v25.4s, v8.8h, v17.8h\n" + "uaddw v19.8h, v28.8h, v19.8b\n" + + "smlal v26.4s, v1.4h, v18.4h\n" + "uaddw v20.8h, v28.8h, v20.8b\n" + "smlal2 v27.4s, v1.8h, v18.8h\n" + "smlal v26.4s, v2.4h, v19.4h\n" + "uaddw v21.8h, v28.8h, v21.8b\n" + "smlal2 v27.4s, v2.8h, v19.8h\n" + "smlal v26.4s, v4.4h, v20.4h\n" + "smlal v26.4s, v5.4h, v21.4h\n" + "smlal2 v27.4s, v4.8h, v20.8h\n" + "uaddw v22.8h, v28.8h, v22.8b\n" + "smlal2 v27.4s, v5.8h, v21.8h\n" + "uaddw v23.8h, v28.8h, v23.8b\n" + "smlal v26.4s, v7.4h, v22.4h\n" + "smlal2 v27.4s, v7.8h, v22.8h\n" + "smlal v26.4s, v8.4h, v23.4h\n" + "smlal2 v27.4s, v8.8h, v23.8h\n" + + "dup v28.4s, w1\n" + "dup v29.4s, w9\n" + "sqrdmulh v24.4s, v24.4s, v28.4s\n" + "sqrdmulh v25.4s, v25.4s, v28.4s\n" + "sqrdmulh v26.4s, v26.4s, v28.4s\n" + "sqrdmulh v27.4s, v27.4s, v28.4s\n" + "dup v28.4s, w2\n" + "and v30.16b, v24.16b, v29.16b\n" + "and v31.16b, v25.16b, v29.16b\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v24.4s, v24.4s, v30.4s\n" + "sqadd v25.4s, v25.4s, v31.4s\n" + "and v30.16b, v26.16b, v29.16b\n" + "and v31.16b, v27.16b, v29.16b\n" + "sshr v30.4s, v30.4s, #31\n" + "sshr v31.4s, v31.4s, #31\n" + "sqadd v26.4s, v26.4s, v30.4s\n" + "dup v30.4s, w3\n" + "sqadd v27.4s, v27.4s, v31.4s\n" + "dup v31.4s, w4\n" + "srshl v24.4s, v24.4s, v29.4s\n" + "srshl v25.4s, v25.4s, v29.4s\n" + "srshl v26.4s, v26.4s, v29.4s\n" + "srshl v27.4s, v27.4s, v29.4s\n" + "add v24.4s, v24.4s, v28.4s\n" + "add v25.4s, v25.4s, v28.4s\n" + "add v26.4s, v26.4s, v28.4s\n" + "add v27.4s, v27.4s, v28.4s\n" + "dup v28.8h, w0\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smax v26.4s, v26.4s, v30.4s\n" + "smax v27.4s, v27.4s, v30.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "smin v26.4s, v26.4s, v31.4s\n" + "smin v27.4s, v27.4s, v31.4s\n" + "sqxtn v24.4h, v24.4s\n" + "sqxtn v26.4h, v26.4s\n" + "sqxtn2 v24.8h, v25.4s\n" + "sqxtn2 v26.8h, v27.4s\n" + "sqxtun v24.8b, v24.8h\n" + "sqxtun v26.8b, v26.8h\n" + "st1 {v24.8b}, [x6], x5\n" + "st1 {v26.8b}, [x6]\n" + "b " DEPTHWISECONV_LABEL_HEIGHT_1_END "f\n" + + // Handle bottom right output if exists. + DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER ":\n" + "dup v26.4s, w9\n" + "dup v27.4s, w1\n" + "dup v29.4s, w2\n" + + "smlal v24.4s, v0.4h, v9.4h\n" + "smlal2 v25.4s, v0.8h, v9.8h\n" + "smlal v24.4s, v1.4h, v10.4h\n" + "smlal2 v25.4s, v1.8h, v10.8h\n" + "smlal v24.4s, v2.4h, v11.4h\n" + "smlal2 v25.4s, v2.8h, v11.8h\n" + "smlal v24.4s, v3.4h, v12.4h\n" + "smlal2 v25.4s, v3.8h, v12.8h\n" + "smlal v24.4s, v4.4h, v13.4h\n" + "smlal2 v25.4s, v4.8h, v13.8h\n" + "smlal v24.4s, v5.4h, v14.4h\n" + "smlal2 v25.4s, v5.8h, v14.8h\n" + "smlal v24.4s, v6.4h, v15.4h\n" + "smlal2 v25.4s, v6.8h, v15.8h\n" + "smlal v24.4s, v7.4h, v16.4h\n" + "smlal2 v25.4s, v7.8h, v16.8h\n" + "smlal v24.4s, v8.4h, v17.4h\n" + "smlal2 v25.4s, v8.8h, v17.8h\n" + + "sqrdmulh v24.4s, v24.4s, v27.4s\n" + "sqrdmulh v25.4s, v25.4s, v27.4s\n" + "and v18.16b, v24.16b, v26.16b\n" + "and v19.16b, v25.16b, v26.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v24.4s, v24.4s, v18.4s\n" + "sqadd v25.4s, v25.4s, v19.4s\n" + "srshl v24.4s, v24.4s, v26.4s\n" + "srshl v25.4s, v25.4s, v26.4s\n" + "add v24.4s, v24.4s, v29.4s\n" + "add v25.4s, v25.4s, v29.4s\n" + "smax v24.4s, v24.4s, v30.4s\n" + "smax v25.4s, v25.4s, v30.4s\n" + "smin v24.4s, v24.4s, v31.4s\n" + "smin v25.4s, v25.4s, v31.4s\n" + "sqxtn v24.4h, v24.4s\n" + "sqxtn2 v24.8h, v25.4s\n" + "sqxtun v24.8b, v24.8h\n" + "st1 {v24.8b}, [x6]\n" + + DEPTHWISECONV_LABEL_HEIGHT_1_END ":\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), + [output_window_height] "+r"(output_window_height) + : + // Inputs. + [bias_ptr] "r"(bias_ptr), [input_row_size] "r"(input_row_size), + [input_depth] "r"(input_depth), + [output_window_width] "r"(output_window_width), + [input_width_increment] "r"(input_width_increment), + [input_height_increment] "r"(input_height_increment), + [output_height_increment] "r"(output_height_increment), + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", + "x9", "x10", "x11", "x12", "x13", "x14", "x15", + "x19", "x20"); +#undef DEPTHWISECONV_LABEL_HEIGHT_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_1_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_2_WIDTH_2_AFTER_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_2_AFTER_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_1 +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LOOP +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_1_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_1_WIDTH_2_LEFTOVER +#undef DEPTHWISECONV_LABEL_HEIGHT_1_END } }; -template <> -struct ConvKernel3x3FilterDepth8<4, 1, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - const int output_row_size = output_depth * output_width; - - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, - temp_8; - - const uint8* ptr = input_ptr; - - // Load all inputs for top output. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Second output. - output_ptr += output_row_size; - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - - DotProductAndStore( - filter, input_6, input_7, input_8, input_0, input_1, input_2, input_3, - input_4, input_5, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Third output. - output_ptr += output_row_size; - - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - - DotProductAndStore( - filter, input_3, input_4, input_5, input_6, input_7, input_8, input_0, - input_1, input_2, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Fourth output. - output_ptr += output_row_size; - - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - } -}; - -template <> -struct ConvKernel3x3FilterDepth8<2, 2, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - Int32x8 acc_0, acc_1, acc_2, acc_3; - acc_0.low = vld1q_s32(bias_ptr); - acc_1.low = vld1q_s32(bias_ptr); - acc_2.low = vld1q_s32(bias_ptr); - acc_3.low = vld1q_s32(bias_ptr); - - bias_ptr += 4; - acc_0.high = vld1q_s32(bias_ptr); - acc_1.high = vld1q_s32(bias_ptr); - acc_2.high = vld1q_s32(bias_ptr); - acc_3.high = vld1q_s32(bias_ptr); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - - // Add scope for input registers to help the compiler know that it is - // not needed. - { - // To process 2x2 outputs using a 3x3 filter at stride 2, we require - // 5x5 inputs. We load the first 5x2 inputs at a time. - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, input_9; - - const uint8* ptr = input_ptr; - - // Load inputs. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4; - - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load next inputs. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4; - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_9 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_9 = vaddq_s16(input_9, input_offset_vec); - } - - acc_0 = MultiplyAccumulateRow(acc_0, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_1 = MultiplyAccumulateRow(acc_1, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - - // Moving onto the two bottom outputs. - acc_2 = MultiplyAccumulateRow(acc_2, filter.f0, filter.f1, filter.f2, - input_0, input_1, input_2); - - acc_3 = MultiplyAccumulateRow(acc_3, filter.f0, filter.f1, filter.f2, - input_2, input_3, input_4); - - acc_2 = MultiplyAccumulateRow(acc_2, filter.f3, filter.f4, filter.f5, - input_5, input_6, input_7); - - acc_3 = MultiplyAccumulateRow(acc_3, filter.f3, filter.f4, filter.f5, - input_7, input_8, input_9); - - // Load last input row. - { - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4; +enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter }; - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - temp_3 = vld1_u8(ptr + 3 * input_depth); - temp_4 = vld1_u8(ptr + 4 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - } - - acc_2 = MultiplyAccumulateRow(acc_2, filter.f6, filter.f7, filter.f8, - input_0, input_1, input_2); - - acc_3 = MultiplyAccumulateRow(acc_3, filter.f6, filter.f7, filter.f8, - input_2, input_3, input_4); - } - - DownquantizeAndStore2x2Output(acc_0, acc_1, acc_2, acc_3, output_offset, - output_multiplier, output_shift, - output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - } -}; +template +struct DepthwiseConvPartial {}; template <> -struct ConvKernel3x3FilterDepth8<2, 4, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - // Reuse 2x2 kernel twice. - ConvKernel3x3FilterDepth8<2, 2, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, output_depth, - output_width); - - ConvKernel3x3FilterDepth8<2, 2, 2, 2>::Run( - input_ptr + 4 * input_depth, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr + 2 * output_depth, output_depth, output_width); +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 1x1 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the 1x1 input and filter values. + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr x11, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w10\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + "cmp x11, #16\n" + "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w10, w10\n" + "dup v29.4s, w10\n" + "ldr w10, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w9\n" + "ldr w9, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w10\n" + "dup v25.8h, w9\n" + + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v0.8h, v25.8h, v0.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "subs x11, x11, #8\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "cmp x11, #16\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v0.8h, v25.8h, v0.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v8", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", + "v29", "v30", "v31", + // We use these general-purpose registers. + "x9", "x10", "x11"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP } }; template <> -struct ConvKernel3x3FilterDepth8<2, 1, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - const int output_row_size = output_depth * output_width; - - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, - temp_8; - - const uint8* ptr = input_ptr; - - // Load all inputs for top output. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Second output. - output_ptr += output_row_size; - - ptr += input_row_size; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - - DotProductAndStore( - filter, input_6, input_7, input_8, input_0, input_1, input_2, input_3, - input_4, input_5, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 2x2 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the beginning of the 2x2 input and + // filter values. + + // Load input and filter values. + "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "ldr x9, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n" + "cmp x15, #16\n" + "add x12, %[input_ptr], x15\n" + "add x13, %[input_ptr], x9\n" + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "add x14, x13, x15\n" + "ld1 {v9.8b}, [x12], #8\n" + "ldr x6, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n" + + "add x9, %[filter_ptr], x15\n" + "ld1 {v10.8b}, [x13], #8\n" + "add x10, %[filter_ptr], x6\n" + "ld1 {v11.8b}, [x14], #8\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + "add x11, x10, x15\n" + "ld1 {v1.8b}, [x9], #8\n" + "ld1 {v2.8b}, [x10], #8\n" + "ld1 {v3.8b}, [x11], #8\n" + + // Load constants. + "ldr w6, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w6\n" + "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w7\n" + "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w6\n" + "ldr w6, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w7, w7\n" + "dup v29.4s, w7\n" + "ldr w7, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w6\n" + "ldr w6, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w7\n" + "dup v25.8h, w6\n" + + // Add input and filter offsets. + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "subs x15, x15, #8\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [%[input_ptr]], #8\n" + "cmp x15, #16\n" + "ld1 {v0.8b}, [%[filter_ptr]], #8\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], #8\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "ld1 {v1.8b}, [x9], #8\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], #8\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "ld1 {v2.8b}, [x10], #8\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "ld1 {v11.8b}, [x14], #8\n" + "ld1 {v3.8b}, [x11], #8\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + // We use these general-purpose registers. + "x6", "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP } }; template <> -struct ConvKernel3x3FilterDepth8<1, 2, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, - temp_8; - - const uint8* ptr = input_ptr; - - // Load all inputs for top output. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Second output. - output_ptr += output_depth; - - ptr = input_ptr + 3 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - DotProductAndStore( - filter, input_2, input_0, input_1, input_5, input_3, input_4, input_8, - input_6, input_7, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 2x3 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the beginning of the 2x3 input and + // filter values. + + // Load input and filter values. + "ldr x7, [%[params_ptr], #" STR(OFFSET_INPUT_DEPTH) "]\n" + "mov x12, %[input_ptr]\n" + "ldr x11, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n" + "mov x9, %[filter_ptr]\n" + "ldr x14, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n" + "add x13, x12, x11\n" + "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + + "ld1 {v8.8b}, [x12], x7\n" + "add x10, x9, x14\n" + "ld1 {v9.8b}, [x12], x7\n" + "cmp x15, #16\n" + "ld1 {v10.8b}, [x12]\n" + "add %[input_ptr], %[input_ptr], #8\n" + "ld1 {v11.8b}, [x13], x7\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "ld1 {v12.8b}, [x13], x7\n" + "ld1 {v13.8b}, [x13]\n" + + "ld1 {v0.8b}, [x9], x7\n" + "ld1 {v1.8b}, [x9], x7\n" + "ld1 {v2.8b}, [x9]\n" + "ld1 {v3.8b}, [x10], x7\n" + "ld1 {v4.8b}, [x10], x7\n" + "ld1 {v5.8b}, [x10]\n" + + // Load constants. + "ldr w12, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w13, w13\n" + "dup v29.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w13\n" + "dup v25.8h, w12\n" + + // Add input and filter offsets. + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "mov x12, %[input_ptr]\n" + "subs x15, x15, #8\n" + "add x13, x12, x11\n" + "cmp x15, #16\n" + "add %[input_ptr], %[input_ptr], #8\n" + + "smlal v16.4s, v0.4h, v8.4h\n" + "mov x9, %[filter_ptr]\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [x12], x7\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "add x10, x9, x14\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "ld1 {v9.8b}, [x12], x7\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "ld1 {v10.8b}, [x12]\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "ld1 {v0.8b}, [x9], x7\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "ld1 {v11.8b}, [x13], x7\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "ld1 {v1.8b}, [x9], x7\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "ld1 {v12.8b}, [x13], x7\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "ld1 {v2.8b}, [x9]\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + "ld1 {v13.8b}, [x13]\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "ld1 {v3.8b}, [x10], x7\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "ld1 {v4.8b}, [x10], x7\n" + "and v18.16b, v16.16b, v29.16b\n" + "ld1 {v5.8b}, [x10]\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v8", "v9", "v10", "v11", "v12", + "v13", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP } }; template <> -struct ConvKernel3x3FilterDepth8<1, 4, 2, 2> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - uint8x8_t temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, - temp_8; - - const uint8* ptr = input_ptr; - - // Load all inputs for top output. - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - temp_2 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - temp_5 = vld1_u8(ptr + 2 * input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - temp_8 = vld1_u8(ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Second output. - output_ptr += output_depth; - - ptr = input_ptr + 3 * input_depth; - temp_0 = vld1_u8(ptr); - temp_1 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_3 = vld1_u8(ptr); - temp_4 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_6 = vld1_u8(ptr); - temp_7 = vld1_u8(ptr + input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - - DotProductAndStore( - filter, input_2, input_0, input_1, input_5, input_3, input_4, input_8, - input_6, input_7, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Third output. - output_ptr += output_depth; - - ptr = input_ptr + 5 * input_depth; - temp_2 = vld1_u8(ptr); - temp_0 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_5 = vld1_u8(ptr); - temp_3 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_8 = vld1_u8(ptr); - temp_6 = vld1_u8(ptr + input_depth); - - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - - input_2 = vaddq_s16(input_2, input_offset_vec); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - - DotProductAndStore( - filter, input_1, input_2, input_0, input_4, input_5, input_3, input_7, - input_8, input_6, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - - // Fourth output. - output_ptr += output_depth; - - ptr = input_ptr + 7 * input_depth; - temp_1 = vld1_u8(ptr); - temp_2 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_4 = vld1_u8(ptr); - temp_5 = vld1_u8(ptr + input_depth); - ptr += input_row_size; - temp_7 = vld1_u8(ptr); - temp_8 = vld1_u8(ptr + input_depth); - - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); +struct DepthwiseConvPartial { + static inline void Run(const uint8* input_ptr, const uint8* filter_ptr, + const int32* bias_ptr, uint8* output_ptr, + const DepthwiseConvParams* params_ptr) { +#define DEPTHWISECONV_LABEL_DEPTH_8_LOOP "1" +#define DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "2" + asm volatile( + // Performs depthwise convolutions for an input window of size 3x2 and + // padding of 1 across the full depth. Expects |input_ptr| and + // |filter_ptr| to be pointing to the beginning of the 3x2 input and + // filter values. + + // Load input and filter values. + "ldr x6, [%[params_ptr], #" STR(OFFSET_INPUT_DEPTH) "]\n" + "mov x12, %[input_ptr]\n" + "ldr x11, [%[params_ptr], #" STR(OFFSET_INPUT_ROW_SIZE) "]\n" + "mov x7, %[filter_ptr]\n" + "ldr x5, [%[params_ptr], #" STR(OFFSET_FILTER_ROW_SIZE) "]\n" + "add x13, x12, x11\n" + "ldr x15, [%[params_ptr], #" STR(OFFSET_OUTPUT_DEPTH) "]\n" + "add x14, x13, x11\n" + + "ld1 {v8.8b}, [x12], x6\n" + "add x9, x7, x5\n" + "ld1 {v9.8b}, [x12]\n" + "cmp x15, #16\n" + "add x10, x9, x5\n" + "ld1 {v10.8b}, [x13], x6\n" + "add %[input_ptr], %[input_ptr], #8\n" + "ld1 {v11.8b}, [x13]\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "ld1 {v12.8b}, [x14], x6\n" + "ld1 {v13.8b}, [x14]\n" + + "ld1 {v0.8b}, [x7], x6\n" + "ld1 {v1.8b}, [x7]\n" + "ld1 {v2.8b}, [x9], x6\n" + "ld1 {v3.8b}, [x9]\n" + "ld1 {v4.8b}, [x10], x6\n" + "ld1 {v5.8b}, [x10]\n" + + // Load constants. + "ldr w12, [%[params_ptr], #" STR(OFFSET_INPUT_OFFSET) "]\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_MULTIPLIER) "]\n" + "dup v26.8h, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_OFFSET) "]\n" + "dup v27.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_SHIFT) "]\n" + "dup v28.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MIN) "]\n" + "neg w13, w13\n" + "dup v29.4s, w13\n" + "ldr w13, [%[params_ptr], #" STR(OFFSET_OUTPUT_ACTIVATION_MAX) "]\n" + "dup v30.4s, w12\n" + "ldr w12, [%[params_ptr], #" STR(OFFSET_FILTER_OFFSET) "]\n" + "dup v31.4s, w13\n" + "dup v25.8h, w12\n" + + // Add input and filter offsets. + "uaddw v8.8h, v26.8h, v8.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "blt " DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP "f\n" + + //"loop_%=:\n" + DEPTHWISECONV_LABEL_DEPTH_8_LOOP ":\n" + "mov x12, %[input_ptr]\n" + "subs x15, x15, #8\n" + "add x13, x12, x11\n" + "cmp x15, #16\n" + "add x14, x13, x11\n" + "add %[input_ptr], %[input_ptr], #8\n" + + "smlal v16.4s, v0.4h, v8.4h\n" + "mov x7, %[filter_ptr]\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "ld1 {v8.8b}, [x12], x6\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "add x9, x7, x5\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "add x10, x9, x5\n" + "ld1 {v9.8b}, [x12]\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "add %[filter_ptr], %[filter_ptr], #8\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "ld1 {v10.8b}, [x13], x6\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "ld1 {v0.8b}, [x7], x6\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "ld1 {v11.8b}, [x13]\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "ld1 {v1.8b}, [x7]\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "ld1 {v12.8b}, [x14], x6\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "ld1 {v2.8b}, [x9], x6\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + "ld1 {v13.8b}, [x14]\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "ld1 {v3.8b}, [x9]\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "ld1 {v4.8b}, [x10], x6\n" + "and v18.16b, v16.16b, v29.16b\n" + "ld1 {v5.8b}, [x10]\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "uaddw v8.8h, v26.8h, v8.8b\n" + "st1 {v16.8b}, [%[output_ptr]], #8\n" + "uaddw v9.8h, v26.8h, v9.8b\n" + "uaddw v10.8h, v26.8h, v10.8b\n" + "uaddw v11.8h, v26.8h, v11.8b\n" + "uaddw v12.8h, v26.8h, v12.8b\n" + "uaddw v13.8h, v26.8h, v13.8b\n" + + "uaddw v0.8h, v25.8h, v0.8b\n" + "uaddw v1.8h, v25.8h, v1.8b\n" + "uaddw v2.8h, v25.8h, v2.8b\n" + "ld1 {v16.4s}, [%[bias_ptr]], #16\n" + "uaddw v3.8h, v25.8h, v3.8b\n" + "ld1 {v17.4s}, [%[bias_ptr]], #16\n" + "uaddw v4.8h, v25.8h, v4.8b\n" + "uaddw v5.8h, v25.8h, v5.8b\n" + + "bge " DEPTHWISECONV_LABEL_DEPTH_8_LOOP "b\n" + + DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP ":\n" + "smlal v16.4s, v0.4h, v8.4h\n" + "smlal2 v17.4s, v0.8h, v8.8h\n" + "smlal v16.4s, v1.4h, v9.4h\n" + "smlal2 v17.4s, v1.8h, v9.8h\n" + "smlal v16.4s, v2.4h, v10.4h\n" + "smlal2 v17.4s, v2.8h, v10.8h\n" + "smlal v16.4s, v3.4h, v11.4h\n" + "smlal2 v17.4s, v3.8h, v11.8h\n" + "smlal v16.4s, v4.4h, v12.4h\n" + "smlal2 v17.4s, v4.8h, v12.8h\n" + "smlal v16.4s, v5.4h, v13.4h\n" + "smlal2 v17.4s, v5.8h, v13.8h\n" + + "sqrdmulh v16.4s, v16.4s, v27.4s\n" + "sqrdmulh v17.4s, v17.4s, v27.4s\n" + "and v18.16b, v16.16b, v29.16b\n" + "and v19.16b, v17.16b, v29.16b\n" + "sshr v18.4s, v18.4s, #31\n" + "sshr v19.4s, v19.4s, #31\n" + "sqadd v16.4s, v16.4s, v18.4s\n" + "sqadd v17.4s, v17.4s, v19.4s\n" + "srshl v16.4s, v16.4s, v29.4s\n" + "srshl v17.4s, v17.4s, v29.4s\n" + "add v16.4s, v16.4s, v28.4s\n" + "add v17.4s, v17.4s, v28.4s\n" + "smax v16.4s, v16.4s, v30.4s\n" + "smax v17.4s, v17.4s, v30.4s\n" + "smin v16.4s, v16.4s, v31.4s\n" + "smin v17.4s, v17.4s, v31.4s\n" + "sqxtn v16.4h, v16.4s\n" + "sqxtn2 v16.8h, v17.4s\n" + "sqxtun v16.8b, v16.8h\n" + "st1 {v16.8b}, [%[output_ptr]]\n" + : + // Outputs. + [filter_ptr] "+r"(filter_ptr), [input_ptr] "+r"(input_ptr), + [output_ptr] "+r"(output_ptr), [bias_ptr] "+r"(bias_ptr) + : + // Inputs. + [params_ptr] "r"(params_ptr) + : + // Clobbers. + "cc", "memory", + // We use these NEON registers. + "v0", "v1", "v2", "v3", "v4", "v5", "v8", "v9", "v10", "v11", "v12", + "v13", "v16", "v17", "v18", "v19", "v25", "v26", "v27", "v28", "v29", + "v30", "v31", + // We use these general-purpose registers. + "x5", "x6", "x7", "x9", "x10", "x11", "x12", "x13", "x14", "x15"); +#undef DEPTHWISECONV_LABEL_DEPTH_8_LOOP +#undef DEPTHWISECONV_LABEL_DEPTH_8_AFTER_LOOP } }; -template -struct ConvKernel3x3FilterDepth8<1, 1, kFixedStrideWidth, kFixedStrideHeight> { - static inline void Run(const uint8* input_ptr, int input_depth, - int32 input_offset, int input_row_size, - const uint8* filter_ptr, int32 filter_offset, - const int32* bias_ptr, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_ptr, - int output_depth, int output_width) { - Filter3x3x8 filter = Load3x3Filter(filter_ptr, filter_offset, output_depth); - - int16x8_t input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8; - - uint8x8_t temp_0 = vld1_u8(input_ptr); - uint8x8_t temp_1 = vld1_u8(input_ptr + input_depth); - uint8x8_t temp_2 = vld1_u8(input_ptr + 2 * input_depth); - - input_ptr += input_row_size; - uint8x8_t temp_3 = vld1_u8(input_ptr); - uint8x8_t temp_4 = vld1_u8(input_ptr + input_depth); - uint8x8_t temp_5 = vld1_u8(input_ptr + 2 * input_depth); - - input_ptr += input_row_size; - uint8x8_t temp_6 = vld1_u8(input_ptr); - uint8x8_t temp_7 = vld1_u8(input_ptr + input_depth); - uint8x8_t temp_8 = vld1_u8(input_ptr + 2 * input_depth); - - input_0 = vreinterpretq_s16_u16(vmovl_u8(temp_0)); - input_1 = vreinterpretq_s16_u16(vmovl_u8(temp_1)); - input_2 = vreinterpretq_s16_u16(vmovl_u8(temp_2)); - input_3 = vreinterpretq_s16_u16(vmovl_u8(temp_3)); - input_4 = vreinterpretq_s16_u16(vmovl_u8(temp_4)); - input_5 = vreinterpretq_s16_u16(vmovl_u8(temp_5)); - input_6 = vreinterpretq_s16_u16(vmovl_u8(temp_6)); - input_7 = vreinterpretq_s16_u16(vmovl_u8(temp_7)); - input_8 = vreinterpretq_s16_u16(vmovl_u8(temp_8)); - - const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); - input_0 = vaddq_s16(input_0, input_offset_vec); - input_1 = vaddq_s16(input_1, input_offset_vec); - input_2 = vaddq_s16(input_2, input_offset_vec); - input_3 = vaddq_s16(input_3, input_offset_vec); - input_4 = vaddq_s16(input_4, input_offset_vec); - input_5 = vaddq_s16(input_5, input_offset_vec); - input_6 = vaddq_s16(input_6, input_offset_vec); - input_7 = vaddq_s16(input_7, input_offset_vec); - input_8 = vaddq_s16(input_8, input_offset_vec); - - DotProductAndStore( - filter, input_0, input_1, input_2, input_3, input_4, input_5, input_6, - input_7, input_8, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, output_ptr); - } -}; - -inline void ShuffleInput(const uint8* input_ptr, int input_depth, - int input_width, int input_height, int output_depth, - int output_width, int output_height, - uint8* output_ptr) { - const int input_row_size = input_depth * input_width; - - for (int y = 0; y < output_height; y++) { +#undef OFFSET_INPUT_DEPTH +#undef OFFSET_INPUT_ROW_SIZE +#undef OFFSET_OUTPUT_DEPTH +#undef OFFSET_OUTPUT_ROW_SIZE +#undef OFFSET_INPUT_OFFSET +#undef OFFSET_OUTPUT_OFFSET +#undef OFFSET_FILTER_OFFSET +#undef OFFSET_OUTPUT_MULTIPLIER +#undef OFFSET_OUTPUT_ACTIVATION_MIN +#undef OFFSET_OUTPUT_ACTIVATION_MAX +#undef OFFSET_OUTPUT_SHIFT +#undef OFFSET_INPUT_WIDTH +#undef OFFSET_INPUT_HEIGHT +#undef OFFSET_OUTPUT_WIDTH +#undef OFFSET_OUTPUT_HEIGHT +#undef STR +#undef STR_UNEXPANDED + +// Copies a subset of the input designated by |input_ptr| into |output_ptr| +// with the specified output dimensions. Supports output depths of 64 only as +// this is the cache line size. +inline void ShuffleInput(const uint8* input_ptr, int64_t input_depth, + int32 input_width, int32 input_height, + int64_t output_depth, int32 output_width, + int32 output_height, uint8* output_ptr) { + const int64_t input_row_size = input_depth * input_width; + for (int32 y = 0; y < output_height; y++) { const uint8* ptr = input_ptr; - for (int x = 0; x < output_width; x++) { + for (int32 x = 0; x < output_width; x++) { memcpy(output_ptr, ptr, output_depth); output_ptr += output_depth; ptr += input_depth; @@ -3873,561 +2937,262 @@ inline void ShuffleInput(const uint8* input_ptr, int input_depth, } } -template -struct ConvRow3x3FilterDepth8 {}; - -template -struct ConvRow3x3FilterDepth8<1, kFixedStrideWidth, kFixedStrideHeight> { - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - int out_x = start_x; - - // 1x4 at a time. - for (; out_x <= output_width - 4; out_x += 4) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<1, 4, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 4 * kFixedStrideWidth * input_depth; - output_data += 4 * output_depth; - } - - // 1x1 at a time. - for (; out_x < output_width; out_x++) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<1, 1, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } +// Calculates the input size depending on stride and output. +inline int32 get_shuffle_input_size(int32 stride, int32 output) { + return stride * (output - 1) + 3; +} - input_data += kFixedStrideWidth * input_depth; - output_data += output_depth; - } +// Indicates the input and output dimensions used when shuffling input +// activations. +struct ShuffleParams { + int32 output_width; + int32 output_height; + int32 input_width; + int32 input_height; + + ShuffleParams() = default; + ShuffleParams(int32 output_width, int32 output_height, int32 stride_width, + int32 stride_height) + : output_width(output_width) + , output_height(output_height) + , input_width(get_shuffle_input_size(stride_width, output_width)) + , input_height(get_shuffle_input_size(stride_height, output_height)) { } }; -template -struct ConvRow3x3FilterDepth8<2, kFixedStrideWidth, kFixedStrideHeight> { - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - int out_x = start_x; - - // 2x4 at a time. - for (; out_x <= output_width - 4; out_x += 4) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<2, 4, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 4 * kFixedStrideWidth * input_depth; - output_data += 4 * output_depth; - } - - // 2x2 at a time. - for (; out_x <= output_width - 2; out_x += 2) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<2, 2, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 2 * kFixedStrideWidth * input_depth; - output_data += 2 * output_depth; - } - - // 2x1 at a time. - for (; out_x < output_width; out_x++) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<2, 1, kFixedStrideWidth, kFixedStrideHeight>:: - Run(input_ptr, input_depth, input_offset, input_row_size, - filter_ptr, filter_offset, bias_ptr, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += kFixedStrideWidth * input_depth; - output_data += output_depth; +template +struct DepthwiseConvThroughDepth { + // Runs the DepthwiseConvWindow kernels through the depth dimension from + // |start_depth| to |end_depth|. Keep this not inlined to maintain a small + // binary size. We use a DepthwiseConvParams struct for read only params + // to minimize call overhead. + static __attribute__((noinline)) void Run(const uint8* input_ptr, + const uint8* filter_ptr, const int32* bias_ptr, uint8* output_ptr, + int64_t start_depth, int64_t end_depth, int64_t input_depth, + int64_t input_row_size, int32 output_window_height, + int32 output_window_width, const DepthwiseConvParams& params) { + for (; start_depth <= end_depth - 8; start_depth += 8) { + DepthwiseConvWindow<8, kStrideWidth, kStrideHeight>::Run( + input_ptr, filter_ptr, bias_ptr, output_ptr, input_depth, + input_row_size, output_window_height, output_window_width, ¶ms); + input_ptr += 8; + output_ptr += 8; + filter_ptr += 8; + bias_ptr += 8; } } }; -template <> -struct ConvRow3x3FilterDepth8<4, 1, 1> { - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - int out_x = start_x; - - // 4x4 at a time. - for (; out_x <= output_width - 4; out_x += 4) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 4, 1, 1>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } +template +struct DepthwiseConvMultiRow { + using ConvKernel = DepthwiseConvThroughDepth; - input_data += 4 * input_depth; - output_data += 4 * output_depth; - } - - // Handle the rest of the right side. - // 4x2 at a time. - for (; out_x <= output_width - 2; out_x += 2) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 2, 1, 1>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 2 * input_depth; - output_data += 2 * output_depth; - } - - // 4x1 at a time. - for (; out_x < output_width; out_x++) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 1, 1, 1>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += input_depth; - output_data += output_depth; - } - } -}; - -template <> -struct ConvRow3x3FilterDepth8<4, 2, 2> { - // The buffer size of the shuffled input. - static inline constexpr int ShuffleWorkspaceSize() { return 64 * 9 * 9; } - - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, + static inline void Run(const uint8* input_data, int32 start_x, int32 end_x, + const uint8* filter_data, const int32* bias_data, + uint8* output_data, const DepthwiseConvParams& params, + const ShuffleParams& shuffle_params, uint8* shuffle_workspace) { - // Branch and cache misses increase substantially with stride 2 kernels. - // Adding prefetching reduces latency by as much as 2x. - const int i0 = 0; - const int i1 = input_depth; - const int i2 = 2 * input_depth; - const int i3 = 3 * input_depth; - const int i4 = 4 * input_depth; - const int i5 = 5 * input_depth; - const int i6 = 6 * input_depth; - const int i7 = 7 * input_depth; - const int i8 = 8 * input_depth; - -#define DEPTHWISECONV_PRELOAD_ROW(input_ptr, i) \ - preload_l1_keep(input_ptr + i * input_row_size + i0); \ - preload_l1_keep(input_ptr + i * input_row_size + i1); \ - preload_l1_keep(input_ptr + i * input_row_size + i2); \ - preload_l1_keep(input_ptr + i * input_row_size + i3); \ - preload_l1_keep(input_ptr + i * input_row_size + i4); \ - preload_l1_keep(input_ptr + i * input_row_size + i5); \ - preload_l1_keep(input_ptr + i * input_row_size + i6); \ - preload_l1_keep(input_ptr + i * input_row_size + i7); \ - preload_l1_keep(input_ptr + i * input_row_size + i8); - - int out_x = start_x; - // 4x4 at a time. - for (; out_x <= output_width - 4; out_x += 4) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - int depth = 0; - for (; depth <= output_depth - 64; depth += 64) { - // Preload 9x9 input. - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 0); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 1); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 2); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 3); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 4); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 5); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 6); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 7); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 8); - - // For a large input window (64x9x9) that is small enough to fit in L1 - // cache, copy the input into a separate buffer and run the kernel on - // this new buffer. This reduces the likelihood of cache misses when - // the kernel is loading input data. If this size is ever changed, - // update the ShuffleWorkspaceSize() function to return the new size. - ShuffleInput(input_ptr, input_depth, input_width, input_height, 64, 9, - 9, shuffle_workspace); - const uint8* shuffled_ptr = &shuffle_workspace[0]; - - for (int micro_depth = 0; micro_depth <= 64 - 8; micro_depth += 8) { - ConvKernel3x3FilterDepth8<4, 4, 2, 2>::Run( - shuffled_ptr, 64, input_offset, 64 * 9, filter_ptr, filter_offset, - bias_ptr, output_offset, output_multiplier, output_shift, - output_activation_min, output_activation_max, output_ptr, - output_depth, output_width); - - shuffled_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; + TFLITE_DCHECK(shuffle_params.input_height == + get_shuffle_input_size(kStrideHeight, shuffle_params.output_height)); + TFLITE_DCHECK(shuffle_params.input_width == + get_shuffle_input_size(kStrideWidth, shuffle_params.output_width)); + TFLITE_DCHECK(64 * shuffle_params.input_width * shuffle_params.input_height + <= DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE); + + int32 out_x = start_x; + + // Run shuffling on inputs with sufficiently large depth and width. When + // these parameters are large enough, more time is taken to load inputs + // from memory. At this point, it becomes useful to prefetch and + // preshuffle the input data to maximize locality. + if (params.output_depth > 64 || + (params.output_depth <= 64 && params.input_width > 150)) { + for (; out_x <= (end_x - shuffle_params.output_width); + out_x += shuffle_params.output_width) { + const uint8* input_ptr = input_data; + const int32* bias_ptr = bias_data; + const uint8* filter_ptr = filter_data; + uint8* output_ptr = output_data; + int64_t depth = 0; + const int64_t shuffle_row_size = 64 * shuffle_params.input_width; + + for (; depth <= params.output_depth - 64; depth += 64) { + // Preload. + const uint8* h_ptr = input_ptr; + for (int32 i = 0; i < shuffle_params.input_height; i++) { + const uint8* ptr = h_ptr; + for (int32 j = 0; j < shuffle_params.input_width; j++) { + asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :); + ptr += params.input_depth; + } + h_ptr += params.input_row_size; + } + + // For a large enough input, shuffle into buckets. + ShuffleInput(input_ptr, params.input_depth, params.input_width, + params.input_height, 64, shuffle_params.input_width, + shuffle_params.input_height, shuffle_workspace); + ConvKernel::Run(shuffle_workspace, filter_ptr, bias_ptr, output_ptr, + 0, 64, 64, shuffle_row_size, + shuffle_params.output_height, + shuffle_params.output_width, params); + input_ptr += 64; + output_ptr += 64; + filter_ptr += 64; + bias_ptr += 64; } - input_ptr += 64; - } - - // Preload 9x9 input one more time for the rest of the depth. - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 0); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 1); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 2); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 3); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 4); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 5); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 6); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 7); - DEPTHWISECONV_PRELOAD_ROW(input_ptr, 8); - - for (; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 4, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 4 * 2 * input_depth; - output_data += 4 * output_depth; - } - -#undef DEPTHWISECONV_PRELOAD_ROW - - // Handle the rest of the right side. - // 4x2 at a time. - for (; out_x <= output_width - 2; out_x += 2) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; + // Preload. + const uint8* h_ptr = input_ptr; + for (int32 i = 0; i < shuffle_params.input_height; i++) { + const uint8* ptr = h_ptr; + for (int32 j = 0; j < shuffle_params.input_width; j++) { + asm volatile("prfm pldl1keep, [%[ptr]]\n" ::[ptr] "r"(ptr) :); + ptr += params.input_depth; + } + h_ptr += params.input_row_size; + } - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 2, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); + // Handle leftover depth. + ConvKernel::Run(input_ptr, filter_ptr, bias_ptr, output_ptr, + depth, params.output_depth, params.input_depth, + params.input_row_size, shuffle_params.output_height, + shuffle_params.output_width, params); - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; + input_data += + shuffle_params.output_width * kStrideWidth * params.input_depth; + output_data += shuffle_params.output_width * params.output_depth; } - - input_data += 2 * 2 * input_depth; - output_data += 2 * output_depth; } - // 4x1 at a time. - for (; out_x < output_width; out_x++) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - for (int depth = 0; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<4, 1, 2, 2>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - - input_data += 2 * input_depth; - output_data += output_depth; + const int32 output_leftover_width = end_x - out_x; + if (output_leftover_width > 0) { + ConvKernel::Run(input_data, filter_data, bias_data, output_data, 0, + params.output_depth, params.input_depth, + params.input_row_size, shuffle_params.output_height, + output_leftover_width, params); } } }; -template <> -struct ConvRow3x3FilterDepth8<8, 2, 2> { - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - // Reuse 4 row kernels twice. - ConvRow3x3FilterDepth8<4, 2, 2>::Run( - input_data, start_x, start_y, input_depth, input_width, input_height, - input_row_size, input_offset, filter_data, filter_offset, bias_data, - output_offset, output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_depth, output_width, - shuffle_workspace); - - ConvRow3x3FilterDepth8<4, 2, 2>::Run( - input_data + 2 * 4 * input_row_size, start_x, start_y + 4, input_depth, - input_width, input_height, input_row_size, input_offset, filter_data, - filter_offset, bias_data, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_data + 4 * output_depth * output_width, output_depth, - output_width, shuffle_workspace); +// Processes the borders of the input for pad_width and pad_height = 1. +// Calls 4 asm kernels: +// * 1x1 input shape. +// * Corner edges. +// * Horizontal edges. +// * Vertical edges. +inline void DepthwiseConvHandlePadding(const uint8* input_data, + const uint8* filter_data, const int32* bias_data, uint8* output_data, + const DepthwiseConvParams& params) { + if (params.input_width == 1 && params.input_height == 1) { + const uint8* filter_ptr = filter_data + params.filter_row_size + + params.output_depth; + DepthwiseConvPartial::Run(input_data, filter_ptr, + bias_data, output_data, ¶ms); + return; } -}; -template <> -struct ConvRow3x3FilterDepth8<8, 1, 1> { - // The buffer size of the shuffled input. - static inline constexpr int ShuffleWorkspaceSize() { return 64 * 10 * 10; } - - static inline void Run(const uint8* input_data, int start_x, int start_y, - int input_depth, int input_width, int input_height, - int input_row_size, int32 input_offset, - const uint8* filter_data, int32 filter_offset, - const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - int output_depth, int output_width, - uint8* shuffle_workspace) { - int out_x = start_x; - // 8x8 at a time. - for (; out_x <= output_width - 8; out_x += 8) { - const int32* bias_ptr = bias_data; - const uint8* filter_ptr = filter_data; - - const uint8* input_ptr = input_data; - uint8* output_ptr = output_data; - - int depth = 0; - for (; depth <= output_depth - 64; depth += 64) { - // For a large input window (64x10x10) that is small enough to fit in L1 - // cache, copy the input into a separate buffer and run the kernel on - // this new buffer. This reduces the likelihood of cache misses when - // the kernel is loading input data. If the size of the input window - // changes, update the function ShuffleWorkspaceSize() with the new - // size. - ShuffleInput(input_ptr, input_depth, input_width, input_height, 64, 10, - 10, shuffle_workspace); - const uint8* shuffled_ptr = shuffle_workspace; - - for (int micro_depth = 0; micro_depth <= 64 - 8; micro_depth += 8) { - ConvKernel3x3FilterDepth8<8, 8, 1, 1>::Run( - shuffled_ptr, 64, input_offset, 64 * 10, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - shuffled_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } - input_ptr += 64; - } + const int32 out_x_start_corner = 0; + const int32 out_x_end_corner = params.output_width - 1; + const int32 out_y_start_corner = 0; + const int32 out_y_end_corner = params.output_height - 1; + + // Handle top row. + const uint8* input_ptr = input_data; + const uint8* filter_ptr = filter_data + params.filter_row_size + + params.output_depth; + uint8* output_ptr = output_data; + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); + + input_ptr += (params.stride_width - 1) * params.input_depth; + filter_ptr = filter_data + params.filter_row_size; + output_ptr += params.output_depth; + + for (int32 out_x = out_x_start_corner + 1; out_x < out_x_end_corner; + out_x++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_depth; + output_ptr += params.output_depth; + } - for (; depth <= output_depth - 8; depth += 8) { - ConvKernel3x3FilterDepth8<8, 8, 1, 1>::Run( - input_ptr, input_depth, input_offset, input_row_size, filter_ptr, - filter_offset, bias_ptr, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_ptr, output_depth, output_width); - - input_ptr += 8; - output_ptr += 8; - filter_ptr += 8; - bias_ptr += 8; - } + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); - input_data += 8 * input_depth; - output_data += 8 * output_depth; - } + // Handle left side. + input_ptr = input_data + (params.stride_width - 1) * params.input_row_size; + filter_ptr = filter_data + params.input_depth; + output_ptr = output_data + params.output_row_size; - // Handle the rest of the right side by re-using 4 row kernels twice. - ConvRow3x3FilterDepth8<4, 1, 1>::Run( - input_data, out_x, start_y, input_depth, input_width, input_height, - input_row_size, input_offset, filter_data, filter_offset, bias_data, - output_offset, output_multiplier, output_shift, output_activation_min, - output_activation_max, output_data, output_depth, output_width, - shuffle_workspace); - - ConvRow3x3FilterDepth8<4, 1, 1>::Run( - input_data + 4 * input_row_size, out_x, start_y + 4, input_depth, - input_width, input_height, input_row_size, input_offset, filter_data, - filter_offset, bias_data, output_offset, output_multiplier, - output_shift, output_activation_min, output_activation_max, - output_data + 4 * output_depth * output_width, output_depth, - output_width, shuffle_workspace); + for (int32 out_y = out_y_start_corner + 1; out_y < out_y_end_corner; + out_y++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_row_size; + output_ptr += params.output_row_size; } -}; -inline bool Fast3x3FilterKernelSupported(const Dims<4>& input_dims, - const Dims<4>& filter_dims, - int stride_width, int stride_height, - int pad_width, int pad_height, - int depth_multiplier, - const Dims<4>& output_dims) { - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int input_depth = ArraySize(input_dims, 0); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - - bool supported = filter_width == 3 && filter_height == 3 && - depth_multiplier == 1 && - (stride_width == 1 || stride_width == 2) && - (stride_height == 1 || stride_height == 2) && - (stride_width == stride_height) && pad_width == 0 && - pad_height == 0 && (input_depth % 8) == 0; + // Handle right side. + input_ptr = input_data + (params.input_width - 2) * params.input_depth + + (params.stride_width - 1) * params.input_row_size; + filter_ptr = filter_data; + output_ptr = output_data + params.output_row_size + + (params.output_width - 1) * params.output_depth; + + for (int32 out_y = out_y_start_corner + 1; out_y < out_y_end_corner; + out_y++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_row_size; + output_ptr += params.output_row_size; + } + + // Handle bottom row. + input_ptr = input_data + (params.input_height - 2) * params.input_row_size; + filter_ptr = filter_data + params.output_depth; + output_ptr = output_data + + (params.output_height - 1) * params.output_row_size; + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); + + input_ptr += (params.stride_width == 1) ? 0 : params.input_depth; + filter_ptr = filter_data; + output_ptr += params.output_depth; + + for (int32 out_x = out_x_start_corner + 1; out_x < out_x_end_corner; + out_x++) { + DepthwiseConvPartial::Run( + input_ptr, filter_ptr, bias_data, output_ptr, ¶ms); + input_ptr += params.stride_width * params.input_depth; + output_ptr += params.output_depth; + } + + DepthwiseConvPartial::Run(input_ptr, filter_ptr, + bias_data, output_ptr, ¶ms); +} + +inline bool Fast3x3FilterKernelSupported( + const Dims<4>& input_dims, const Dims<4>& filter_dims, int32 stride_width, + int32 stride_height, int32 pad_width, int32 pad_height, + int32 depth_multiplier, const Dims<4>& output_dims, int32 output_shift) { + const int32 input_height = ArraySize(input_dims, 2); + const int32 input_width = ArraySize(input_dims, 1); + const int32 input_depth = ArraySize(input_dims, 0); + const int32 filter_height = ArraySize(filter_dims, 2); + const int32 filter_width = ArraySize(filter_dims, 1); + const int32 output_height = ArraySize(output_dims, 2); + const int32 output_width = ArraySize(output_dims, 1); + + bool supported = + filter_width == 3 && filter_height == 3 && depth_multiplier == 1 && + (stride_width == 1 || stride_width == 2) && + (stride_height == 1 || stride_height == 2) && + (stride_width == stride_height) && (pad_width == 0 || pad_width == 1) && + (pad_height == 0 || pad_height == 1) && (pad_width == pad_height) && + (input_depth % 8) == 0 && (output_shift > 0); if (!supported) { return false; @@ -4436,145 +3201,193 @@ inline bool Fast3x3FilterKernelSupported(const Dims<4>& input_dims, // Handle case where padding is zero but padding type is not kValid. // This would require special boundary case handling that is not supported. - const int out_x = output_width - 1; - const int out_y = output_height - 1; + const int32 out_x = output_width - 1; + const int32 out_y = output_height - 1; - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; + const int32 in_x_origin = (out_x * stride_width) - pad_width; + const int32 in_y_origin = (out_y * stride_height) - pad_height; - const int in_x_end = in_x_origin + filter_width; - const int in_y_end = in_y_origin + filter_height; + const int32 in_x_end = in_x_origin + filter_width; + const int32 in_y_end = in_y_origin + filter_height; // Supported only if filter on the right and bottom boundary lies completely - // within the input. - return in_x_end <= input_width && in_y_end <= input_height; + // within the input if padding is zero. + if (pad_width == 0 && pad_height == 0) { + return in_x_end <= input_width && in_y_end <= input_height; + } + + // Else if padding is 1, supported if bottom right filter lies +1 past input + // width and height. + supported = in_x_end <= (input_width + 1) && in_y_end <= (input_height + 1); + + if (!supported) { + return false; + } + + // Shapes with width 1 and height > 1, and vice versa are not supported yet. + if (input_width == 1) { + supported = (input_width == input_height); + } else if (input_height == 1) { + supported = (input_width == input_height); + } + return supported; } inline void DepthwiseConv3x3Filter( const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, int stride_width, - int stride_height, int pad_width, int pad_height, int depth_multiplier, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int input_depth = ArraySize(input_dims, 0); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - - // Algorithm assumes below constraints. It is optimized for depth multiplier - // of 1, 3x3 filter, no padding and strides 1 and 2. - TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + const int32* bias_data, const Dims<4>& bias_dims, int32 stride_width, + int32 stride_height, int32 pad_width, int32 pad_height, + int32 depth_multiplier, int32 output_offset, int32 output_multiplier, + int32 output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + DepthwiseConvParams params; + params.input_depth = ArraySize(input_dims, 0); + params.input_width = ArraySize(input_dims, 1); + params.input_height = ArraySize(input_dims, 2); + params.input_row_size = params.input_depth * params.input_width; + params.input_offset = input_offset; + params.stride_width = stride_width; + params.stride_height = stride_height; + params.output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + params.output_width = ArraySize(output_dims, 1); + params.output_height = ArraySize(output_dims, 2); + params.output_row_size = params.output_depth * params.output_width; + params.output_offset = output_offset; + params.filter_offset = filter_offset; + params.output_multiplier = output_multiplier; + params.output_shift = output_shift; + params.output_activation_min = output_activation_min; + params.output_activation_max = output_activation_max; + + const int32 filter_height = ArraySize(filter_dims, 2); + const int32 filter_width = ArraySize(filter_dims, 1); + params.filter_row_size = params.output_depth * filter_width; + + // Algorithm assumes below constraints. It is optimized for depth + // multiplier of 1, 3x3 filter, no padding and strides 1 and 2. + TFLITE_DCHECK(params.output_depth == params.input_depth * depth_multiplier); TFLITE_DCHECK(depth_multiplier == 1); TFLITE_DCHECK(filter_height == 3); TFLITE_DCHECK(filter_width == 3); - TFLITE_DCHECK(pad_height == 0); - TFLITE_DCHECK(pad_width == 0); TFLITE_DCHECK(stride_height == 1 || stride_height == 2); TFLITE_DCHECK(stride_width == 1 || stride_width == 2); TFLITE_DCHECK(stride_width == stride_height); + TFLITE_DCHECK(pad_height == 0 || pad_height == 1); + TFLITE_DCHECK(pad_width == 0 || pad_width == 1); + TFLITE_DCHECK(pad_width == pad_height); + + const int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int64_t input_batch_size = params.input_row_size * params.input_height; + const int64_t output_batch_size = + params.output_row_size * params.output_height; + + ShuffleParams one_row_shuffle_params, two_row_shuffle_params, + four_row_shuffle_params, eight_row_shuffle_params; + if (stride_width == 1) { + one_row_shuffle_params = ShuffleParams(30, 1, 1, 1); + two_row_shuffle_params = ShuffleParams(22, 2, 1, 1); + four_row_shuffle_params = ShuffleParams(14, 4, 1, 1); + eight_row_shuffle_params = ShuffleParams(8, 8, 1, 1); + } else { + one_row_shuffle_params = ShuffleParams(14, 1, 2, 2); + two_row_shuffle_params = ShuffleParams(8, 2, 2, 2); + four_row_shuffle_params = ShuffleParams(4, 4, 2, 2); + eight_row_shuffle_params = ShuffleParams(2, 8, 2, 2); + } - const int input_row_size = input_depth * (input_width + 2 * pad_width); - const int output_row_size = output_depth * output_width; - const int input_batch_size = input_row_size * (input_height + 2 * pad_height); - const int output_batch_size = output_depth * output_width * output_height; - - using conv_row_func_t = decltype(&ConvRow3x3FilterDepth8<1, 1, 1>::Run); - conv_row_func_t conv_1_output_row = ConvRow3x3FilterDepth8<1, 1, 1>::Run; - conv_row_func_t conv_2_output_rows = ConvRow3x3FilterDepth8<2, 1, 1>::Run; - conv_row_func_t conv_4_output_rows = ConvRow3x3FilterDepth8<4, 1, 1>::Run; - conv_row_func_t conv_8_output_rows = ConvRow3x3FilterDepth8<8, 1, 1>::Run; - + using conv_multirow_func_t = decltype(&DepthwiseConvMultiRow<1, 1>::Run); + conv_multirow_func_t conv_multirow_func = DepthwiseConvMultiRow<1, 1>::Run; if (stride_width == 2) { - conv_1_output_row = ConvRow3x3FilterDepth8<1, 2, 2>::Run; - conv_2_output_rows = ConvRow3x3FilterDepth8<2, 2, 2>::Run; - conv_4_output_rows = ConvRow3x3FilterDepth8<4, 2, 2>::Run; - conv_8_output_rows = ConvRow3x3FilterDepth8<8, 2, 2>::Run; + conv_multirow_func = DepthwiseConvMultiRow<2, 2>::Run; } // Allocate maximum memory needed for shuffled input. // TODO(mariewhite): The size of this workspace is small enough to be // allocated on the stack. Eventually we will want to move it to the heap - // and have it allocated outside of this function, like the im2col_array used - // in gemmlowp. -#define DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE 10 * 10 * 64 + // and have it allocated outside of this function, like the im2col_array + // used in gemmlowp. uint8 shuffle_workspace[DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE]; - // Make sure the kernels using this buffer will not run out of bounds. - static_assert(ConvRow3x3FilterDepth8<8, 1, 1>::ShuffleWorkspaceSize() <= - DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE, - "Shuffle workspace size is too small."); - static_assert(ConvRow3x3FilterDepth8<4, 2, 2>::ShuffleWorkspaceSize() <= - DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE, - "Shuffle workspace size is too small."); - -#undef DEPTHWISECONV_SHUFFLE_WORKSPACE_SIZE - - for (int b = 0; b < batches; ++b) { + for (int32 b = 0; b < batches; ++b) { const uint8* input_ptr = input_data + b * input_batch_size; uint8* output_ptr = output_data + b * output_batch_size; - int out_y = 0; + int32 out_x = 0; + int32 out_y = 0; + int32 end_x = params.output_width; + int32 end_y = params.output_height; + + if (pad_width == 1 && pad_height == 1) { + DepthwiseConvHandlePadding(input_ptr, filter_data, bias_data, output_ptr, + params); + + // Update extents now that the edges have been handled. + out_x = 1; + end_x = params.output_width - 1; + out_y = 1; + end_y = params.output_height - 1; + const int in_x = (out_x * stride_width) - pad_width; + const int in_y = (out_y * stride_height) - pad_height; + input_ptr += in_y * params.input_row_size + in_x * params.input_depth; + output_ptr += out_y * params.output_row_size + + out_x * params.output_depth; + } + + // Shuffling shapes that maximize width over the shuffle workspace size + // perform better since the inputs are closer together, minimizing + // shuffling time. + // + // If the input shape has width large enough for the 2 row kernels, + // we prefer to use this. The innermost loop of the kernels handle + // 2 height x 2 width so this is the fastest path. + // + // If the input shape has smaller width but larger height, shuffling is + // still useful and can benefit from kernels 4 row and 8 row kernels. // Handle 8 rows at a time. - for (; out_y <= output_height - 8; out_y += 8) { - conv_8_output_rows(input_ptr, 0, out_y, input_depth, input_width, - input_height, input_row_size, input_offset, - filter_data, filter_offset, bias_data, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, - output_width, shuffle_workspace); - - input_ptr += 8 * stride_height * input_row_size; - output_ptr += 8 * output_row_size; + if (params.input_width < four_row_shuffle_params.input_width) { + for (; out_y <= end_y - 8; out_y += 8) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, + output_ptr, params, eight_row_shuffle_params, + shuffle_workspace); + input_ptr += 8 * stride_height * params.input_row_size; + output_ptr += 8 * params.output_row_size; + } } // Handle 4 rows at a time. - for (; out_y <= output_height - 4; out_y += 4) { - conv_4_output_rows(input_ptr, 0, out_y, input_depth, input_width, - input_height, input_row_size, input_offset, - filter_data, filter_offset, bias_data, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, - output_width, shuffle_workspace); - - input_ptr += 4 * stride_height * input_row_size; - output_ptr += 4 * output_row_size; + if (params.input_width < two_row_shuffle_params.input_width) { + for (; out_y <= end_y - 4; out_y += 4) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, + output_ptr, params, four_row_shuffle_params, + shuffle_workspace); + input_ptr += 4 * stride_height * params.input_row_size; + output_ptr += 4 * params.output_row_size; + } } // Handle 2 rows at a time. - for (; out_y <= output_height - 2; out_y += 2) { - conv_2_output_rows(input_ptr, 0, out_y, input_depth, input_width, - input_height, input_row_size, input_offset, - filter_data, filter_offset, bias_data, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, - output_width, shuffle_workspace); - - input_ptr += 2 * stride_height * input_row_size; - output_ptr += 2 * output_row_size; + for (; out_y <= end_y - 2; out_y += 2) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, + output_ptr, params, two_row_shuffle_params, + shuffle_workspace); + input_ptr += 2 * stride_height * params.input_row_size; + output_ptr += 2 * params.output_row_size; } // Handle one row at a time. - for (; out_y < output_height; out_y++) { - conv_1_output_row(input_ptr, 0, out_y, input_depth, input_width, - input_height, input_row_size, input_offset, filter_data, - filter_offset, bias_data, output_offset, - output_multiplier, output_shift, output_activation_min, - output_activation_max, output_ptr, output_depth, - output_width, shuffle_workspace); - - input_ptr += stride_height * input_row_size; - output_ptr += output_row_size; + for (; out_y < end_y; out_y++) { + conv_multirow_func(input_ptr, out_x, end_x, filter_data, bias_data, + output_ptr, params, one_row_shuffle_params, + shuffle_workspace); + input_ptr += stride_height * params.input_row_size; + output_ptr += params.output_row_size; } } } +// clang-format on #endif // __aarch64__ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc index 08f7cfa5a5f9453cd187164078898e754126da52..38ad32c734a2286c7d23162810625169a4d8df43 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -352,6 +352,30 @@ void NeonSub1Vector(const float* vector, int v_size, float* result) { } } +bool NeonIsZeroVector(const float* vector, int v_size) { + // If v_size is not divisible by kFloatWeightsPerNeonLane, we cannot + // use the main vectorized loop, and we need to process sequentially. + // postamble_start shows the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + const float32x4_t zero_x4_float = vmovq_n_f32(0.0f); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + const float32x4_t i_x4_float = vld1q_f32(vector + v); + uint32x4_t cmp_result = vceqq_f32(i_x4_float, zero_x4_float); + if (vgetq_lane_u32(cmp_result, 0) == 0) return false; + if (vgetq_lane_u32(cmp_result, 1) == 0) return false; + if (vgetq_lane_u32(cmp_result, 2) == 0) return false; + if (vgetq_lane_u32(cmp_result, 3) == 0) return false; + } + + // Postamble loop + for (int v = postamble_start; v < v_size; ++v) { + if (vector[v] != 0.0) return false; + } + return true; +} + void NeonClipVector(const float* vector, int v_size, float abs_limit, float* result) { // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index 9e60d0657b49ed5ee1f2e999acb86ebee0eab972..7a5a8fc54123946229963abd1720030d0bb358bf 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -100,6 +100,11 @@ void ZeroVector(float* vector, int v_size) { float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } +// Check if all entries of a vector are zero. +bool IsZeroVector(const float* vector, int v_size) { + return NEON_OR_PORTABLE(IsZeroVector, vector, v_size); +} + void ClipVector(const float* vector, int v_size, float abs_limit, float* result) { NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result); diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index b33e4b5017d6ace16799fdea9b4d08dc6c110660..be4825f2507ccb583298af6cd824dd6da1fc47c7 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -51,6 +51,13 @@ using reference_ops::LessEqual; using reference_ops::RankOneSelect; using reference_ops::Select; +// TODO(b/80247582) Remove this constant. +// This will be phased out as the shifts are revised with more thought. Use of a +// constant enables us to track progress on this work. +// +// Used mainly to convert from old-style shifts (right) to new-style (left). +static constexpr int kReverseShift = -1; + // Make a local VectorMap typedef allowing to map a float array // as a Eigen vector expression. The std::conditional here is to // construct the suitable Eigen type for the constness of the @@ -140,6 +147,45 @@ MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data, return MatrixMap(data, rows, cols); } +// This is like the template-parameter version, except that the power-of-two is +// passed as a function parameter. The template version is to be preferred, +// since some target hardware optimizations depend on the range of the exponent. +template +IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { + if (exponent == 0) { + return x; + } + using ScalarIntegerType = + typename gemmlowp::FixedPointRawTypeTraits::ScalarRawType; + const IntegerType min = + gemmlowp::Dup(std::numeric_limits::min()); + const IntegerType max = + gemmlowp::Dup(std::numeric_limits::max()); + const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); + + const std::int32_t threshold = + ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); + const IntegerType positive_mask = + gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup(threshold)); + const IntegerType negative_mask = + gemmlowp::MaskIfLessThan(x, gemmlowp::Dup(-threshold)); + + IntegerType result = gemmlowp::ShiftLeft(x, exponent); + result = gemmlowp::SelectUsingMask(positive_mask, max, result); + result = gemmlowp::SelectUsingMask(negative_mask, min, result); + return result; +} + +// This is like the template-parameter version, except that the power-of-two is +// passed as a function parameter. See raw-integer version for further comments. +template +gemmlowp::FixedPoint +SaturatingRoundingMultiplyByPOTParam( + gemmlowp::FixedPoint a, int exponent) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); +} + // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE // BROADCASTING. // @@ -1036,10 +1082,10 @@ struct GemmlowpOutputPipeline { gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8> Pipeline; - static Pipeline Make(const int32* bias_data, int output_rows, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max) { + static Pipeline MakeExp(const int32* bias_data, int output_rows, + int32 output_offset, int32 output_multiplier, + int output_left_shift, int32 output_activation_min, + int32 output_activation_max) { ColVectorMap bias_vector(bias_data, output_rows); gemmlowp::OutputStageBiasAddition bias_addition_stage; bias_addition_stage.bias_vector = bias_vector; @@ -1047,7 +1093,7 @@ struct GemmlowpOutputPipeline { quantize_down_stage; quantize_down_stage.result_offset_after_shift = output_offset; quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; - quantize_down_stage.result_shift = output_shift; + quantize_down_stage.result_shift = -output_left_shift; gemmlowp::OutputStageClamp clamp_stage; clamp_stage.min = output_activation_min; clamp_stage.max = output_activation_max; @@ -1100,8 +1146,8 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, input_data, filter_cols, batches, filter_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, batches, output_rows); - const auto& output_pipeline = GemmlowpOutputPipeline::Make( - bias_data, output_rows, output_offset, output_multiplier, output_shift, + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); gemmlowp::GemmWithOutputPipeline( @@ -1730,6 +1776,100 @@ inline void ExtractPatchIntoBufferColumn( } } +template +void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, + const Dims<4>& filter_dims, int stride_width, + int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + const Dims<4>& output_dims, uint8 byte_zero, + T* im2col_data) { + // For dilated convolution, the input pixels are not contiguous therefore we + // can't use the same opitimizations as Im2Col(). Though note this code would + // work fine for the non-dilated case too (though likely a bit slower). + gemmlowp::ScopedProfilingLabel label("DilatedIm2col"); + TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK(im2col_data); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + MatchingArraySize(output_dims, 0, filter_dims, 3); + + // Construct the MxN sized im2col matrix. + // The rows M, are sub-ordered B x H x W + Dims<4> row_dims; + row_dims.sizes[0] = output_width; + row_dims.sizes[1] = output_height; + row_dims.sizes[2] = batches; + row_dims.sizes[3] = 1; + ComputeStrides(&row_dims); + + // The columns, N, are sub-ordered Kh x Kw x Din + Dims<4> col_dims; + col_dims.sizes[0] = input_depth; + col_dims.sizes[1] = filter_width; + col_dims.sizes[2] = filter_height; + col_dims.sizes[3] = 1; + ComputeStrides(&col_dims); + + // Use dimensions M and N to construct dims for indexing directly into im2col + Dims<4> im2col_dims; + im2col_dims.sizes[0] = col_dims.strides[3]; + im2col_dims.sizes[1] = row_dims.strides[3]; + im2col_dims.sizes[2] = 1; + im2col_dims.sizes[3] = 1; + ComputeStrides(&im2col_dims); + + // Loop through the output rows (B x H x W) + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + // Each row is an output pixel. Arrange the input data into this row in + // an order we can conveniently multiply with the filter data. + int row_offset = Offset(row_dims, out_x, out_y, batch, 0); + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Loop through all the pixels of the filter (Kh x Kw) + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int in_y = in_y_origin + dilation_height_factor * filter_y; + if ((in_y >= 0) && (in_y < input_height)) { + // Filter row is within the input data. + // Loop through all the filter pixels in this row. + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + dilation_width_factor * filter_x; + int col_offset = Offset(col_dims, 0, filter_x, filter_y, 0); + T* dst = im2col_data + + Offset(im2col_dims, col_offset, row_offset, 0, 0); + if ((in_x >= 0) && (in_x < input_width)) { + // Filter pixel is within the input, copy the data. + T const* src = + input_data + Offset(input_dims, 0, in_x, in_y, batch); + memcpy(dst, src, input_depth * sizeof(T)); + } else { + // Filter pixel is outside the input, zero it out. + memset(dst, byte_zero, input_depth * sizeof(T)); + } + } + } else { + // Filter row is outside the input, zero out the entire im2col row. + int col_offset = Offset(col_dims, 0, 0, filter_y, 0); + T* dst = + im2col_data + Offset(im2col_dims, col_offset, row_offset, 0, 0); + memset(dst, byte_zero, filter_width * input_depth * sizeof(T)); + } + } + } + } + } +} + template void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int kheight, @@ -1770,74 +1910,6 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, kwidth, byte_zero, output_data, output_dims); } -inline void DilatedConv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, - int dilation_width_factor, int dilation_height_factor, - int pad_width, int pad_height, - float output_activation_min, - float output_activation_max, float* output_data, - const Dims<4>& output_dims, float* im2col_data, - const Dims<4>& im2col_dims) { - gemmlowp::ScopedProfilingLabel label("DilatedConv"); - // This is a copy of the reference Conv implementation. We do not currently - // have an optimized path for dilation. - (void)im2col_data; // only used in optimized code. - (void)im2col_dims; // only used in optimized code. - const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); - const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); - if (bias_data) { - TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0)); - } - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); - const int filter_height = ArraySize(filter_dims, 2); - const int filter_width = ArraySize(filter_dims, 1); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - for (int batch = 0; batch < batches; ++batch) { - for (int out_y = 0; out_y < output_height; ++out_y) { - for (int out_x = 0; out_x < output_width; ++out_x) { - for (int out_channel = 0; out_channel < output_depth; ++out_channel) { - const int in_x_origin = (out_x * stride_width) - pad_width; - const int in_y_origin = (out_y * stride_height) - pad_height; - float total = 0.f; - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { - for (int filter_x = 0; filter_x < filter_width; ++filter_x) { - for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - const int in_x = in_x_origin + dilation_width_factor * filter_x; - const int in_y = - in_y_origin + dilation_height_factor * filter_y; - // If the location is outside the bounds of the input image, - // use zero as a default value. - if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && - (in_y < input_height)) { - float input_value = input_data[Offset(input_dims, in_channel, - in_x, in_y, batch)]; - float filter_value = - filter_data[Offset(filter_dims, in_channel, filter_x, - filter_y, out_channel)]; - total += (input_value * filter_value); - } - } - } - } - float bias_value = 0.0f; - if (bias_data) { - bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; - } - output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = - ActivationFunctionWithMinMax(total + bias_value, - output_activation_min, - output_activation_max); - } - } - } - } -} - inline void Conv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, const float* bias_data, const Dims<4>& bias_dims, @@ -1846,29 +1918,32 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims, float* im2col_data, const Dims<4>& im2col_dims) { - if ((dilation_width_factor != 1) || (dilation_height_factor != 1)) { - return DilatedConv(input_data, input_dims, filter_data, filter_dims, - bias_data, bias_dims, stride_width, stride_height, - dilation_width_factor, dilation_height_factor, pad_width, - pad_height, output_activation_min, output_activation_max, - output_data, output_dims, im2col_data, im2col_dims); - } - (void)im2col_data; (void)im2col_dims; gemmlowp::ScopedProfilingLabel label("Conv"); + // A float set to 0x00000000h == 0.0f + const uint8 float_zero_byte = 0x00; const float* gemm_input_data = nullptr; const Dims<4>* gemm_input_dims = nullptr; const int filter_width = ArraySize(filter_dims, 1); const int filter_height = ArraySize(filter_dims, 2); + const bool need_dilated_im2col = + dilation_width_factor != 1 || dilation_height_factor != 1; const bool need_im2col = stride_width != 1 || stride_height != 1 || filter_width != 1 || filter_height != 1; - if (need_im2col) { + if (need_dilated_im2col) { + DilatedIm2col(input_data, input_dims, filter_dims, stride_width, + stride_height, dilation_width_factor, dilation_height_factor, + pad_width, pad_height, output_dims, float_zero_byte, + im2col_data); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else if (need_im2col) { TFLITE_DCHECK(im2col_data); Im2col(input_data, input_dims, stride_width, stride_height, pad_width, - pad_height, filter_height, filter_width, 0, im2col_data, - im2col_dims); + pad_height, filter_height, filter_width, float_zero_byte, + im2col_data, im2col_dims); gemm_input_data = im2col_data; gemm_input_dims = &im2col_dims; } else { @@ -1979,11 +2054,23 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, } const int gemm_input_rows = gemm_input_dims->sizes[0]; - const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0); + // Using FlatSizeSkipDim causes segfault in some contexts (see b/79927784). + // The root cause has not yet been identified though. Same applies below for + // the other calls commented out. This is a partial rollback of cl/196819423. + // const int gemm_input_cols = FlatSizeSkipDim(*gemm_input_dims, 0); + const int gemm_input_cols = gemm_input_dims->sizes[1] * + gemm_input_dims->sizes[2] * + gemm_input_dims->sizes[3]; const int filter_rows = filter_dims.sizes[3]; - const int filter_cols = FlatSizeSkipDim(filter_dims, 3); + // See b/79927784. + // const int filter_cols = FlatSizeSkipDim(filter_dims, 3); + const int filter_cols = + filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; const int output_rows = output_dims.sizes[0]; - const int output_cols = FlatSizeSkipDim(output_dims, 0); + // See b/79927784. + // const int output_cols = FlatSizeSkipDim(output_dims, 0); + const int output_cols = + output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; TFLITE_DCHECK_EQ(output_rows, filter_rows); TFLITE_DCHECK_EQ(output_cols, gemm_input_cols); TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows); @@ -1997,8 +2084,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, gemm_input_data, gemm_input_rows, gemm_input_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, output_cols); - const auto& output_pipeline = GemmlowpOutputPipeline::Make( - bias_data, output_rows, output_offset, output_multiplier, output_shift, + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); gemmlowp::GemmWithOutputPipeline( @@ -2155,8 +2242,8 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, input_data, filter_cols, output_cols, filter_cols); gemmlowp::MatrixMap output_matrix( output_data, output_rows, output_cols, output_rows); - const auto& output_pipeline = GemmlowpOutputPipeline::Make( - bias_data, output_rows, output_offset, output_multiplier, output_shift, + const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( + bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); gemmlowp::GemmWithOutputPipeline( @@ -2300,8 +2387,9 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims, } } -inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, - int* output_shift) { +inline void GetInvSqrtQuantizedMultiplierExp(int32 input, + int32* output_inv_sqrt, + int* output_shift) { *output_shift = 11; while (input >= (1 << 29)) { input /= 4; @@ -2343,6 +2431,7 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, *output_inv_sqrt <<= -*output_shift; *output_shift = 0; } + *output_shift *= kReverseShift; } inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, @@ -2353,24 +2442,27 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - TFLITE_DCHECK_EQ(outer_size, 1); - int32 square_l2_norm = 0; - for (int i = 0; i < depth; i++) { - int32 diff = input_data[i] - input_zero_point; - square_l2_norm += diff * diff; - } - int32 inv_l2norm_multiplier; - int inv_l2norm_shift; - GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, - &inv_l2norm_shift); - - for (int i = 0; i < depth; i++) { - int32 diff = input_data[i] - input_zero_point; - int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( - 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); - int32 unclamped_output_val = 128 + rescaled_diff; - int32 output_val = std::min(255, std::max(0, unclamped_output_val)); - output_data[i] = static_cast(output_val); + for (int i = 0; i < outer_size; ++i) { + int32 square_l2_norm = 0; + for (int c = 0; c < depth; c++) { + int32 diff = input_data[c] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int c = 0; c < depth; c++) { + int32 diff = *input_data - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + *output_data = static_cast(output_val); + ++input_data; + ++output_data; + } } } @@ -2506,14 +2598,19 @@ inline void AddElementwise(int size, int left_shift, const uint8* input1_data, const int32 input2_val = input2_offset + input2_data[i]; const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); - const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; - const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + - output_offset; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + + output_offset; const int32 clamped_output = std::min( output_activation_max, std::max(output_activation_min, raw_output)); output_data[i] = static_cast(clamped_output); @@ -2732,15 +2829,17 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -3081,9 +3180,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, const int32 input2_val = input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; const int32 unclamped_result = - output_offset + - MultiplyByQuantizedMultiplierSmallerThanOne( - input1_val * input2_val, output_multiplier, output_shift); + output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, output_multiplier, + kReverseShift * output_shift); const int32 clamped_output = std::min(output_activation_max, std::max(output_activation_min, unclamped_result)); @@ -3265,15 +3364,17 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sub = scaled_input1_val - scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sub, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sub, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -4556,6 +4657,119 @@ inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, } } +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1_impl( + gemmlowp::FixedPoint input_val) { + // assert(__builtin_clz(0u) >= std::numeric_limits::digits - 1); + // assert(__builtin_clz(0u) <= std::numeric_limits::digits); + using FixedPoint0 = gemmlowp::FixedPoint; + // The reason for accumulating the result with an extra bit of headroom is + // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * + // recip_denom will otherwise introduce an error. + static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; + using FixedPointAccum = gemmlowp::FixedPoint; + + const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1488522236, std::log(2.0)); + const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); + const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1518500250, std::sqrt(0.5)); + const FixedPoint0 one_quarter = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); + + const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1057819769, + 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); + + const FixedPointAccum shifted_quarter = + gemmlowp::Rescale(one_quarter); + + // Reinterpret the input value as Q0.31, because we will figure out the + // required shift "ourselves" instead of using, say, Rescale. + FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); + // z_a_pow_2 = input_integer_bits - z_a_headroom; + int z_a_headroom_plus_1 = __builtin_clz(static_cast(z_a.raw())); + FixedPoint0 r_a_tmp = + SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); + const int32 r_a_raw = + SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); + // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); + // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, + // InputIntegerBits - z_b_headroom - 0.25); + const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), + shifted_quarter); + + // z_b is treated like z_a, but premultiplying by sqrt(0.5). + FixedPoint0 z_b = z_a * sqrt_half; + int z_b_headroom = __builtin_clz(static_cast(z_b.raw())) - 1; + const int32 r_b_raw = + SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); + const FixedPointAccum z_b_pow_2_adj = SaturatingSub( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), + shifted_quarter); + + const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); + const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( + std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); + + const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); + FixedPoint0 q = r - sqrt_sqrt_half; + q = q + q; + + const FixedPoint0 common_sq = q * q; + const FixedPoint0 num = q * r + q * common_sq * alpha_n; + const FixedPoint0 denom_minus_one_0 = + p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; + const FixedPoint0 recip_denom = + one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); + + const FixedPointAccum num_scaled = gemmlowp::Rescale(num); + return gemmlowp::Rescale(z_pow_2_adj * log_2 + + num_scaled * recip_denom); +} + +// Minimum output bits to accommodate log of maximum input range. It actually +// does not matter if one considers, say, [-64,64] or [-64,64). +// +// For example, run this through Octave: +// [0:127; ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] +constexpr int min_log_x_output_bits(int input_bits) { + return input_bits > 90 + ? 7 + : input_bits > 44 + ? 6 + : input_bits > 21 + ? 5 + : input_bits > 10 + ? 4 + : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; +} + +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1( + gemmlowp::FixedPoint input_val) { + static_assert( + OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), + "Output integer bits must be sufficent to accommodate logs of inputs."); + return log_x_for_x_greater_than_or_equal_to_1_impl( + input_val); +} + // Currently just a copy of the reference code. inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, int32 input_multiplier, int32 input_left_shift, @@ -4601,13 +4815,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } - // TODO(b/77858996): Implement fixed-point log(). - // Not a fully-quantized implementation: floating-point log(). - const float float_log_sum_of_exps = - std::log(static_cast(sum_of_exps.raw()) / - (1 << (31 - kAccumulationIntegerBits))); - const int32 fixed_log_sum_of_exps = static_cast(TfLiteRound( - float_log_sum_of_exps * (1 << (31 - kScaledDiffIntegerBits)))); + const int32 fixed_log_sum_of_exps = + log_x_for_x_greater_than_or_equal_to_1( + sum_of_exps) + .raw(); // rescaled_diff_min is smallest representable in // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the @@ -4618,9 +4829,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, fixed_log_sum_of_exps + std::numeric_limits::lowest(); const int adjusted_diff_min = std::max(diff_min - 1, // Note use of > below instead of >= above. - MultiplyByQuantizedMultiplierSmallerThanOne( + MultiplyByQuantizedMultiplierSmallerThanOneExp( rescaled_diff_min, reverse_scaling_divisor, - reverse_scaling_right_shift)); + kReverseShift * reverse_scaling_right_shift)); for (int c = 0; c < depth; ++c) { int32 input_diff = static_cast(block_input_data[c]) - max_in_row; @@ -5513,6 +5724,46 @@ inline void ResizeBilinearGeneric(const float* input_data, } } +template +inline void ResizeBilinearGenericSmallChannel( + const T* input_data, const Dims<4>& input_dims, T* output_data, + const Dims<4>& output_dims, int32 batches, int32 input_height, + int32 input_width, int32 depth, int32 output_height, int32 output_width, + float height_scale, float width_scale) { + memset(output_data, 0, + batches * output_height * output_width * depth * sizeof(T)); + + T* output_ptr = &output_data[0]; + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(input_x); + int32 x1 = std::min(x0 + 1, input_width - 1); + + int32 input_offset[4] = { + Offset(input_dims, 0, x0, y0, b), Offset(input_dims, 0, x1, y0, b), + Offset(input_dims, 0, x0, y1, b), Offset(input_dims, 0, x1, y1, b)}; + float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)), + (1 - (input_y - y0)) * (input_x - x0), + (input_y - y0) * (1 - (input_x - x0)), + (input_y - y0) * (input_x - x0)}; + + for (int d = 0; d < depth; d++) { + const T* input_ptr = &input_data[d]; + *output_ptr++ = static_cast(input_ptr[input_offset[0]] * scale[0] + + input_ptr[input_offset[1]] * scale[1] + + input_ptr[input_offset[2]] * scale[2] + + input_ptr[input_offset[3]] * scale[3]); + } + } + } + } +} + inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, @@ -5553,6 +5804,41 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, } } +// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8 +// or int16 arithmetic. +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims, bool align_corners) { + gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + + float height_scale = + (align_corners && output_height > 1) + ? (static_cast(input_height - 1) / (output_height - 1)) + : (static_cast(input_height) / output_height); + + float width_scale = + (align_corners && output_width > 1) + ? (static_cast(input_width - 1) / (output_width - 1)) + : (static_cast(input_width) / output_width); + + ResizeBilinearGenericSmallChannel( + input_data, input_dims, output_data, output_dims, batches, input_height, + input_width, depth, output_height, output_width, height_scale, + width_scale); +} + // legacy, for compatibility with old checked-in code inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, @@ -5562,6 +5848,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, output_data, output_dims, /*align_corners=*/false); } +// legacy, for compatibility with old checked-in code +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, + output_data, output_dims, /*align_corners=*/false); +} + template inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, @@ -6080,8 +6375,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, // To optimize, start by using the conv code with transposed weights for the // case of stride_height = stride_width = 1. const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); const int input_height = ArraySize(input_dims, 2); const int input_width = ArraySize(input_dims, 1); const int filter_height = ArraySize(filter_dims, 2); @@ -6128,8 +6423,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, float input_value = input_data[Offset(input_dims, in_channel, in_x, in_y, batch)]; float filter_value = - filter_data[Offset(filter_dims, out_channel, filter_x, - filter_y, in_channel)]; + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] += input_value * filter_value; } diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index d570dadd86b4dc7c3abe341a4955320367330b9c..f14667090f5c3867c7992211272063239f3b92aa 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -127,6 +127,10 @@ void PortableZeroVector(float* vector, int v_size); // Limit a float input f between +abs_limit and -abs_limit. float PortableClip(float f, float abs_limit); +// Check if all entries of a vector are zero. +bool PortableIsZeroVector(const float* vector, int v_size); +bool NeonIsZeroVector(const float* vector, int v_size); + // Symmetric quantizer. void PortableSymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index b0951aac8cbb98a181d9dcaef88770fadfc74f62..57ee859115cddbcbccae24ff639e848340d8e2ee 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -48,15 +48,15 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier, TFLITE_CHECK_GE(*left_shift, 0); } -void QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int* right_shift) { +void QuantizeMultiplierSmallerThanOneExp(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift) { TFLITE_CHECK_LT(double_multiplier, 1.); TFLITE_CHECK_GT(double_multiplier, 0.); int shift; QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift); TFLITE_CHECK_LE(shift, 0); - *right_shift = -shift; + *left_shift = shift; } void PreprocessSoftmaxScaling(double beta, double input_scale, @@ -78,20 +78,21 @@ void PreprocessSoftmaxScaling(double beta, double input_scale, quantized_multiplier, left_shift); } -void PreprocessLogSoftmaxScaling(double beta, double input_scale, - int input_integer_bits, - int32_t* quantized_multiplier, int* left_shift, - int32_t* reverse_scaling_divisor, - int* reverse_scaling_right_shift) { +void PreprocessLogSoftmaxScalingExp(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, + int* left_shift, + int32_t* reverse_scaling_divisor, + int* reverse_scaling_left_shift) { PreprocessSoftmaxScaling(beta, input_scale, input_integer_bits, quantized_multiplier, left_shift); // Also calculate what amounts to the inverse scaling factor for the input. const double real_reverse_scaling_divisor = (1 << (31 - *left_shift)) / static_cast(*quantized_multiplier); - tflite::QuantizeMultiplierSmallerThanOne(real_reverse_scaling_divisor, - reverse_scaling_divisor, - reverse_scaling_right_shift); + tflite::QuantizeMultiplierSmallerThanOneExp(real_reverse_scaling_divisor, + reverse_scaling_divisor, + reverse_scaling_left_shift); } int CalculateInputRadius(int input_integer_bits, int input_left_shift) { diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index 4a217515f142b2451ebd61e423871b95cdc09748..182ee782c76fcccedc99327d47805b49bfb8580d 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -167,9 +167,9 @@ IntOut SafeCast(FloatIn x) { // this is intended as a RIGHT-shift. // // Restricted to the case where the multiplier < 1 (and non-negative). -void QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int* right_shift); +void QuantizeMultiplierSmallerThanOneExp(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift); // Decompose a double multiplier into a Q0.31 int32 representation of its // significand, and shift representation of its exponent. @@ -197,11 +197,12 @@ void PreprocessSoftmaxScaling(double beta, double input_scale, int input_integer_bits, int32_t* quantized_multiplier, int* left_shift); // Like PreprocessSoftmaxScaling, but inverse scaling factors also calculated. -void PreprocessLogSoftmaxScaling(double beta, double input_scale, - int input_integer_bits, - int32_t* quantized_multiplier, int* left_shift, - int32_t* reverse_scaling_divisor, - int* reverse_scaling_right_shift); +void PreprocessLogSoftmaxScalingExp(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, + int* left_shift, + int32_t* reverse_scaling_divisor, + int* reverse_scaling_left_shift); // Calculate the largest input that will result in a within-bounds intermediate // result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words, // it must not overflow before we reduce the value by multiplication by the diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index 2d74b3d3849812a2dc95fabcd680aa280c99ca55..94773b47d3817d7ed7240f74545ad04e7fa4bd52 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -196,21 +196,21 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) { EXPECT_DEATH(ChooseQuantizationParams(10.0, -30.0), ""); } -TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { +TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOneExp) { auto quantize = [](double d) { int32_t q; int s; - QuantizeMultiplierSmallerThanOne(d, &q, &s); + QuantizeMultiplierSmallerThanOneExp(d, &q, &s); return std::pair{q, s}; }; EXPECT_DEATH(quantize(-0.1), ""); EXPECT_DEATH(quantize(0.0), ""); - EXPECT_THAT(quantize(0.25), Pair(1073741824, 1)); + EXPECT_THAT(quantize(0.25), Pair(1073741824, -1)); // Around 0.5 we can see the change in exponent and how we try hard to // void hitting max int32. - EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, 1)); + EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, -1)); EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0)); EXPECT_THAT(quantize(0.50), Pair(1073741824, 0)); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h index e9b6baeaee87d22aef238410bc9f447509a81c47..d57739279f44ad2a9fff2bd4dca21047b8147f2a 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h @@ -76,8 +76,8 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOne( - acc, output_multiplier, output_shift); + acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, + -output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc index 2607adc0c18aeaa8dc2061e0e95a307205700a08..f8c6f341f7e61529bbbac592f9caf115f6121e0c 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -29,9 +29,18 @@ float PortableClip(float f, float abs_limit) { return result; } +bool PortableIsZeroVector(const float* vector, int v_size) { + for (int i = 0; i < v_size; ++i) { + if (*vector++ != 0.0f) return false; + } + return true; +} + void PortableSymmetricQuantizeFloats(const float* values, const int size, - int8_t* quantized_values, float* min, - float* max, float* scaling_factor) { + int8_t* quantized_values, + float* __restrict__ min, + float* __restrict__ max, + float* __restrict__ scaling_factor) { auto minmax = std::minmax_element(values, values + size); *min = *minmax.first; *max = *minmax.second; @@ -71,13 +80,14 @@ void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, void PortableMatrixBatchVectorMultiplyAccumulate( const int8_t* __restrict__ matrix, const int m_rows, const int m_cols, - const int8_t* __restrict__ vectors, const float* scaling_factors, - int n_batch, float* __restrict__ result, int result_stride) { + const int8_t* __restrict__ vectors, + const float* __restrict__ scaling_factors, int n_batch, + float* __restrict__ result, int result_stride) { int batch, row, col; for (batch = 0; batch < n_batch; ++batch, vectors += m_cols) { const float batch_scaling_factor_inv = 1.0 / scaling_factors[batch]; // Get the address of the first row. - int8_t* row_ptr = (int8_t*)matrix; // NOLINT + const int8_t* row_ptr = matrix; for (row = 0; row < m_rows; ++row, result += result_stride) { // Initialize the dot product sum for the row to 0. int32_t dotprod = 0; diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index 1757a9f5e5299401fb9fef1d38870bd4e63ba3c1..d2e1fecd25cf3d11d3daffcc566dc1d5df97128c 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -25,6 +25,8 @@ namespace tensor_utils { // Limit a float input f between +abs_limit and -abs_limit. float PortableClip(float f, float abs_limit); +bool PortableIsZeroVector(const float* vector, int v_size); + void PortableSymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, float* max, float* scaling_factor); @@ -112,6 +114,10 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector, float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } +bool IsZeroVector(const float* vector, int v_size) { + return PortableIsZeroVector(vector, v_size); +} + void SymmetricQuantizeFloats(const float* values, const int size, int8_t* quantized_values, float* min, float* max, float* scaling_factor) { diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 15fac237edae79ea8d778b6e28694acb8de48668..07c215ebcb6f65b33df7086a2ffc1c3d1f04c605 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -33,8 +33,131 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/internal/types.h" namespace tflite { + +// TODO(b/77858996): Add these to gemmlowp. +template +IntegerType SaturatingAddNonGemmlowp(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + return a; +} + +template <> +inline std::int32_t SaturatingAddNonGemmlowp(std::int32_t a, std::int32_t b) { + std::int64_t a64 = a; + std::int64_t b64 = b; + std::int64_t sum = a64 + b64; + return static_cast(std::min( + static_cast(std::numeric_limits::max()), + std::max( + static_cast(std::numeric_limits::min()), + sum))); +} + +template +gemmlowp::FixedPoint SaturatingAddNonGemmlowp( + gemmlowp::FixedPoint a, + gemmlowp::FixedPoint b) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingAddNonGemmlowp(a.raw(), b.raw())); +} + +template +IntegerType SaturatingSub(IntegerType a, IntegerType b) { + static_assert(std::is_same::value, "unimplemented"); + return a; +} + +template <> +inline std::int16_t SaturatingSub(std::int16_t a, std::int16_t b) { + std::int32_t a32 = a; + std::int32_t b32 = b; + std::int32_t diff = a32 - b32; + return static_cast(std::min(32767, std::max(-32768, diff))); +} + +template <> +inline std::int32_t SaturatingSub(std::int32_t a, std::int32_t b) { + std::int64_t a64 = a; + std::int64_t b64 = b; + std::int64_t diff = a64 - b64; + return static_cast(std::min( + static_cast(std::numeric_limits::max()), + std::max( + static_cast(std::numeric_limits::min()), + diff))); +} + +template +gemmlowp::FixedPoint SaturatingSub( + gemmlowp::FixedPoint a, + gemmlowp::FixedPoint b) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingSub(a.raw(), b.raw())); +} +// End section to be moved to gemmlowp. + namespace reference_ops { +// TODO(b/80247582) Remove this constant. +// This will be phased out as the shifts are revised with more thought. Use of a +// constant enables us to track progress on this work. +// +// Used mainly to convert from old-style shifts (right) to new-style (left). +static constexpr int kReverseShift = -1; + +template +int CountLeadingZeros(T integer_input) { + static_assert(std::is_unsigned::value, + "Only unsigned integer types handled."); + if (integer_input == 0) { + return std::numeric_limits::digits; + } + const T one_in_leading_positive = static_cast(1) + << (std::numeric_limits::digits - 1); + int leading_zeros = 0; + while (integer_input < one_in_leading_positive) { + integer_input <<= 1; + ++leading_zeros; + } + return leading_zeros; +} + +template +IntegerType SaturatingRoundingMultiplyByPOTParam(IntegerType x, int exponent) { + if (exponent == 0) { + return x; + } + using ScalarIntegerType = + typename gemmlowp::FixedPointRawTypeTraits::ScalarRawType; + const IntegerType min = + gemmlowp::Dup(std::numeric_limits::min()); + const IntegerType max = + gemmlowp::Dup(std::numeric_limits::max()); + const int ScalarIntegerTypeBits = 8 * sizeof(ScalarIntegerType); + + const std::int32_t threshold = + ((1 << (ScalarIntegerTypeBits - 1 - exponent)) - 1); + const IntegerType positive_mask = + gemmlowp::MaskIfGreaterThan(x, gemmlowp::Dup(threshold)); + const IntegerType negative_mask = + gemmlowp::MaskIfLessThan(x, gemmlowp::Dup(-threshold)); + + IntegerType result = gemmlowp::ShiftLeft(x, exponent); + result = gemmlowp::SelectUsingMask(positive_mask, max, result); + result = gemmlowp::SelectUsingMask(negative_mask, min, result); + return result; +} + +// If we want to leave IntegerBits fixed, then multiplication +// by a power of two has to be saturating/rounding, not exact anymore. +template +gemmlowp::FixedPoint +SaturatingRoundingMultiplyByPOTParam( + gemmlowp::FixedPoint a, int exponent) { + return gemmlowp::FixedPoint::FromRaw( + SaturatingRoundingMultiplyByPOTParam(a.raw(), exponent)); +} + // DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE // BROADCASTING. // @@ -291,8 +414,8 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOne( - acc, output_multiplier, output_shift); + acc = MultiplyByQuantizedMultiplierSmallerThanOneExp( + acc, output_multiplier, kReverseShift * output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); @@ -515,8 +638,8 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, if (bias_data) { acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; } - acc = MultiplyByQuantizedMultiplierSmallerThanOne(acc, output_multiplier, - output_shift); + acc = MultiplyByQuantizedMultiplierSmallerThanOneExp( + acc, output_multiplier, kReverseShift * output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); @@ -845,8 +968,9 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims, } } -inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, - int* output_shift) { +inline void GetInvSqrtQuantizedMultiplierExp(int32 input, + int32* output_inv_sqrt, + int* output_shift) { *output_shift = 11; while (input >= (1 << 29)) { input /= 4; @@ -888,6 +1012,7 @@ inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, *output_inv_sqrt <<= -*output_shift; *output_shift = 0; } + *output_shift *= kReverseShift; } inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, @@ -895,25 +1020,28 @@ inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, const Dims<4>& output_dims) { const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims); - TFLITE_DCHECK_EQ(outer_size, 1); - int32 square_l2_norm = 0; - for (int i = 0; i < depth; i++) { - int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; - square_l2_norm += diff * diff; - } - int32 inv_l2norm_multiplier; - int inv_l2norm_shift; - GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, - &inv_l2norm_shift); - - for (int i = 0; i < depth; i++) { - int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; - int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( - 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); - int32 unclamped_output_val = 128 + rescaled_diff; - int32 output_val = std::min(255, std::max(0, unclamped_output_val)); - output_data[Offset(output_dims, i, 0, 0, 0)] = - static_cast(output_val); + for (int i = 0; i < outer_size; ++i) { + int32 square_l2_norm = 0; + for (int c = 0; c < depth; c++) { + int32 diff = + input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplierExp(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int c = 0; c < depth; c++) { + int32 diff = + input_data[Offset(input_dims, c, i, 0, 0)] - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOneExp( + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + output_data[Offset(output_dims, c, i, 0, 0)] = + static_cast(output_val); + } } } @@ -979,15 +1107,17 @@ inline void Add(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1133,15 +1263,17 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1186,15 +1318,17 @@ inline void BroadcastAddFivefold( const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sum = scaled_input1_val + scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sum, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sum, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -1374,9 +1508,9 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, const int32 input2_val = input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; const int32 unclamped_result = - output_offset + - MultiplyByQuantizedMultiplierSmallerThanOne( - input1_val * input2_val, output_multiplier, output_shift); + output_offset + MultiplyByQuantizedMultiplierSmallerThanOneExp( + input1_val * input2_val, output_multiplier, + kReverseShift * output_shift); const int32 clamped_output = std::min(output_activation_max, std::max(output_activation_min, unclamped_result)); @@ -1590,15 +1724,17 @@ inline void BroadcastSub(int left_shift, const uint8* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); const int32 raw_sub = scaled_input1_val - scaled_input2_val; const int32 raw_output = - MultiplyByQuantizedMultiplierSmallerThanOne( - raw_sub, output_multiplier, output_shift) + + MultiplyByQuantizedMultiplierSmallerThanOneExp( + raw_sub, output_multiplier, kReverseShift * output_shift) + output_offset; const int32 clamped_output = std::min(output_activation_max, @@ -2639,6 +2775,121 @@ inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, } } +// Although currently the name of this function says that it cannot handle +// values less than 1, in practice it can handle as low as 1/x_max, where +// x_max is the largest representable input. In other words, the output range +// is symmetric. +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1_impl( + gemmlowp::FixedPoint input_val) { + using FixedPoint0 = gemmlowp::FixedPoint; + // The reason for accumulating the result with an extra bit of headroom is + // that z_pow_2_adj * log_2 might be saturated, and adding num_scaled * + // recip_denom will otherwise introduce an error. + static constexpr int kAccumIntegerBits = OutputIntegerBits + 1; + using FixedPointAccum = gemmlowp::FixedPoint; + + const FixedPoint0 log_2 = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1488522236, std::log(2.0)); + const FixedPoint0 sqrt_sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1805811301, std::sqrt(std::sqrt(0.5))); + const FixedPoint0 sqrt_half = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1518500250, std::sqrt(0.5)); + const FixedPoint0 one_quarter = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(FixedPoint0, 536870912, 1.0 / 4.0); + + const FixedPoint0 alpha_n = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 117049297, 11.0 / 240.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_d = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 127690142, 1.0 / 20.0 * std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_i = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 1057819769, + 2.0 / std::sqrt(std::sqrt(2.0)) - std::sqrt(std::sqrt(2.0))); + const FixedPoint0 alpha_f = GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT( + FixedPoint0, 638450708, 1.0 / 4.0 * std::sqrt(std::sqrt(2.0))); + + const FixedPointAccum shifted_quarter = + gemmlowp::Rescale(one_quarter); + + // Reinterpret the input value as Q0.31, because we will figure out the + // required shift "ourselves" instead of using, say, Rescale. + FixedPoint0 z_a = FixedPoint0::FromRaw(input_val.raw()); + // z_a_pow_2 = input_integer_bits - z_a_headroom; + int z_a_headroom_plus_1 = CountLeadingZeros(static_cast(z_a.raw())); + FixedPoint0 r_a_tmp = + SaturatingRoundingMultiplyByPOTParam(z_a, (z_a_headroom_plus_1 - 1)); + const int32 r_a_raw = + SaturatingRoundingMultiplyByPOTParam((r_a_tmp * sqrt_half).raw(), 1); + // z_pow_2_adj = max(z_pow_2_a - 0.75, z_pow_2_b - 0.25); + // z_pow_2_adj = max(InputIntegerBits - z_a_headroom_plus_1 + 0.25, + // InputIntegerBits - z_b_headroom - 0.25); + const FixedPointAccum z_a_pow_2_adj = SaturatingAddNonGemmlowp( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_a_headroom_plus_1, 31 - kAccumIntegerBits)), + shifted_quarter); + + // z_b is treated like z_a, but premultiplying by sqrt(0.5). + FixedPoint0 z_b = z_a * sqrt_half; + int z_b_headroom = CountLeadingZeros(static_cast(z_b.raw())) - 1; + const int32 r_b_raw = + SaturatingRoundingMultiplyByPOTParam(z_a.raw(), z_b_headroom); + const FixedPointAccum z_b_pow_2_adj = SaturatingSub( + FixedPointAccum::FromRaw(SaturatingRoundingMultiplyByPOTParam( + InputIntegerBits - z_b_headroom, 31 - kAccumIntegerBits)), + shifted_quarter); + + const FixedPoint0 r = FixedPoint0::FromRaw(std::min(r_a_raw, r_b_raw)); + const FixedPointAccum z_pow_2_adj = FixedPointAccum::FromRaw( + std::max(z_a_pow_2_adj.raw(), z_b_pow_2_adj.raw())); + + const FixedPoint0 p = gemmlowp::RoundingHalfSum(r, sqrt_sqrt_half); + FixedPoint0 q = r - sqrt_sqrt_half; + q = q + q; + + const FixedPoint0 common_sq = q * q; + const FixedPoint0 num = q * r + q * common_sq * alpha_n; + const FixedPoint0 denom_minus_one_0 = + p * (alpha_i + q + alpha_d * common_sq) + alpha_f * q; + const FixedPoint0 recip_denom = + one_over_one_plus_x_for_x_in_0_1(denom_minus_one_0); + + const FixedPointAccum num_scaled = gemmlowp::Rescale(num); + return gemmlowp::Rescale(z_pow_2_adj * log_2 + + num_scaled * recip_denom); +} + +// Minimum output bits to accommodate log of maximum input range. It actually +// does not matter if one considers, say, [-64,64] or [-64,64). +// +// For example, run this through Octave: +// [0:127; ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2)); ... +// ceil(log(abs( log(2.^(0:127))+1 ))/log(2))] +constexpr int min_log_x_output_bits(int input_bits) { + return input_bits > 90 + ? 7 + : input_bits > 44 + ? 6 + : input_bits > 21 + ? 5 + : input_bits > 10 + ? 4 + : input_bits > 4 ? 3 : input_bits > 1 ? 2 : 1; +} + +template +inline gemmlowp::FixedPoint +log_x_for_x_greater_than_or_equal_to_1( + gemmlowp::FixedPoint input_val) { + static_assert( + OutputIntegerBits >= min_log_x_output_bits(InputIntegerBits), + "Output integer bits must be sufficent to accommodate logs of inputs."); + return log_x_for_x_greater_than_or_equal_to_1_impl( + input_val); +} + inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, int32 input_multiplier, int32 input_left_shift, int32 reverse_scaling_divisor, @@ -2681,13 +2932,10 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } } - // TODO(b/77858996): Implement fixed-point log(). - // Not a fully-quantized implementation: floating-point log(). - const float float_log_sum_of_exps = - std::log(static_cast(sum_of_exps.raw()) / - (1 << (31 - kAccumulationIntegerBits))); - const int32 fixed_log_sum_of_exps = static_cast(TfLiteRound( - float_log_sum_of_exps * (1 << (31 - kScaledDiffIntegerBits)))); + const int32 fixed_log_sum_of_exps = + log_x_for_x_greater_than_or_equal_to_1( + sum_of_exps) + .raw(); // rescaled_diff_min is smallest representable in // Q(kScaledDiffIntegerBits).(31-kScaledDiffIntegerBits) plus the @@ -2698,9 +2946,9 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, fixed_log_sum_of_exps + std::numeric_limits::lowest(); const int adjusted_diff_min = std::max(diff_min - 1, // Note use of > below instead of >= above. - MultiplyByQuantizedMultiplierSmallerThanOne( + MultiplyByQuantizedMultiplierSmallerThanOneExp( rescaled_diff_min, reverse_scaling_divisor, - reverse_scaling_right_shift)); + kReverseShift * reverse_scaling_right_shift)); for (int c = 0; c < depth; ++c) { int32 input_diff = @@ -2956,9 +3204,10 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims, } } -inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, +template +inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims, const int32* output_size_data, - const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_size_dims, T* output_data, const Dims<4>& output_dims, bool align_corners) { int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); int32 input_height = ArraySize(input_dims, 2); @@ -2990,15 +3239,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, int32 x0 = static_cast(std::floor(input_x)); int32 x1 = std::min(x0 + 1, input_width - 1); for (int c = 0; c < depth; ++c) { - float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] * - (1 - (input_y - y0)) * - (1 - (input_x - x0)) + - input_data[Offset(input_dims, c, x0, y1, b)] * - (input_y - y0) * (1 - (input_x - x0)) + - input_data[Offset(input_dims, c, x1, y0, b)] * - (1 - (input_y - y0)) * (input_x - x0) + - input_data[Offset(input_dims, c, x1, y1, b)] * - (input_y - y0) * (input_x - x0); + T interpolation = + static_cast(input_data[Offset(input_dims, c, x0, y0, b)] * + (1 - (input_y - y0)) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x0, y1, b)] * + (input_y - y0) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x1, y0, b)] * + (1 - (input_y - y0)) * (input_x - x0) + + input_data[Offset(input_dims, c, x1, y1, b)] * + (input_y - y0) * (input_x - x0)); output_data[Offset(output_dims, c, x, y, b)] = interpolation; } } @@ -3011,8 +3260,18 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, const Dims<4>& output_dims) { - ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, - output_data, output_dims, /*align_corners=*/false); + ResizeBilinear(input_data, input_dims, output_size_data, + output_size_dims, output_data, output_dims, + /*align_corners=*/false); +} + +inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, uint8* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, + output_size_dims, output_data, output_dims, + /*align_corners=*/false); } template @@ -3259,63 +3518,124 @@ inline void Exp(const T* input_data, const size_t num_elements, } } +// A generic reduce method that can be used for reduce_sum, reduce_mean, etc. +// It takes a reducer function as input and returns false when numeric overflow +// is detected. +// This method iterates through input data and reduce elements along the +// dimensions given in axis. +template +inline bool Reduce(const In* input_data, const int* input_dims, + const int* output_dims, const int input_num_dims, + const int output_num_dims, const int* axis, + const int num_axis, int* input_iter, + Out reducer(Out current, const In in, bool* overflow), + Out* output_data) { + // Reset input iterator. + TFLITE_DCHECK(input_num_dims > 0); + for (int idx = 0; idx < input_num_dims; ++idx) { + input_iter[idx] = 0; + } + // Iterate through input_data. + do { + size_t input_offset = + ReducedOutputOffset(input_num_dims, input_dims, input_iter, 0, nullptr); + size_t output_offset = ReducedOutputOffset(input_num_dims, input_dims, + input_iter, num_axis, axis); + bool overflow = false; + output_data[output_offset] = reducer(output_data[output_offset], + input_data[input_offset], &overflow); + if (overflow) return false; + } while (NextIndex(input_num_dims, input_dims, input_iter)); + return true; +} + +inline bool ResolveAxis(const int num_dims, const int* axis, const int num_axis, + int* out_axis, int* out_num_axis) { + *out_num_axis = 0; // Just in case. + // o(n^2) is fine since out_num_axis should be really small, mostly <= 4 + for (int idx = 0; idx < num_axis; ++idx) { + // Handle negative index. + int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx]; + TFLITE_DCHECK(current >= 0 && current < num_dims); + bool is_dup = false; + for (int j = 0; j < *out_num_axis; ++j) { + if (out_axis[j] == current) { + is_dup = true; + break; + } + } + if (!is_dup) { + out_axis[*out_num_axis] = current; + *out_num_axis += 1; + } + } + return true; +} + +// This method expects that output_data has been initialized. +template +inline bool ReduceSumImpl(const In* input_data, const int* input_dims, + const int* output_dims, const int input_num_dims, + const int output_num_dims, const int* axis, + const int num_axis, int* input_iter, + Out* output_data) { + auto reducer = [](Out current, const In in, bool* overflow) -> Out { + const Out actual_in = static_cast(in); + return current + actual_in; + }; + return Reduce(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, axis, num_axis, input_iter, reducer, + output_data); +} + +// Computes the mean of elements across dimensions given in axis. +// It does so in two stages, first calculates the sum of elements along the axis +// then divides it by the number of element in axis. template inline bool Mean(const T* input_data, const int* input_dims, const int input_num_dims, T* output_data, const int* output_dims, const int output_num_dims, const int* axis, const int num_axis_dimensions, bool keep_dims, int* temp_index, int* resolved_axis, U* temp_sum) { - // resets output data. + // Reset output data. size_t num_outputs = 1; for (int idx = 0; idx < output_num_dims; ++idx) { - num_outputs *= static_cast(output_dims[idx]); + size_t current = static_cast(output_dims[idx]); + // Overflow prevention. + if (num_outputs > std::numeric_limits::max() / current) { + return false; + } + num_outputs *= current; } for (size_t idx = 0; idx < num_outputs; ++idx) { output_data[idx] = T(); temp_sum[idx] = U(); } - // resets temp index. - for (int idx = 0; idx < input_num_dims; ++idx) { - temp_index[idx] = 0; - } - // resolves axis. + + // Resolve axis. int num_resolved_axis = 0; - for (int idx = 0; idx < num_axis_dimensions; ++idx) { - int current = axis[idx]; - TFLITE_DCHECK(current < input_num_dims && current + input_num_dims >= 0); - if (current < 0) { - current += input_num_dims; - } - bool is_dup = false; - for (int j = 0; j < num_resolved_axis; ++j) { - if (resolved_axis[j] == current) { - is_dup = true; - break; - } - } - if (!is_dup) { - resolved_axis[num_resolved_axis++] = current; - } + if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis, + &num_resolved_axis)) { + return false; } - // iterates through input_data. - for (bool has_next = true; has_next; - has_next = NextIndex(input_num_dims, input_dims, temp_index)) { - size_t input_offset = - ReducedOutputOffset(input_num_dims, input_dims, temp_index, 0, nullptr); - size_t output_offset = - ReducedOutputOffset(input_num_dims, input_dims, temp_index, - num_resolved_axis, resolved_axis); - temp_sum[output_offset] += static_cast(input_data[input_offset]); - } - // takes average by num of elements added to get mean. - size_t num_elements_in_axis = 1; + + if (!ReduceSumImpl(input_data, input_dims, output_dims, input_num_dims, + output_num_dims, resolved_axis, num_resolved_axis, + temp_index, temp_sum)) { + return false; + } + + // Calculate mean by dividing output_data by num of aggregated element. + U num_elements_in_axis = 1; for (int idx = 0; idx < num_resolved_axis; ++idx) { size_t current = static_cast(input_dims[resolved_axis[idx]]); + // Overflow prevention. if (current > (std::numeric_limits::max() / num_elements_in_axis)) { return false; } num_elements_in_axis *= current; } + if (num_elements_in_axis > 0) { for (size_t idx = 0; idx < num_outputs; ++idx) { output_data[idx] = @@ -3503,8 +3823,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, int pad_height, float* output_data, const Dims<4>& output_dims) { const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); - const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 3); - const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); const int input_height = ArraySize(input_dims, 2); const int input_width = ArraySize(input_dims, 1); const int filter_height = ArraySize(filter_dims, 2); @@ -3544,8 +3864,8 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, float input_value = input_data[Offset(input_dims, in_channel, in_x, in_y, batch)]; float filter_value = - filter_data[Offset(filter_dims, out_channel, filter_x, - filter_y, in_channel)]; + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] += input_value * filter_value; } @@ -3558,6 +3878,16 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims, } } +template +inline bool EqualFn(T lhs, T rhs) { + return lhs == rhs; +} + +template +inline bool NotEqualFn(T lhs, T rhs) { + return lhs != rhs; +} + template inline bool GreaterFn(T lhs, T rhs) { return lhs > rhs; @@ -3604,10 +3934,14 @@ inline void Comparison(int left_shift, const T* input1_data, const int32 input2_val = input2_offset + input2_data[i]; const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); - const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); - const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); output_data[i] = F(scaled_input1_val, scaled_input2_val); } } @@ -3656,11 +3990,13 @@ inline void BroadcastComparison(int left_shift, const T* input1_data, const int32 shifted_input1_val = input1_val * (1 << left_shift); const int32 shifted_input2_val = input2_val * (1 << left_shift); const int32 scaled_input1_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input1_val, input1_multiplier, input1_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input1_val, input1_multiplier, + kReverseShift * input1_shift); const int32 scaled_input2_val = - MultiplyByQuantizedMultiplierSmallerThanOne( - shifted_input2_val, input2_multiplier, input2_shift); + MultiplyByQuantizedMultiplierSmallerThanOneExp( + shifted_input2_val, input2_multiplier, + kReverseShift * input2_shift); output_data[Offset(output_dims, c, x, y, b)] = F(scaled_input1_val, scaled_input2_val); } @@ -3715,6 +4051,8 @@ inline void BroadcastComparison(int left_shift, const T* input1_data, input2_offset, input2_multiplier, \ input2_shift, output_data, output_dims); \ } +TFLITE_COMPARISON_OP(Equal); +TFLITE_COMPARISON_OP(NotEqual); TFLITE_COMPARISON_OP(Greater); TFLITE_COMPARISON_OP(GreaterEqual); TFLITE_COMPARISON_OP(Less); @@ -3754,6 +4092,42 @@ inline void RankOneSelect(const D* input_condition_data, } } +// For easy implementation, the indices is always a vector of size-4 vectors. +template +inline void SparseToDense(const std::vector>& indices, + const T* values, T default_value, T* output_data, + const Dims<4>& output_dims, bool value_is_scalar) { + const int value_count = indices.size(); + + // First fill the output_data with default value. + const int num_elements = FlatSize(output_dims); + for (int i = 0; i < num_elements; ++i) { + output_data[i] = default_value; + } + + // Special handle for value is scalar case to avoid checking the boolean + // condition within the loop every time. + if (value_is_scalar) { + for (int i = 0; i < value_count; ++i) { + const std::vector& index = indices[i]; + TFLITE_DCHECK_EQ(index.size(), 4); + const T value = *values; // just use the first value. + output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = + value; + } + return; + } + + // Go through the values and indices to fill the sparse values. + for (int i = 0; i < value_count; ++i) { + const std::vector& index = indices[i]; + TFLITE_DCHECK_EQ(index.size(), 4); + const T value = values[i]; + output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = + value; + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3d8765f11b2941ef5871c7db8e3582e506713aa6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc @@ -0,0 +1,136 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace { +template +void TestOneResizeBilinear(int batch, int depth, int input_width, + int input_height, int output_width, + int output_height, float error_threshold) { + Dims<4> input_dims_inference = + MakeDimsForInference(depth, input_width, input_height, batch); + Dims<4> output_dims_inference = + MakeDimsForInference(depth, output_width, output_height, batch); + + const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference); + const int output_buffer_size = + RequiredBufferSizeForDims(output_dims_inference); + + std::vector input_data(input_buffer_size, 0); + std::vector reference_output_data(output_buffer_size, 0); + // Initialize the output data with something other than zero, so we can catch + // issue with kernels failing to initialize the output. + std::vector output_data(output_buffer_size, 3); + + const T min_amplitude = static_cast(0); + const T max_amplitude = static_cast(255); + FillRandom(&input_data, min_amplitude, max_amplitude); + + Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1); + std::vector output_size_data = {output_height, output_width}; + + reference_ops::ResizeBilinear( + input_data.data(), input_dims_inference, output_size_data.data(), + output_size_dims, reference_output_data.data(), output_dims_inference); + optimized_ops::ResizeBilinear(input_data.data(), input_dims_inference, + output_size_data.data(), output_size_dims, + output_data.data(), output_dims_inference); + + double sum_diff = 0; + float max_abs_val = 0; + for (int i = 0; i < output_buffer_size; i++) { + sum_diff += std::abs(static_cast(output_data[i]) - + static_cast(reference_output_data[i])); + max_abs_val = std::max( + max_abs_val, std::abs(static_cast(reference_output_data[i]))); + } + + if (sum_diff != 0.f) { + const float mean_diff = static_cast(sum_diff / output_buffer_size); + const float relative_error = std::abs(mean_diff) / max_abs_val; + ASSERT_LT(relative_error, error_threshold); + } +} + +TEST(ResizeBilinear, TestResizeBilinear8Bit) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 0.025); + } +} + +TEST(ResizeBilinear2x2, TestResizeBilinear8Bit) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = input_width * 2; + const int output_height = input_height * 2; + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); + } +} + +TEST(ResizeBilinear, TestResizeBilinear) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); + } +} + +TEST(ResizeBilinear2x2, TestResizeBilinear) { + const int kTestsToRun = 100 * 1000; + for (int i = 0; i < kTestsToRun; i++) { + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50); + const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200); + const int output_width = input_width * 2; + const int output_height = input_height * 2; + + TestOneResizeBilinear(batch, depth, input_width, input_height, + output_width, output_height, 1e-5); + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d781a7b642036f3c5ddaa366f257fe26511c83c3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc @@ -0,0 +1,227 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" + +namespace tflite { +namespace { + +void RunSoftmaxFloatReference(const uint8* input_data, + const Dims<4>& dims_common, int32 input_offset, + const double input_scale, int stride, float beta, + uint8* reference_output_data) { + const int ref_buffer_size = RequiredBufferSizeForDims(dims_common); + std::vector reference_dequant_data(ref_buffer_size); + std::vector reference_output_float_data(ref_buffer_size); + + // Reference data generated via Dequant of input into float, and then applying + // float Softmax. + reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale, + reference_dequant_data.data(), dims_common); + optimized_ops::Softmax(reference_dequant_data.data(), dims_common, beta, + reference_output_float_data.data(), dims_common); + // Work with quantized scaling for Softmax, under which 256 represents 1, but + // we limit this to 255. + for (int i = 0; i < ref_buffer_size; i++) { + reference_output_data[i] = std::min( + 255, + static_cast(std::round(256.0f * reference_output_float_data[i]))); + } +} + +void CheckOutputData(const uint8* test_output, const uint8* reference_output, + const Dims<4>& dims_common, const string& check_label, + bool be_exacting) { + const int buffer_size = RequiredBufferSizeForDims(dims_common); + // While calculating some metrics in floating point, we work with quantized + // scaling. + std::vector diff(buffer_size); + int64_t sum_diff = 0; + int64_t sum_abs_diff = 0; + for (int i = 0; i < buffer_size; i++) { + diff[i] = static_cast(test_output[i]) - reference_output[i]; + sum_diff += diff[i]; + sum_abs_diff += std::abs(diff[i]); + } + // These stats help understand test failures. + std::sort(std::begin(diff), std::end(diff)); + const int min_diff = diff.front(); + const int max_diff = diff.back(); + const int median_diff = diff[diff.size() / 2]; + const float mean_diff = static_cast(sum_diff) / buffer_size; + const float mean_abs_diff = static_cast(sum_abs_diff) / buffer_size; + // We either check for bit exactness (against the reference quantized version) + // or for general accuracy, allowing off-by-one (against the float reference). + if (be_exacting) { + ASSERT_TRUE(std::abs(min_diff) == 0 && std::abs(max_diff) == 0); + } else { + // For small numbers of samples, the estimates of the means vary more. + // Rather than widen the tolerances, we skip the smaller tests. + ASSERT_TRUE(((std::abs(mean_diff) < 2e-2f && mean_abs_diff < 3e-2f) || + buffer_size < 10000) && + std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 && + std::abs(max_diff) <= 1); + } +} + +// Runs the Softmax and compares against the float reference implementation and +// the quantized reference implementation. +void RunOneSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, + int32 input_offset, const double input_scale, int stride, + float beta) { + const int buffer_size = RequiredBufferSizeForDims(dims_common); + std::vector optimized_softmax_output(buffer_size); + std::vector reference_float_softmax_output(buffer_size); + std::vector reference_quant_softmax_output(buffer_size); + + RunSoftmaxFloatReference(input_data, dims_common, input_offset, input_scale, + stride, beta, reference_float_softmax_output.data()); + + int32 input_beta_multiplier; + int input_beta_left_shift; + static const int kScaledDiffIntegerBits = 5; + tflite::PreprocessSoftmaxScaling(beta, input_scale, kScaledDiffIntegerBits, + &input_beta_multiplier, + &input_beta_left_shift); + // diff_min has a negative value, and is used to limit the maximum magnitude + // of the diffs, which are <= 0. + const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, + input_beta_left_shift); + + optimized_ops::Softmax(input_data, dims_common, input_beta_multiplier, + input_beta_left_shift, diff_min, + optimized_softmax_output.data(), dims_common); + reference_ops::Softmax(input_data, dims_common, input_beta_multiplier, + input_beta_left_shift, diff_min, + reference_quant_softmax_output.data(), dims_common); + + CheckOutputData(optimized_softmax_output.data(), + reference_float_softmax_output.data(), dims_common, + "Optimized vs float reference", false); + CheckOutputData(optimized_softmax_output.data(), + reference_quant_softmax_output.data(), dims_common, + "Optimized vs quant reference", true); + CheckOutputData(reference_quant_softmax_output.data(), + reference_float_softmax_output.data(), dims_common, + "Quant reference vs float reference", false); +} + +// This function picks some random Softmax params, which are checked for +// desirability. If not acceptable, it returns false. If they're OK, +// it runs the Softmax test and returns true. This allows the caller +// to loop until a test has been run. +// +// Currently we do not reject for any reason. +bool TryOneUniformSoftmax() { + // We pick mostly positive values, on the whole emphasizing smaller values and + // therefore faster tests. We test a wider range of depths. In the case of + // Softmax, the width and height really just create test repetitions. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = ExponentialRandomPositiveInt(0.75f, 175, 500); + const int input_width = ExponentialRandomPositiveInt(0.8f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); + const int32 input_offset = UniformRandomInt(-256, 0); + const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); + + Dims<4> dims_common = + MakeDimsForInference(input_depth, input_width, input_height, batch); + const int buffer_size = RequiredBufferSizeForDims(dims_common); + + std::vector input_data(buffer_size); + FillRandom(&input_data); + RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale, + stride, beta); + return true; +} + +// See TryOneUniformSoftmax() for a general description. +// +// Tests with "skyscraper" input patterns are included for two reasons. (a) +// Bimodal distributions are potentially challenging and perhaps more +// realistic than simple uniform random inputs. (b) Some implementations of +// Softmax may adapt as they traverse the depth, and so we test handling of +// cases where relatively small values are encountered at the beginning and end. +bool TryOneSkyscraperSoftmax(bool small_depth) { + // We pick mostly positive values, on the whole emphasizing smaller values and + // therefore faster tests. We test a wider range of depths. In the case of + // Softmax, the width and height really just create test repetitions. + const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20); + const int input_depth = small_depth + ? ExponentialRandomPositiveInt(0.75f, 40, 500) + : ExponentialRandomPositiveInt(0.75f, 175, 500); + const int input_width = ExponentialRandomPositiveInt(0.7f, 20, 200); + const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200); + const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); + const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); + const int32 input_offset = UniformRandomInt(-256, 0); + const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); + // Extra parameters for skyscraper input patterns. + const double middle_proportion = + ExponentialRandomPositiveFloat(0.65f, 0.1, 1.0); + const int middle_min = UniformRandomInt(0, 255); + const int sides_max = UniformRandomInt(0, middle_min); + + Dims<4> dims_common = + MakeDimsForInference(input_depth, input_width, input_height, batch); + const int buffer_size = RequiredBufferSizeForDims(dims_common); + + std::vector input_data(buffer_size); + FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, + sides_max); + RunOneSoftmaxTest(input_data.data(), dims_common, input_offset, input_scale, + stride, beta); + return true; +} + +TEST(TestQuantizedSoftmax, UniformSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneUniformSoftmax()) { + } + } +} + +TEST(TestQuantizedSoftmax, SkyscraperSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneSkyscraperSoftmax(false)) { + } + } +} + +TEST(TestQuantizedSoftmax, SmallSkyscraperSoftmaxTests) { + const int kTestsToRun = 1000; + for (int i = 0; i < kTestsToRun; i++) { + while (!TryOneSkyscraperSoftmax(true)) { + } + } +} +} // namespace +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index e1c9ccd84b09fdca0241192ef4dff62ded096433..5160e22307ae0894fabd0e9c4f7b9cd38b00840e 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -23,6 +23,9 @@ namespace tensor_utils { // Limit a float input f between +abs_limit and -abs_limit. float Clip(float f, float abs_limit); +// Checks if all entries of vector are zero. +bool IsZeroVector(const float* vector, int v_size); + // Quantizes a buffer of floating point values using a symmetric quantization // (i.e. linear quantization without an offset) to 8-bit signed integers. // It also outputs the range (min, max) of the floating point buffer, and the diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc index 3d8a2eada0c30132b5646327da5cd9fd80ccd39f..14ee528394b6872d9e79969db0e431658277f56b 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -32,6 +32,25 @@ TEST(uKernels, ClipTest) { {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0}))); } +TEST(uKernels, IsZeroTest) { + constexpr int kVectorSize = 21; + static float zeros[kVectorSize] = {0.0}; + EXPECT_TRUE(IsZeroVector(zeros, kVectorSize)); + + static float nonzeros[kVectorSize] = { + 1e-6, 1e-7, 1e-8, 1e-9, 1e-10, 1e-11, 1e-12, + 1e-13, 1e-14, 1e-15, 1e-16, 1e-17, 1e-18, 1e-19, + 1e-20, 1e-21, 1e-22, 1e-23, 1e-24, 1e-25, 1e-26}; + EXPECT_FALSE(IsZeroVector(nonzeros, kVectorSize)); +} + +TEST(uKernels, GeneratedIsZeroTest) { + constexpr int kVectorSize = 39; + std::vector input(kVectorSize); + ZeroVector(input.data(), kVectorSize); + EXPECT_TRUE(IsZeroVector(input.data(), kVectorSize)); +} + TEST(uKernels, SymmetricQuantizeFloatsTest) { constexpr int kVectorSize = 9; static float input[kVectorSize] = {-640, -635.0, -630, 10.0, 2.0, diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.cc b/tensorflow/contrib/lite/kernels/internal/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b1fd9b344d99103ee2a1b5b95fd697ccf4a64d0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/test_util.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/test_util.h" + +#include +#include + +namespace tflite { + +Dims<4> MakeDimsForInference(int depth, int width, int height, int batch) { + Dims<4> result; + int cum_prod = 1; + + result.sizes[0] = depth; + result.strides[0] = cum_prod; + cum_prod *= result.sizes[0]; + + result.sizes[1] = width; + result.strides[1] = cum_prod; + cum_prod *= result.sizes[1]; + + result.sizes[2] = height; + result.strides[2] = cum_prod; + cum_prod *= result.sizes[2]; + + result.sizes[3] = batch; + result.strides[3] = cum_prod; + + return result; +} + +// this is a copied from an internal function in propagate_fixed_sizes.cc +bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width, + int filter_height, int stride, PaddingType padding_type, + Dims<4>* output_dims, int* pad_width, int* pad_height) { + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int batch = ArraySize(input_dims, 3); + + int output_height = 0; + int output_width = 0; + if (padding_type == PaddingType::kValid) { + output_height = (input_height + stride - filter_height) / stride; + output_width = (input_width + stride - filter_width) / stride; + } else if (padding_type == PaddingType::kSame) { + output_height = (input_height + stride - 1) / stride; + output_width = (input_width + stride - 1) / stride; + } else { + return false; + } + + if (output_width <= 0 || output_height <= 0) { + return false; + } + + *pad_height = + ((output_height - 1) * stride + filter_height - input_height) / 2; + *pad_width = ((output_width - 1) * stride + filter_width - input_width) / 2; + *output_dims = + MakeDimsForInference(output_depth, output_width, output_height, batch); + return true; +} + +std::mt19937& RandomEngine() { + static std::mt19937 engine; + return engine; +} + +int UniformRandomInt(int min, int max) { + std::uniform_int_distribution dist(min, max); + return dist(RandomEngine()); +} + +float UniformRandomFloat(float min, float max) { + std::uniform_real_distribution dist(min, max); + return dist(RandomEngine()); +} + +int ExponentialRandomPositiveInt(float percentile, int percentile_val, + int max_val) { + const float lambda = + -std::log(1.f - percentile) / static_cast(percentile_val); + std::exponential_distribution dist(lambda); + float val; + do { + val = dist(RandomEngine()); + } while (!val || !std::isfinite(val) || val > max_val); + return static_cast(std::ceil(val)); +} + +float ExponentialRandomPositiveFloat(float percentile, float percentile_val, + float max_val) { + const float lambda = + -std::log(1.f - percentile) / static_cast(percentile_val); + std::exponential_distribution dist(lambda); + float val; + do { + val = dist(RandomEngine()); + } while (!std::isfinite(val) || val > max_val); + return val; +} + +void FillRandom(std::vector* vec, float min, float max) { + std::uniform_real_distribution dist(min, max); + auto gen = std::bind(dist, RandomEngine()); + std::generate(std::begin(*vec), std::end(*vec), gen); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/test_util.h b/tensorflow/contrib/lite/kernels/internal/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..26078cef49a7868c64ff0095898eebe5e4de8751 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/test_util.h @@ -0,0 +1,104 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +// Creates a Dims struct from a set of dimensions. +Dims<4> MakeDimsForInference(int depth, int width, int height, int batch); + +// Computes output and padding dimensions. +bool ComputeConvSizes(Dims<4> input_dims, int output_depth, int filter_width, + int filter_height, int stride, PaddingType padding_type, + Dims<4>* output_dims, int* pad_width, int* pad_height); + +// Returns a mt19937 random engine. +std::mt19937& RandomEngine(); + +// Returns a random integer uniformly distributed between |min| and |max|. +int UniformRandomInt(int min, int max); + +// Returns a random float uniformly distributed between |min| and |max|. +float UniformRandomFloat(float min, float max); + +// Returns a random element in |v|. +template +const T& RandomElement(const std::vector& v) { + return v[UniformRandomInt(0, v.size() - 1)]; +} + +// Returns a random exponentially distributed integer. +int ExponentialRandomPositiveInt(float percentile, int percentile_val, + int max_val); + +// Returns a random exponentially distributed float. +float ExponentialRandomPositiveFloat(float percentile, float percentile_val, + float max_val); + +// Fills a vector with random floats between |min| and |max|. +void FillRandom(std::vector* vec, float min, float max); + +// Fills a vector with random numbers between |min| and |max|. +template +void FillRandom(std::vector* vec, T min, T max) { + std::uniform_int_distribution dist(min, max); + auto gen = std::bind(dist, RandomEngine()); + std::generate(std::begin(*vec), std::end(*vec), gen); +} + +// Fills a vector with random numbers. +template +void FillRandom(std::vector* vec) { + FillRandom(vec, std::numeric_limits::min(), std::numeric_limits::max()); +} + +template +void FillRandom(typename std::vector::iterator begin_it, + typename std::vector::iterator end_it, T min, T max) { + std::uniform_int_distribution dist(min, max); + auto gen = std::bind(dist, RandomEngine()); + std::generate(begin_it, end_it, gen); +} + +// Fill with a "skyscraper" pattern, in which there is a central section (across +// the depth) with higher values than the surround. +template +void FillRandomSkyscraper(std::vector* vec, int depth, + double middle_proportion, uint8 middle_min, + uint8 sides_max) { + for (auto base_it = std::begin(*vec); base_it != std::end(*vec); + base_it += depth) { + auto left_it = base_it + std::ceil(0.5 * depth * (1.0 - middle_proportion)); + auto right_it = + base_it + std::ceil(0.5 * depth * (1.0 + middle_proportion)); + FillRandom(base_it, left_it, std::numeric_limits::min(), sides_max); + FillRandom(left_it, right_it, middle_min, std::numeric_limits::max()); + FillRandom(right_it, base_it + depth, std::numeric_limits::min(), + sides_max); + } +} + +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 43c68832785ac87f51c298370b50dc722167dc7f..3ecef15271c6eb9dd9d6dd370377fddda2723fcf 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,11 +15,15 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ #define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#include +#include + #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" namespace tflite { enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; +enum class PaddingType { kNone, kSame, kValid }; // Quantization parameters, determining the mapping of quantized values // to real values (i.e. determining how quantized values are mathematically @@ -43,6 +47,121 @@ struct Dims { int strides[N]; }; +class RuntimeShape { + public: + // Shapes with dimensions up to 4 are stored directly in the structure, while + // larger shapes are separately allocated. + static constexpr int kMaxSmallSize = 4; + + RuntimeShape() : size_(0) {} + + explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) { + if (dimensions_count > kMaxSmallSize) { + dims_pointer_ = new int32[dimensions_count]; + } + } + + RuntimeShape(int dimensions_count, const int32* dims_data) : size_(0) { + ReplaceWith(dimensions_count, dims_data); + } + + ~RuntimeShape() { + if (size_ > kMaxSmallSize) { + delete[] dims_pointer_; + } + } + + inline int32 DimensionsCount() const { return size_; } + inline int32 Dims(int i) const { + TFLITE_DCHECK_GE(i, 0); + TFLITE_DCHECK_LT(i, size_); + return size_ > kMaxSmallSize ? dims_pointer_[i] : dims_[i]; + } + inline void SetDim(int i, int32 val) { + TFLITE_DCHECK_GE(i, 0); + TFLITE_DCHECK_LT(i, size_); + if (size_ > kMaxSmallSize) { + dims_pointer_[i] = val; + } else { + dims_[i] = val; + } + } + inline int32* DimsData() { + return size_ > kMaxSmallSize ? dims_pointer_ : dims_; + } + inline const int32* DimsData() const { + return size_ > kMaxSmallSize ? dims_pointer_ : dims_; + } + + inline void Resize(int dimensions_count) { + if (size_ > kMaxSmallSize) { + delete[] dims_pointer_; + } + size_ = dimensions_count; + if (dimensions_count > kMaxSmallSize) { + dims_pointer_ = new int32[dimensions_count]; + } + } + + inline void ReplaceWith(int dimensions_count, const int32* dims_data) { + Resize(dimensions_count); + int32* dst_dims = DimsData(); + std::memcpy(dst_dims, dims_data, dimensions_count * sizeof(int32)); + } + + template + inline void BuildFrom(const T& src_iterable) { + const int dimensions_count = + std::distance(src_iterable.begin(), src_iterable.end()); + Resize(dimensions_count); + int32* data = DimsData(); + for (auto it : src_iterable) { + *data = it; + ++data; + } + } + + inline void BuildFrom(const std::initializer_list init_list) { + BuildFrom>(init_list); + } + + // Returns the total count of elements, that is the size when flattened into a + // vector. + inline int FlatSize() const { + int buffer_size = 1; + const int* dims_data = DimsData(); + for (int i = 0; i < size_; i++) { + const int dim = dims_data[i]; + TFLITE_DCHECK_GE(dim, 1); + buffer_size *= dim; + } + return buffer_size; + } + + private: + int32 size_; + union { + int32 dims_[kMaxSmallSize]; + int32* dims_pointer_; + }; +}; + +// Converts inference-style shape to legacy tflite::Dims<4>. +inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) { + tflite::Dims<4> result; + const int dimensions_count = array_shape.DimensionsCount(); + TFLITE_CHECK_LE(dimensions_count, 4); + int cum_prod = 1; + for (int i = 0; i < 4; i++) { + const int new_dim = + (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1; + result.sizes[i] = new_dim; + result.strides[i] = cum_prod; + cum_prod *= new_dim; + } + return result; +} + // Gets next index to iterate through a multidimensional array. inline bool NextIndex(const int num_dims, const int* dims, int* current) { TFLITE_DCHECK_GT(num_dims, 0); @@ -259,6 +378,14 @@ bool IsPackedWithoutStrides(const Dims& dims) { return true; } +template +void ComputeStrides(Dims* dims) { + dims->strides[0] = 1; + for (int d = 1; d < N; d++) { + dims->strides[d] = dims->strides[d - 1] * dims->sizes[d - 1]; + } +} + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc index 239b533a17efaa25d632c140c17e9f35d01a80ef..184028427fb193aa99cf155961c16eda1298e326 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.cc +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -37,7 +37,6 @@ TfLiteStatus GetQuantizedConvolutionMultipler(TfLiteContext* context, TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <= 1e-6 * std::min(input_product_scale, bias_scale)); TF_LITE_ENSURE(context, input_product_scale >= 0); - TF_LITE_ENSURE(context, input_product_scale < output_scale); *multiplier = input_product_scale / output_scale; diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc index 11cc666bad61b5753d933b98510ee8d4a093644e..070ed60040997f18f7e8053acc9532adc2377400 100644 --- a/tensorflow/contrib/lite/kernels/l2norm_test.cc +++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc @@ -67,7 +67,7 @@ class L2NormOpModel : public SingleOpModel { int output_; }; -TEST(L2NormOpTest, SimpleTest) { +TEST(L2NormOpTest, SimpleFloatTest) { L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE); m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); @@ -76,7 +76,7 @@ TEST(L2NormOpTest, SimpleTest) { ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); } -TEST(L2NormOpTest, MultipleBatchesTest) { +TEST(L2NormOpTest, MultipleBatchFloatTest) { L2NormOpModel m({3, 1, 1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE); m.SetInput({ @@ -105,6 +105,32 @@ TEST(L2NormOpTest, SimpleUint8Test) { ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1))); } +TEST(L2NormOpTest, MultipleBatchUint8Test) { + L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE); + + m.QuantizeAndPopulate(m.input(), + { + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({ + 58, 166, 173, 205, 83, 134, // batch 1 + 58, 166, 173, 205, 83, 134, // batch 2 + 58, 166, 173, 205, 83, 134, // batch 3 + })); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3 + }, + 0.1))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc index 990b3da0554ebcb13f995fa281ed04f8c7c6d7ea..eb26a02455ce2afccaa081a72d93a9ceeca746cc 100644 --- a/tensorflow/contrib/lite/kernels/lstm.cc +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" #include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" @@ -34,6 +36,17 @@ namespace ops { namespace builtin { namespace lstm { +struct OpData { + // Which kernel type to use. Full kernel (18-inputs) or basic kernel + // (5-inputs). + TfLiteLSTMKernelType kernel_type; + // Only used by full kernel. + int scratch_tensor_index; +}; + +// For full inputs kernel (18-inputs). +namespace full { + // Input Tensors of size {n_batch, n_input} constexpr int kInputTensor = 0; @@ -71,20 +84,18 @@ constexpr int kCellStateTensor = 1; constexpr int kOutputTensor = 2; void* Init(TfLiteContext* context, const char* buffer, size_t length) { - auto* scratch_tensor_index = new int; - context->AddTensors(context, 1, scratch_tensor_index); - return scratch_tensor_index; -} - -void Free(TfLiteContext* context, void* buffer) { - delete reinterpret_cast(buffer); + auto* op_data = new OpData; + op_data->kernel_type = kTfLiteLSTMFullKernel; + context->AddTensors(context, /*tensors_to_add=*/7, + &op_data->scratch_tensor_index); + return op_data; } // Check that input tensor dimensions matches with each other. TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, int n_cell) { - auto* params = reinterpret_cast(node->builtin_data); + const auto* params = reinterpret_cast(node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -94,7 +105,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - if (input_to_input_weights) { + if (input_to_input_weights != nullptr) { TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); @@ -114,7 +125,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - if (recurrent_to_input_weights) { + if (recurrent_to_input_weights != nullptr) { TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], n_cell); @@ -204,7 +215,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* projection_weights = GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - if (projection_weights) { + if (projection_weights != nullptr) { TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); @@ -212,7 +223,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, const TfLiteTensor* projection_bias = GetOptionalInputTensor(context, node, kProjectionBiasTensor); - if (projection_bias) { + if (projection_bias != nullptr) { TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); } @@ -233,7 +244,7 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, // Allocate a temporary scratch tensor. Also check that the sizes of the input // tensors match each other. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - int* scratch_tensor_index = reinterpret_cast(node->user_data); + OpData* op_data = reinterpret_cast(node->user_data); // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); @@ -242,6 +253,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Inferring batch size, number of outputs and number of cells from the // input tensors. const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE(context, input->dims->size > 1); const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; @@ -286,86 +298,148 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, cell_state, cell_size)); - // Create a scratch buffer tensor. + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + // The weights are of consistent type, so it suffices to check one. + // TODO(mirkov): create a utility/macro for this check, so all Ops can use it. + const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 && + input->type == kTfLiteFloat32); + TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(1); - node->temporaries->data[0] = *scratch_tensor_index; + if (is_hybrid_op) { + node->temporaries = TfLiteIntArrayCreate(7); + } else { + node->temporaries = TfLiteIntArrayCreate(1); + } + node->temporaries->data[0] = op_data->scratch_tensor_index; + + // Create a scratch buffer tensor. TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); scratch_buffer->type = input->type; scratch_buffer->allocation_type = kTfLiteArenaRw; - // Mark state tensors as persistent tensors. - output_state->allocation_type = kTfLiteArenaRwPersistent; - cell_state->allocation_type = kTfLiteArenaRwPersistent; - const TfLiteTensor* input_to_input_weights = GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); const bool use_cifg = (input_to_input_weights == nullptr); + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; if (use_cifg) { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 3; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); } else { - TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); - scratch_buffer_size->data[0] = n_batch; // Reserving space for Input, Cell, Forget, Output gates scratch_buffer_size->data[1] = n_cell * 4; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, - scratch_buffer_size)); + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + + if (is_hybrid_op) { + // Allocate temporary tensors to store quantized values of input, + // output_state and cell_state tensors. + node->temporaries->data[1] = op_data->scratch_tensor_index + 1; + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + input_quantized->type = kTfLiteUInt8; + input_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { + TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, + input_quantized_size)); + } + node->temporaries->data[2] = op_data->scratch_tensor_index + 2; + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, /*index=*/2); + output_state_quantized->type = kTfLiteUInt8; + output_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(output_state_quantized->dims, + output_state->dims)) { + TfLiteIntArray* output_state_quantized_size = + TfLiteIntArrayCopy(output_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output_state_quantized, + output_state_quantized_size)); + } + node->temporaries->data[3] = op_data->scratch_tensor_index + 3; + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + cell_state_quantized->type = kTfLiteUInt8; + cell_state_quantized->allocation_type = kTfLiteArenaRw; + if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { + TfLiteIntArray* cell_state_quantized_size = + TfLiteIntArrayCopy(cell_state->dims); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state_quantized, + cell_state_quantized_size)); + } + + // Allocate temporary tensors to store scaling factors and product scaling + // factors. The latter is a convenience storage which allows to quantize + // a vector once (which produces the scaling factors) and multiply it with + // different matrices (which requires multiplying the scaling factors with + // the scaling factor of the matrix). + node->temporaries->data[4] = op_data->scratch_tensor_index + 4; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } + node->temporaries->data[5] = op_data->scratch_tensor_index + 5; + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + prod_scaling_factors->type = kTfLiteFloat32; + prod_scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); + prod_scaling_factors_size->data[0] = n_batch; + if (!TfLiteIntArrayEqual(prod_scaling_factors->dims, + prod_scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, prod_scaling_factors, + prod_scaling_factors_size)); + } + + // Allocate a temporary tensor to store the recovered cell weights. Since + // this is used for diagonal matrices, only need to store n_cell values. + node->temporaries->data[6] = op_data->scratch_tensor_index + 6; + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + recovered_cell_weights->type = kTfLiteFloat32; + recovered_cell_weights->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); + recovered_cell_weights_size->data[0] = n_cell; + if (!TfLiteIntArrayEqual(recovered_cell_weights->dims, + recovered_cell_weights_size)) { + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, recovered_cell_weights, + recovered_cell_weights_size)); + } } return kTfLiteOk; } // The LSTM Op engine. -TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - - const TfLiteTensor* input_to_input_weights = - GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); - const TfLiteTensor* input_to_forget_weights = - GetInput(context, node, kInputToForgetWeightsTensor); - const TfLiteTensor* input_to_cell_weights = - GetInput(context, node, kInputToCellWeightsTensor); - const TfLiteTensor* input_to_output_weights = - GetInput(context, node, kInputToOutputWeightsTensor); - - const TfLiteTensor* recurrent_to_input_weights = - GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); - const TfLiteTensor* recurrent_to_forget_weights = - GetInput(context, node, kRecurrentToForgetWeightsTensor); - const TfLiteTensor* recurrent_to_cell_weights = - GetInput(context, node, kRecurrentToCellWeightsTensor); - const TfLiteTensor* recurrent_to_output_weights = - GetInput(context, node, kRecurrentToOutputWeightsTensor); - - const TfLiteTensor* cell_to_input_weights = - GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); - const TfLiteTensor* cell_to_forget_weights = - GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); - const TfLiteTensor* cell_to_output_weights = - GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); - - const TfLiteTensor* input_gate_bias = - GetOptionalInputTensor(context, node, kInputGateBiasTensor); - const TfLiteTensor* forget_gate_bias = - GetInput(context, node, kForgetGateBiasTensor); - const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); - const TfLiteTensor* output_gate_bias = - GetInput(context, node, kOutputGateBiasTensor); - - const TfLiteTensor* projection_weights = - GetOptionalInputTensor(context, node, kProjectionWeightsTensor); - const TfLiteTensor* projection_bias = - GetOptionalInputTensor(context, node, kProjectionBiasTensor); - - TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); - TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - +TfLiteStatus EvalFloat( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { const int n_batch = input->dims->data[0]; const int n_input = input->dims->data[1]; // n_cell and n_output will be the same size when there is no projection. @@ -377,9 +451,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const bool use_cifg = (input_to_input_weights == nullptr); const bool use_peephole = (cell_to_output_weights != nullptr); - // Index the scratch buffers pointers to the global scratch buffer. - TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); - float* input_gate_scratch = nullptr; float* cell_scratch = nullptr; float* forget_gate_scratch = nullptr; @@ -447,6 +518,421 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights, + const TfLiteTensor* input_to_forget_weights, + const TfLiteTensor* input_to_cell_weights, + const TfLiteTensor* input_to_output_weights, + const TfLiteTensor* recurrent_to_input_weights, + const TfLiteTensor* recurrent_to_forget_weights, + const TfLiteTensor* recurrent_to_cell_weights, + const TfLiteTensor* recurrent_to_output_weights, + const TfLiteTensor* cell_to_input_weights, + const TfLiteTensor* cell_to_forget_weights, + const TfLiteTensor* cell_to_output_weights, + const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, + const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, + const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, + const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer, + TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, + TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, + TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized, + TfLiteTensor* output_state, TfLiteTensor* cell_state, + TfLiteTensor* output) { + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existence of only one to get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + int8_t* input_to_input_weights_ptr = nullptr; + float input_to_input_weights_scale = 1.0f; + int8_t* recurrent_to_input_weights_ptr = nullptr; + float recurrent_to_input_weights_scale = 1.0f; + float* input_gate_bias_ptr = nullptr; + if (!use_cifg) { + input_to_input_weights_ptr = + reinterpret_cast(input_to_input_weights->data.uint8); + recurrent_to_input_weights_ptr = + reinterpret_cast(recurrent_to_input_weights->data.uint8); + input_gate_bias_ptr = input_gate_bias->data.f; + input_to_input_weights_scale = input_to_input_weights->params.scale; + recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale; + } + + int8_t* cell_to_input_weights_ptr = nullptr; + int8_t* cell_to_forget_weights_ptr = nullptr; + int8_t* cell_to_output_weights_ptr = nullptr; + float cell_to_input_weights_scale = 1.0f; + float cell_to_forget_weights_scale = 1.0f; + float cell_to_output_weights_scale = 1.0f; + if (use_peephole) { + if (!use_cifg) { + cell_to_input_weights_ptr = + reinterpret_cast(cell_to_input_weights->data.uint8); + cell_to_input_weights_scale = cell_to_input_weights->params.scale; + } + cell_to_forget_weights_ptr = + reinterpret_cast(cell_to_forget_weights->data.uint8); + cell_to_output_weights_ptr = + reinterpret_cast(cell_to_output_weights->data.uint8); + cell_to_forget_weights_scale = cell_to_forget_weights->params.scale; + cell_to_output_weights_scale = cell_to_output_weights->params.scale; + } + + const int8_t* projection_weights_ptr = + (projection_weights == nullptr) + ? nullptr + : reinterpret_cast(projection_weights->data.uint8); + const float projection_weights_scale = + (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale; + const float* projection_bias_ptr = + (projection_bias == nullptr) ? nullptr : projection_bias->data.f; + + // Required tensors, pointers are non-null. + const float* input_ptr_batch = input->data.f; + const int8_t* input_to_forget_weights_ptr = + reinterpret_cast(input_to_forget_weights->data.uint8); + const float input_to_forget_weights_scale = + input_to_forget_weights->params.scale; + const int8_t* input_to_cell_weights_ptr = + reinterpret_cast(input_to_cell_weights->data.uint8); + const float input_to_cell_weights_scale = input_to_cell_weights->params.scale; + const int8_t* input_to_output_weights_ptr = + reinterpret_cast(input_to_output_weights->data.uint8); + const float input_to_output_weights_scale = + input_to_output_weights->params.scale; + const int8_t* recurrent_to_forget_weights_ptr = + reinterpret_cast(recurrent_to_forget_weights->data.uint8); + const float recurrent_to_forget_weights_scale = + recurrent_to_forget_weights->params.scale; + const int8_t* recurrent_to_cell_weights_ptr = + reinterpret_cast(recurrent_to_cell_weights->data.uint8); + const float recurrent_to_cell_weights_scale = + recurrent_to_cell_weights->params.scale; + const int8_t* recurrent_to_output_weights_ptr = + reinterpret_cast(recurrent_to_output_weights->data.uint8); + const float recurrent_to_output_weights_scale = + recurrent_to_output_weights->params.scale; + const float* forget_gate_bias_ptr = forget_gate_bias->data.f; + const float* cell_bias_ptr = cell_bias->data.f; + const float* output_gate_bias_ptr = output_gate_bias->data.f; + + float* output_state_ptr = output_state->data.f; + float* cell_state_ptr = cell_state->data.f; + float* output_ptr_batch = output->data.f; + + // Temporary storage for quantized values and scaling factors. + int8_t* quantized_input_ptr = + reinterpret_cast(input_quantized->data.uint8); + int8_t* quantized_output_state_ptr = + reinterpret_cast(output_state_quantized->data.uint8); + int8_t* quantized_cell_state_ptr = + reinterpret_cast(cell_state_quantized->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; + float* prod_scaling_factors_ptr = prod_scaling_factors->data.f; + float* recovered_cell_weights_ptr = recovered_cell_weights->data.f; + + kernel_utils::LstmStep( + input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale, + input_to_forget_weights_ptr, input_to_forget_weights_scale, + input_to_cell_weights_ptr, input_to_cell_weights_scale, + input_to_output_weights_ptr, input_to_output_weights_scale, + recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale, + recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale, + recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale, + recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale, + cell_to_input_weights_ptr, cell_to_input_weights_scale, + cell_to_forget_weights_ptr, cell_to_forget_weights_scale, + cell_to_output_weights_ptr, cell_to_output_weights_scale, + input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr, + output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale, + projection_bias_ptr, params, n_batch, n_cell, n_input, n_output, + input_gate_scratch, forget_gate_scratch, cell_scratch, + output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr, + recovered_cell_weights_ptr, quantized_input_ptr, + quantized_output_state_ptr, quantized_cell_state_ptr, output_state_ptr, + cell_state_ptr, output_ptr_batch); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* params = reinterpret_cast(node->builtin_data); + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + const TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + const TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + const TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + const TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + const TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + const TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + const TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + const TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + const TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + const TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + const TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + const TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + const TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + const TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + const TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(mirkov): add a check that weights are all uint8s or all floats. + switch (input_to_output_weights->type) { + case kTfLiteFloat32: { + return EvalFloat(input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, + cell_to_output_weights, input_gate_bias, + forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, + scratch_buffer, output_state, cell_state, output); + } + case kTfLiteUInt8: { + TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1); + TfLiteTensor* output_state_quantized = + GetTemporary(context, node, /*index=*/2); + TfLiteTensor* cell_state_quantized = + GetTemporary(context, node, /*index=*/3); + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4); + TfLiteTensor* prod_scaling_factors = + GetTemporary(context, node, /*index=*/5); + TfLiteTensor* recovered_cell_weights = + GetTemporary(context, node, /*index=*/6); + return EvalHybrid( + input, input_to_input_weights, input_to_forget_weights, + input_to_cell_weights, input_to_output_weights, + recurrent_to_input_weights, recurrent_to_forget_weights, + recurrent_to_cell_weights, recurrent_to_output_weights, + cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, + input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias, + projection_weights, projection_bias, params, scratch_buffer, + scaling_factors, prod_scaling_factors, recovered_cell_weights, + input_quantized, output_state_quantized, cell_state_quantized, + output_state, cell_state, output); + } + default: + context->ReportError(context, "Type %d is not currently supported.", + input_to_output_weights->type); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace full + +// For basic kernel (5-inputs). +namespace basic { + +enum InputTensor { + kInputData = 0, + kInputPrevActivation = 1, + kInputWeights = 2, + kInputBiases = 3, + kInputPrevState = 4, + kInputNum = 5, +}; + +enum OutputTensor { + kOutputActivation = 0, + kOutputState = 1, + kOutputConcatTemp = 2, + kOutputActivationTemp = 3, + kOutputNum = 4, +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* op_data = new OpData; + op_data->kernel_type = kTfLiteLSTMBasicKernel; + // `scratch_tensor_index` is unused in this kernel. + op_data->scratch_tensor_index = -1; + return op_data; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE(context, node->inputs->size == kInputNum); + TF_LITE_ENSURE(context, node->outputs->size == kOutputNum); + + // Only Float32 is supported currently. + // TODO(ycling): Implement quantize uint8 support. + for (int index = 0; index < node->inputs->size; ++index) { + TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; + TF_LITE_ENSURE_EQ(context, tensor->type, kTfLiteFloat32); + } + + const TfLiteTensor* input = GetInput(context, node, kInputData); + const TfLiteTensor* prev_activation = + GetInput(context, node, kInputPrevActivation); + const TfLiteTensor* weights = GetInput(context, node, kInputWeights); + const TfLiteTensor* bias = GetInput(context, node, kInputBiases); + const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); + + TF_LITE_ENSURE_EQ(context, input->dims->size, 2); + const int num_batches = input->dims->data[0]; + + TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2); + TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches); + + TF_LITE_ENSURE_EQ(context, weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); + + TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2); + TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches); + + TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); + TfLiteTensor* state_out = GetOutput(context, node, kOutputState); + TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); + TfLiteTensor* activation_temp = + GetOutput(context, node, kOutputActivationTemp); + + TF_LITE_ENSURE_OK(context, context->ResizeTensor( + context, activation_out, + TfLiteIntArrayCopy(prev_activation->dims))); + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, state_out, + TfLiteIntArrayCopy(prev_state->dims))); + TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2); + concat_temp_size->data[0] = num_batches; + concat_temp_size->data[1] = weights->dims->data[1]; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, concat_temp, concat_temp_size)); + TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2); + activation_temp_size->data[0] = num_batches; + activation_temp_size->data[1] = weights->dims->data[0]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp, + activation_temp_size)); + + // Set the state tensors as persistent. + for (auto index : {kInputPrevActivation, kInputPrevState}) { + TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; + tensor->allocation_type = kTfLiteArenaRwPersistent; + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputData); + const TfLiteTensor* prev_activation = + GetInput(context, node, kInputPrevActivation); + const TfLiteTensor* weights = GetInput(context, node, kInputWeights); + const TfLiteTensor* bias = GetInput(context, node, kInputBiases); + const TfLiteTensor* prev_state = GetInput(context, node, kInputPrevState); + + TfLiteTensor* activation_out = GetOutput(context, node, kOutputActivation); + TfLiteTensor* state_out = GetOutput(context, node, kOutputState); + TfLiteTensor* concat_temp = GetOutput(context, node, kOutputConcatTemp); + TfLiteTensor* activation_temp = + GetOutput(context, node, kOutputActivationTemp); + + optimized_ops::LstmCell( + // Inputs. + GetTensorData(input), GetTensorDims(input), + GetTensorData(prev_activation), GetTensorDims(prev_activation), + GetTensorData(weights), GetTensorDims(weights), + GetTensorData(bias), GetTensorDims(bias), + GetTensorData(prev_state), GetTensorDims(prev_state), + // Outputs. + GetTensorData(state_out), GetTensorDims(state_out), + GetTensorData(activation_out), GetTensorDims(activation_out), + GetTensorData(concat_temp), GetTensorDims(concat_temp), + GetTensorData(activation_temp), GetTensorDims(activation_temp)); + + // TODO(ycling): Investigate if this copy can be avoided with the 5-inputs + // LSTM kernel. + memcpy(prev_activation->data.raw, activation_out->data.raw, + activation_out->bytes); + memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes); + + return kTfLiteOk; +} + +} // namespace basic + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + const auto* params = reinterpret_cast(buffer); + switch (params->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Init(context, buffer, length); + case kTfLiteLSTMBasicKernel: + return basic::Init(context, buffer, length); + } +} +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + switch (op_data->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Prepare(context, node); + case kTfLiteLSTMBasicKernel: + return basic::Prepare(context, node); + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const auto* op_data = reinterpret_cast(node->user_data); + switch (op_data->kernel_type) { + case kTfLiteLSTMFullKernel: + return full::Eval(context, node); + case kTfLiteLSTMBasicKernel: + return basic::Eval(context, node); + } +} + } // namespace lstm TfLiteRegistration* Register_LSTM() { diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index d81220d8d30793616444c03e8647b0877a39a4d9..6da29a4a923f16f7b5ad382f51cfd820783504cd 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite LSTM op. -#include #include #include @@ -35,7 +34,8 @@ class LSTMOpModel : public SingleOpModel { LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, bool use_peephole, bool use_projection_weights, bool use_projection_bias, float cell_clip, float proj_clip, - const std::vector>& input_shapes) + const std::vector>& input_shapes, + const TensorType& weight_type = TensorType_FLOAT32) : n_batch_(n_batch), n_input_(n_input), n_cell_(n_cell), @@ -45,31 +45,31 @@ class LSTMOpModel : public SingleOpModel { if (use_cifg) { input_to_input_weights_ = AddNullInput(); } else { - input_to_input_weights_ = AddInput(TensorType_FLOAT32); + input_to_input_weights_ = AddInput(weight_type); } - input_to_forget_weights_ = AddInput(TensorType_FLOAT32); - input_to_cell_weights_ = AddInput(TensorType_FLOAT32); - input_to_output_weights_ = AddInput(TensorType_FLOAT32); + input_to_forget_weights_ = AddInput(weight_type); + input_to_cell_weights_ = AddInput(weight_type); + input_to_output_weights_ = AddInput(weight_type); if (use_cifg) { recurrent_to_input_weights_ = AddNullInput(); } else { - recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_input_weights_ = AddInput(weight_type); } - recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); - recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_forget_weights_ = AddInput(weight_type); + recurrent_to_cell_weights_ = AddInput(weight_type); + recurrent_to_output_weights_ = AddInput(weight_type); if (use_peephole) { if (use_cifg) { cell_to_input_weights_ = AddNullInput(); } else { - cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + cell_to_input_weights_ = AddInput(weight_type); } - cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); - cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + cell_to_forget_weights_ = AddInput(weight_type); + cell_to_output_weights_ = AddInput(weight_type); } else { cell_to_input_weights_ = AddNullInput(); cell_to_forget_weights_ = AddNullInput(); @@ -86,7 +86,7 @@ class LSTMOpModel : public SingleOpModel { output_gate_bias_ = AddInput(TensorType_FLOAT32); if (use_projection_weights) { - projection_weights_ = AddInput(TensorType_FLOAT32); + projection_weights_ = AddInput(weight_type); if (use_projection_bias) { projection_bias_ = AddInput(TensorType_FLOAT32); } else { @@ -192,8 +192,9 @@ class LSTMOpModel : public SingleOpModel { zero_buffer.get() + zero_buffer_size); } - void SetInput(int offset, float* begin, float* end) { - PopulateTensor(input_, offset, begin, end); + void SetInput(int offset, const float* begin, const float* end) { + PopulateTensor(input_, offset, const_cast(begin), + const_cast(end)); } std::vector GetOutput() { return ExtractVector(output_); } @@ -203,7 +204,7 @@ class LSTMOpModel : public SingleOpModel { int num_cells() { return n_cell_; } int num_batches() { return n_batch_; } - private: + protected: int input_; int input_to_input_weights_; int input_to_forget_weights_; @@ -237,7 +238,182 @@ class LSTMOpModel : public SingleOpModel { int n_output_; }; -TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { +class HybridLSTMOpModel : public LSTMOpModel { + public: + HybridLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + bool use_cifg, bool use_peephole, + bool use_projection_weights, bool use_projection_bias, + float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : LSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg, use_peephole, + use_projection_weights, use_projection_bias, cell_clip, + proj_clip, input_shapes, TensorType_UINT8) {} + + void SetInputToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(cell_to_output_weights_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + SymmetricQuantizeAndPopulate(projection_weights_, f); + } +}; + +class BaseLstmTest : public ::testing::Test { + protected: + // Weights of the LSTM model. Some are optional. + std::initializer_list input_to_input_weights_; + std::initializer_list input_to_cell_weights_; + std::initializer_list input_to_forget_weights_; + std::initializer_list input_to_output_weights_; + std::initializer_list input_gate_bias_; + std::initializer_list cell_gate_bias_; + std::initializer_list forget_gate_bias_; + std::initializer_list output_gate_bias_; + std::initializer_list recurrent_to_input_weights_; + std::initializer_list recurrent_to_cell_weights_; + std::initializer_list recurrent_to_forget_weights_; + std::initializer_list recurrent_to_output_weights_; + std::initializer_list cell_to_input_weights_; + std::initializer_list cell_to_forget_weights_; + std::initializer_list cell_to_output_weights_; + std::initializer_list projection_weights_; + + // LSTM input is stored as num_batch x num_inputs vector. + std::vector> lstm_input_; + // LSTM output is stored as num_batch x num_outputs vector. + std::vector> lstm_golden_output_; + + // Compares output up to tolerance to the result of the lstm given the input. + void VerifyGoldens(const std::vector>& input, + const std::vector>& output, + LSTMOpModel* lstm, float tolerance = 1e-5) { + const int num_batches = input.size(); + EXPECT_GT(num_batches, 0); + const int num_inputs = lstm->num_inputs(); + EXPECT_GT(num_inputs, 0); + const int input_sequence_size = input[0].size() / num_inputs; + EXPECT_GT(input_sequence_size, 0); + for (int i = 0; i < input_sequence_size; ++i) { + for (int b = 0; b < num_batches; ++b) { + const float* batch_start = input[b].data() + i * num_inputs; + const float* batch_end = batch_start + num_inputs; + + lstm->SetInput(b * lstm->num_inputs(), batch_start, batch_end); + } + + lstm->Invoke(); + + const int num_outputs = lstm->num_outputs(); + std::vector expected; + for (int b = 0; b < num_batches; ++b) { + const float* golden_start_batch = output[b].data() + i * num_outputs; + const float* golden_end_batch = golden_start_batch + num_outputs; + expected.insert(expected.end(), golden_start_batch, golden_end_batch); + } + EXPECT_THAT(lstm->GetOutput(), + ElementsAreArray(ArrayFloatNear(expected, tolerance))); + for (int i = 0; i < num_outputs; ++i) { + std::cout << lstm->GetOutput()[i] << ", "; + } + std::cout << std::endl; + for (int i = 0; i < num_outputs; ++i) { + std::cout << expected[i] << ", "; + } + std::cout << std::endl; + } + } +}; + +class NoCifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = {-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}; + input_to_cell_weights_ = {-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, -0.29909778}; + input_to_forget_weights_ = {0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}; + input_to_output_weights_ = {-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, + -0.1556896, 0.19487578}; + input_gate_bias_ = {0., 0., 0., 0.}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_input_weights_ = { + -0.0063535, -0.2042388, 0.31454784, -0.35746509, + 0.28902304, 0.08183324, -0.16555229, 0.02286911, + -0.13566875, 0.03034258, 0.48091322, -0.12528998, + 0.24077177, -0.51332325, -0.33502164, 0.10629296}; + + recurrent_to_cell_weights_ = { + -0.3407414, 0.24443203, -0.2078532, 0.26320225, + 0.05695659, -0.00123841, -0.4744786, -0.35869038, + -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}; + + recurrent_to_forget_weights_ = { + -0.48684245, -0.06655136, 0.42224967, 0.2112639, + 0.27654213, 0.20864892, -0.07646349, 0.45877004, + 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}; + + recurrent_to_output_weights_ = { + 0.43385774, -0.17194885, 0.2718237, 0.09215671, + 0.24107647, -0.39835793, 0.18212086, 0.01301402, + 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}}; + } +}; + +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -257,10 +433,10 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {n_cell, n_input}, // input_to_cell_weight tensor {n_cell, n_input}, // input_to_output_weight tensor - {n_cell, n_output}, // recurrent_to_input_weight tensor - {n_cell, n_output}, // recurrent_to_forget_weight tensor - {n_cell, n_output}, // recurrent_to_cell_weight tensor - {n_cell, n_output}, // recurrent_to_output_weight tensor + {n_cell, n_output}, // recurrent_to_input_weight_tensor + {n_cell, n_output}, // recurrent_to_forget_weight_tensor + {n_cell, n_output}, // recurrent_to_cell_weight_tensor + {n_cell, n_output}, // recurrent_to_output_weight_tensor {0}, // cell_to_input_weight tensor {0}, // cell_to_forget_weight tensor @@ -275,79 +451,137 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, - -0.34550029, 0.04266912, -0.15680569, - -0.34856534, 0.43890524}); - - lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, - -0.20583314, 0.44344562, 0.22077113, - -0.29909778}); - - lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, - -0.31343272, -0.40032279, 0.44781327, - 0.01387155, -0.35593212}); - - lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, - 0.40525138, 0.44272184, 0.03897077, -0.1556896, - 0.19487578}); + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetInputGateBias({0., 0., 0., 0.}); + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetCellBias({0., 0., 0., 0.}); + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - lstm.SetForgetGateBias({1., 1., 1., 1.}); - - lstm.SetOutputGateBias({0., 0., 0., 0.}); - - lstm.SetRecurrentToInputWeights( - {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, - -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, - -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); - - lstm.SetRecurrentToCellWeights( - {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, - -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, - -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - lstm.SetRecurrentToForgetWeights( - {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, - -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, - 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetRecurrentToOutputWeights( - {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, - 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, - -0.51818722, -0.15390486, 0.0468148, 0.39922136}); +TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, - -0.15358765, -0.03716109, 0.12507336, - 0.41193449, -0.20860538, -0.15053082, - 0.09120187, 0.24278517, -0.12222792}; + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, + /*tolerance=*/0.0157651); +} - lstm.SetInput(0, batch0_start, batch0_end); +class CifgNoPeepholeNoProjectionNoClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_cell_weights_ = {-0.49770179, -0.27711356, -0.09624726, + 0.05100781, 0.04717243, 0.48944736, + -0.38535351, -0.17212132}; - lstm.Invoke(); + input_to_forget_weights_ = {-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, + 0.24407166, 0.33826375}; - float* golden_start = lstm_golden_output + i * lstm.num_outputs(); - float* golden_end = golden_start + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + input_to_output_weights_ = {0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}; + cell_gate_bias_ = {0., 0., 0., 0.}; + forget_gate_bias_ = {1., 1., 1., 1.}; + output_gate_bias_ = {0., 0., 0., 0.}; + + recurrent_to_cell_weights_ = { + 0.54066205, -0.32668582, -0.43562764, -0.56094903, + 0.42957711, 0.01841056, -0.32764608, -0.33027974, + -0.10826075, 0.20675004, 0.19069612, -0.03026325, + -0.54532051, 0.33003211, 0.44901288, 0.21193194}; + + recurrent_to_forget_weights_ = { + -0.13832897, -0.0515101, -0.2359007, -0.16661474, + -0.14340827, 0.36986142, 0.23414481, 0.55899, + 0.10798943, -0.41174671, 0.17751795, -0.34484994, + -0.35874045, -0.11352962, 0.27268326, 0.54058349}; + + recurrent_to_output_weights_ = { + 0.41613156, 0.42610586, -0.16495961, -0.5663873, + 0.30579174, -0.05115908, -0.33941799, 0.23364776, + 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}; + + cell_to_forget_weights_ = {0.47485286, -0.51955009, -0.24458408, + 0.31544167}; + cell_to_output_weights_ = {-0.17135078, 0.82760304, 0.85573703, + -0.77109635}; + + lstm_input_ = {{2., 3., 3., 4., 1., 1.}}; + lstm_golden_output_ = {{-0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}}; } -} +}; -TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 1; const int n_input = 2; // n_cell and n_output have the same size when there is no projection. @@ -385,74 +619,689 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, - 0.04717243, 0.48944736, -0.38535351, - -0.17212132}); - - lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, - -0.3633365, -0.22755712, 0.28253698, 0.24407166, - 0.33826375}); - - lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, - -0.09426838, -0.44257352, 0.54939759, - 0.01533556, 0.42751634}); - - lstm.SetCellBias({0., 0., 0., 0.}); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); - lstm.SetForgetGateBias({1., 1., 1., 1.}); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); - lstm.SetOutputGateBias({0., 0., 0., 0.}); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); - lstm.SetRecurrentToCellWeights( - {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, - 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, - 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, - 0.21193194}); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); - lstm.SetRecurrentToForgetWeights( - {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, - 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, - -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - lstm.SetRecurrentToOutputWeights( - {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, - -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, - 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetCellToForgetWeights( - {0.47485286, -0.51955009, -0.24458408, 0.31544167}); - lstm.SetCellToOutputWeights( - {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); +TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; - static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; - static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, - -0.05163646, -0.42312205, -0.01218222, - 0.24201041, -0.08124574, -0.358325, - -0.04621704, 0.21641694, -0.06471302}; + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); - - lstm.SetInput(0, batch0_start, batch0_end); - - lstm.Invoke(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573); +} - float* golden_start = lstm_golden_output + i * lstm.num_outputs(); - float* golden_end = golden_start + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start, golden_end); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest { + void SetUp() override { + input_to_input_weights_ = { + 0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}; + + input_to_forget_weights_ = { + -0.0018401089, -0.004852237, 0.03698424, 0.014181704, + 0.028273236, -0.016726194, -0.05249759, -0.10204261, + 0.00861066, -0.040979505, -0.009899187, 0.01923892, + -0.028177269, -0.08535103, -0.14585495, 0.10662567, + -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, + 0.0814421, -0.12257899, -0.033945758, -0.031303465, + 0.045630626, 0.06843887, -0.13492945, -0.012480007, + -0.0811829, -0.07224499, -0.09628791, 0.045100946, + 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, + 0.052625068, 0.12784666, 0.07077897, 0.025725935, + 0.04165009, 0.07241905, 0.018668644, -0.037377294, + -0.06277783, -0.08833636, -0.040120605, -0.011405586, + -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, + 0.13396506, -0.08402166, -0.01901462, -0.044678304, + -0.07720565, 0.014350063, -0.11757958, -0.0652038, + -0.08185733, -0.076754324, -0.092614375, 0.10405491, + 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, + -0.054523353, 0.02582715, 0.02327355, -0.011857179, + -0.0011980024, -0.034641717, -0.026125094, -0.17582615, + -0.15923657, -0.27486774, -0.0006143371, 0.0001771948, + -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}; + + input_to_cell_weights_ = { + -0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}; + + input_to_output_weights_ = { + -0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}; + + input_gate_bias_ = {0.02234832, 0.14757581, 0.18176508, 0.10380666, + 0.053110216, -0.06928846, -0.13942584, -0.11816189, + 0.19483899, 0.03652339, -0.10250295, 0.036714908, + -0.18426876, 0.036065217, 0.21810818, 0.02383196, + -0.043370757, 0.08690144, -0.04444982, 0.00030581196}; + + forget_gate_bias_ = {0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}; + + cell_gate_bias_ = {-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}; + + output_gate_bias_ = {0.046159424, -0.0012809046, 0.03563469, 0.12648113, + 0.027195795, 0.35373217, -0.018957434, 0.008907322, + -0.0762701, 0.12018895, 0.04216877, 0.0022856654, + 0.040952638, 0.3147856, 0.08225149, -0.057416286, + -0.14995944, -0.008040261, 0.13208859, 0.029760877}; + + recurrent_to_input_weights_ = { + -0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}; + + recurrent_to_cell_weights_ = { + -0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}; + + recurrent_to_forget_weights_ = { + -0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}; + + recurrent_to_output_weights_ = { + 0.025825322, -0.05813119, 0.09495884, -0.045984812, + -0.01255415, -0.0026479573, -0.08196161, -0.054914974, + -0.0046604523, -0.029587349, -0.044576716, -0.07480124, + -0.082868785, 0.023254942, 0.027502948, -0.0039728214, + -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, + -0.08829125, -0.005139627, -0.08989442, -0.0555066, + 0.13596267, -0.025062224, -0.048351806, -0.03850004, + 0.07266485, -0.022414139, 0.05940088, 0.075114764, + 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, + 0.014416728, 0.043229222, 0.034178585, -0.07530371, + 0.035837382, -0.085607, -0.007721233, -0.03287832, + -0.043848954, -0.06404588, -0.06632928, -0.073643476, + 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, + -0.030063879, 0.008801774, -0.023021035, -0.019558564, + 0.05158114, -0.010947698, -0.011825728, 0.0075720972, + 0.0699727, -0.0039981045, 0.069350146, 0.08799282, + 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, + -0.00924166, 0.0046702605, -0.036598757, -0.08811812, + 0.10522024, -0.032441203, 0.008176899, -0.04454919, + 0.07058152, 0.0067963637, 0.039206743, 0.03259838, + 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, + 0.036879618, 0.043357447, 0.028362012, -0.05908629, + 0.0059240665, -0.04995891, -0.019187413, 0.0276265, + -0.01628143, 0.0025863599, 0.08800015, 0.035250366, + -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, + -0.009660886, 0.019076364, 0.018299393, -0.046004917, + 0.08891175, 0.0431396, -0.026327137, -0.051502608, + 0.08979574, -0.051670972, 0.04940282, -0.07491107, + -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, + -0.035936575, -0.011681591, 0.064818054, 0.0073146066, + -0.021745546, -0.043124277, -0.06471268, -0.07053354, + -0.029321948, -0.05330136, 0.016933719, -0.053782392, + 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, + -0.07924483, 0.06936997, 0.0034815092, -0.007305279, + -0.037325785, -0.07251102, -0.033633437, -0.08677009, + 0.091591336, -0.14165086, 0.021752775, 0.019683983, + 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, + 0.1183656, -0.0010731248, -0.023590032, -0.072285876, + -0.0724771, -0.026382286, -0.0014920527, 0.042667855, + 0.0018776858, 0.02986552, 0.009814309, 0.0733756, + 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, + -0.010036754, 0.02576849, -0.08307328, 0.010112348, + 0.042521734, -0.05869831, -0.071689695, 0.03876447, + -0.13275425, -0.0352966, -0.023077697, 0.10285965, + 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, + -0.08271222, -0.0030240538, -0.016368777, 0.1070414, + 0.042672627, 0.013456989, -0.0437609, -0.022309763, + 0.11576483, 0.04108048, 0.061026827, -0.0190714, + -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, + -0.023771819, -0.01965048, 0.007955471, -0.043740474, + 0.03346837, -0.10549954, 0.090567775, 0.042013682, + -0.03176985, 0.12569028, -0.02421228, -0.029526481, + 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, + -0.06861939, -0.021256343, -0.041093912, -0.06669611, + 0.035498552, 0.021757556, -0.09302526, -0.015403468, + -0.06614931, -0.051798206, -0.013874718, 0.03630673, + 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, + -0.020674974, -0.03944324, -0.008110165, -0.11113267, + 0.08484226, 0.043586485, 0.040582247, 0.0968012, + -0.065249965, -0.028036479, 0.0050708856, 0.0017462453, + 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, + -0.11768019, 0.085926116, -0.08251791, -0.045081906, + 0.0948852, 0.068401024, 0.024856757, 0.06978981, + -0.057309967, -0.012775832, -0.0032452994, 0.01977615, + -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }; + + cell_to_input_weights_ = { + 0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}; + + cell_to_forget_weights_ = { + -0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}; + + cell_to_output_weights_ = { + 0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}; + + projection_weights_ = { + -0.009802181, 0.09401916, 0.0717386, -0.13895074, + 0.09641832, 0.060420845, 0.08539281, 0.054285463, + 0.061395317, 0.034448683, -0.042991187, 0.019801661, + -0.16840284, -0.015726732, -0.23041931, -0.024478018, + -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, + 0.16169067, 0.22465782, -0.03993472, -0.004017731, + 0.08633481, -0.28869787, 0.08682067, 0.17240396, + 0.014975425, 0.056431185, 0.031037588, 0.16702051, + 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, + 0.014777949, -0.20203483, 0.094781205, 0.19100232, + 0.13987629, -0.036132768, -0.06426278, -0.05108664, + 0.13221376, 0.009441198, -0.16715929, 0.15859416, + -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, + 0.0046453946, 0.050794356, 0.10770313, -0.20790008, + -0.07149004, -0.11425117, 0.008225835, -0.035802525, + 0.14374903, 0.15262283, 0.048710253, 0.1847461, + -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, + 0.016261552, 0.022461696, 0.12689082, -0.043589946, + -0.12035478, -0.08361797, -0.050666027, -0.1248618, + -0.1275799, -0.071875185, 0.07377272, 0.09944291, + -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, + 0.041546922, -0.20424393, 0.06907816, 0.050412357, + 0.00724631, 0.039827548, 0.12449835, 0.10747581, + 0.13708383, 0.09134148, -0.12617786, -0.06428341, + 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, + 0.022913318, -0.042050496, 0.16842307, -0.060597885, + 0.10531834, -0.06411776, -0.07451711, -0.03410368, + -0.13393489, 0.06534304, 0.003620307, 0.04490757, + 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, + -0.024575593, -0.036445823, 0.07155557, 0.009672501, + -0.02328883, 0.009533515, -0.03606021, -0.07421458, + -0.028082801, -0.2678904, -0.13221288, 0.18419984, + -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, + 0.14494097, -0.12522776, -0.098633975, -0.10766018, + -0.08317623, 0.08594209, 0.07749552, 0.039474737, + 0.1776665, -0.07409566, -0.0477268, 0.29323658, + 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, + 0.10034707, 0.045594677, 0.0635285, -0.0715442, + -0.089667566, -0.10811871, 0.00026344223, 0.08298446, + -0.009525053, 0.006585689, -0.24567553, -0.09450807, + 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, + 0.067035615, 0.19271925, -0.0032889997, -0.043264326, + 0.09663576, -0.057112187, -0.10100678, 0.0628376, + 0.04447668, 0.017961001, -0.10094388, -0.10190601, + 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, + -0.1947263, 0.02251204, 0.11216432, -0.10307853, + 0.17351969, -0.039091777, 0.08066188, -0.00561982, + 0.12633002, 0.11335965, -0.0088127935, -0.019777594, + 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, + -0.07468996, -0.0855457, 0.099339016, -0.07580735, + -0.13775392, 0.08434318, 0.08330512, -0.12131499, + 0.031935584, 0.09180414, -0.08876437, -0.08049874, + 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, + 0.04331237, 0.04299654, -0.036394123, -0.12915532, + 0.09793732, 0.07512415, -0.11319543, -0.032502122, + 0.15661901, 0.07671967, -0.005491124, -0.19379048, + -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, + -0.09334311, 0.15026465, -0.15493552, -0.057762887, + -0.11604192, -0.262013, -0.01391798, 0.012185008, + 0.11156489, -0.07483202, 0.06693364, -0.26151478, + 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, + 0.030635102, 0.010969227, 0.11109743, 0.010919218, + 0.027526086, 0.13519906, 0.01891392, -0.046839405, + -0.040167913, 0.017953383, -0.09700955, 0.0061885654, + -0.07000971, 0.026893595, -0.038844477, 0.14543656}; + + lstm_input_ = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, // step 0 + 0.596268, 0.998386, 0.568695, 0.864524, 0.571277, // step 1 + 0.073204, 0.296072, 0.743333, 0.069199, 0.045348, // step 2 + 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, // step 3 + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, // step 0 + 0.642421, 0.524260, 0.134799, 0.003639, 0.162482, // step 1 + 0.640394, 0.930399, 0.050782, 0.432485, 0.988078, // step 2 + 0.082922, 0.563329, 0.865614, 0.333232, 0.259916} // step 3 + }; + + lstm_golden_output_ = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; } -} +}; -TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) { const int n_batch = 2; const int n_input = 5; const int n_cell = 20; @@ -489,588 +1338,98 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToInputWeights( - {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, - 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, - -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, - -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, - -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, - -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, - -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, - 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, - 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, - 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, - -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, - 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, - -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, - -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, - -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, - 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, - -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, - -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, - -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, - -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); - - lstm.SetInputToForgetWeights( - {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, - -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, - -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, - 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, - 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, - -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, - -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, - 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, - 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, - 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, - 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, - -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, - 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, - -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, - -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, - 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, - 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, - 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, - -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, - 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); - - lstm.SetInputToCellWeights( - {-0.04580283, -0.09549462, -0.032418985, -0.06454633, - -0.043528453, 0.043018587, -0.049152344, -0.12418144, - -0.078985475, -0.07596889, 0.019484362, -0.11434962, - -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, - -0.025034338, -0.0028890965, 0.048929527, 0.06235075, - 0.10665918, -0.032036792, -0.08505916, -0.10843358, - -0.13002433, -0.036816437, -0.02130134, -0.016518239, - 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, - -0.10652836, -0.1037554, -0.13056071, -0.03266643, - -0.033702414, -0.006473424, -0.04611692, 0.014419339, - -0.025174323, 0.0396852, 0.081777506, 0.06157468, - 0.10210095, -0.009658194, 0.046511717, 0.03603906, - 0.0069369148, 0.015960095, -0.06507666, 0.09551598, - 0.053568836, 0.06408714, 0.12835667, -0.008714329, - -0.20211966, -0.12093674, 0.029450472, 0.2849013, - -0.029227901, 0.1164364, -0.08560263, 0.09941786, - -0.036999565, -0.028842626, -0.0033637602, -0.017012902, - -0.09720865, -0.11193351, -0.029155117, -0.017936034, - -0.009768936, -0.04223324, -0.036159635, 0.06505112, - -0.021742892, -0.023377212, -0.07221364, -0.06430552, - 0.05453865, 0.091149814, 0.06387331, 0.007518393, - 0.055960953, 0.069779344, 0.046411168, 0.10509911, - 0.07463894, 0.0075130584, 0.012850982, 0.04555431, - 0.056955688, 0.06555285, 0.050801456, -0.009862683, - 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); - - lstm.SetInputToOutputWeights( - {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, - -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, - 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, - -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, - -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, - 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, - -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, - -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, - -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, - -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, - 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, - 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, - 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, - -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, - 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, - 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, - -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, - 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, - -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, - -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); - - lstm.SetInputGateBias( - {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, - -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, - -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, - 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); - - lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, - 0.11098921, 0.15378423, 0.09263801, 0.09790885, - 0.09508917, 0.061199076, 0.07665568, -0.015443159, - -0.03499149, 0.046190713, 0.08895977, 0.10899629, - 0.40694186, 0.06030037, 0.012413437, -0.06108739}); - - lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, - -0.1483596, -0.10639995, -0.091433935, 0.058573797, - -0.06809782, -0.07889636, -0.043246906, -0.09829136, - -0.4279842, 0.034901652, 0.18797937, 0.0075234566, - 0.016178843, 0.1749513, 0.13975595, 0.92058027}); - - lstm.SetOutputGateBias( - {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, - 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, - 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, - -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); - - lstm.SetRecurrentToInputWeights( - {-0.001374326, -0.078856036, 0.10672688, 0.029162422, - -0.11585556, 0.02557986, -0.13446963, -0.035785314, - -0.01244275, 0.025961924, -0.02337298, -0.044228926, - -0.055839065, -0.046598054, -0.010546039, -0.06900766, - 0.027239809, 0.022582639, -0.013296484, -0.05459212, - 0.08981, -0.045407712, 0.08682226, -0.06867011, - -0.14390695, -0.02916037, 0.000996957, 0.091420636, - 0.14283475, -0.07390571, -0.06402044, 0.062524505, - -0.093129106, 0.04860203, -0.08364217, -0.08119002, - 0.009352075, 0.22920375, 0.0016303885, 0.11583097, - -0.13732095, 0.012405723, -0.07551853, 0.06343048, - 0.12162708, -0.031923793, -0.014335606, 0.01790974, - -0.10650317, -0.0724401, 0.08554849, -0.05727212, - 0.06556731, -0.042729504, -0.043227166, 0.011683251, - -0.013082158, -0.029302018, -0.010899579, -0.062036745, - -0.022509435, -0.00964907, -0.01567329, 0.04260106, - -0.07787477, -0.11576462, 0.017356863, 0.048673786, - -0.017577527, -0.05527947, -0.082487635, -0.040137455, - -0.10820036, -0.04666372, 0.022746278, -0.07851417, - 0.01068115, 0.032956902, 0.022433773, 0.0026891115, - 0.08944216, -0.0685835, 0.010513544, 0.07228705, - 0.02032331, -0.059686817, -0.0005566496, -0.086984694, - 0.040414046, -0.1380399, 0.094208956, -0.05722982, - 0.012092817, -0.04989123, -0.086576, -0.003399834, - -0.04696032, -0.045747425, 0.10091314, 0.048676282, - -0.029037097, 0.031399418, -0.0040285117, 0.047237843, - 0.09504992, 0.041799378, -0.049185462, -0.031518843, - -0.10516937, 0.026374253, 0.10058866, -0.0033195973, - -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, - -0.10167381, 0.042500053, -0.01447153, 0.06464186, - -0.017142897, 0.03312627, 0.009205989, 0.024138335, - -0.011337001, 0.035530265, -0.010912711, 0.0706555, - -0.005894094, 0.051841937, -0.1401738, -0.02351249, - 0.0365468, 0.07590991, 0.08838724, 0.021681072, - -0.10086113, 0.019608743, -0.06195883, 0.077335775, - 0.023646897, -0.095322326, 0.02233014, 0.09756986, - -0.048691444, -0.009579111, 0.07595467, 0.11480546, - -0.09801813, 0.019894179, 0.08502348, 0.004032281, - 0.037211012, 0.068537936, -0.048005626, -0.091520436, - -0.028379958, -0.01556313, 0.06554592, -0.045599163, - -0.01672207, -0.020169014, -0.011877351, -0.20212261, - 0.010889619, 0.0047078193, 0.038385306, 0.08540671, - -0.017140968, -0.0035865551, 0.016678626, 0.005633034, - 0.015963363, 0.00871737, 0.060130805, 0.028611384, - 0.10109069, -0.015060172, -0.07894427, 0.06401885, - 0.011584063, -0.024466386, 0.0047652307, -0.09041358, - 0.030737216, -0.0046374933, 0.14215417, -0.11823516, - 0.019899689, 0.006106124, -0.027092824, 0.0786356, - 0.05052217, -0.058925, -0.011402121, -0.024987547, - -0.0013661642, -0.06832946, -0.015667673, -0.1083353, - -0.00096863037, -0.06988685, -0.053350925, -0.027275559, - -0.033664223, -0.07978348, -0.025200296, -0.017207067, - -0.058403496, -0.055697463, 0.005798788, 0.12965427, - -0.062582195, 0.0013350133, -0.10482091, 0.0379771, - 0.072521195, -0.0029455067, -0.13797039, -0.03628521, - 0.013806405, -0.017858358, -0.01008298, -0.07700066, - -0.017081132, 0.019358726, 0.0027079724, 0.004635139, - 0.062634714, -0.02338735, -0.039547626, -0.02050681, - 0.03385117, -0.083611414, 0.002862572, -0.09421313, - 0.058618143, -0.08598433, 0.00972939, 0.023867095, - -0.053934585, -0.023203006, 0.07452513, -0.048767887, - -0.07314807, -0.056307215, -0.10433547, -0.06440842, - 0.04328182, 0.04389765, -0.020006588, -0.09076438, - -0.11652589, -0.021705797, 0.03345259, -0.010329105, - -0.025767034, 0.013057034, -0.07316461, -0.10145612, - 0.06358255, 0.18531723, 0.07759293, 0.12006465, - 0.1305557, 0.058638252, -0.03393652, 0.09622831, - -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, - -0.005644518, 0.06857898, -0.12598175, -0.035084512, - 0.03156317, -0.12794146, -0.031963028, 0.04692781, - 0.030070418, 0.0071660685, -0.095516115, -0.004643372, - 0.040170413, -0.062104587, -0.0037324072, 0.0554317, - 0.08184801, -0.019164372, 0.06791302, 0.034257166, - -0.10307039, 0.021943003, 0.046745934, 0.0790918, - -0.0265588, -0.007824208, 0.042546265, -0.00977924, - -0.0002440307, -0.017384544, -0.017990116, 0.12252321, - -0.014512694, -0.08251313, 0.08861942, 0.13589665, - 0.026351685, 0.012641483, 0.07466548, 0.044301085, - -0.045414884, -0.051112458, 0.03444247, -0.08502782, - -0.04106223, -0.028126027, 0.028473156, 0.10467447}); - - lstm.SetRecurrentToForgetWeights( - {-0.057784554, -0.026057621, -0.068447545, -0.022581743, - 0.14811787, 0.10826372, 0.09471067, 0.03987225, - -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, - 0.08414449, -0.022036452, -0.00066928595, -0.09203576, - 0.032950465, -0.10985798, -0.023809856, 0.0021431844, - -0.02196096, -0.00326074, 0.00058621005, -0.074678116, - -0.06193199, 0.055729095, 0.03736828, 0.020123724, - 0.061878487, -0.04729229, 0.034919553, -0.07585433, - -0.04421272, -0.044019096, 0.085488975, 0.04058006, - -0.06890133, -0.030951202, -0.024628663, -0.07672815, - 0.034293607, 0.08556707, -0.05293577, -0.033561368, - -0.04899627, 0.0241671, 0.015736353, -0.095442444, - -0.029564252, 0.016493602, -0.035026584, 0.022337519, - -0.026871363, 0.004780428, 0.0077918363, -0.03601621, - 0.016435321, -0.03263031, -0.09543275, -0.047392778, - 0.013454138, 0.028934088, 0.01685226, -0.086110644, - -0.046250615, -0.01847454, 0.047608484, 0.07339695, - 0.034546845, -0.04881143, 0.009128804, -0.08802852, - 0.03761666, 0.008096139, -0.014454086, 0.014361001, - -0.023502491, -0.0011840804, -0.07607001, 0.001856849, - -0.06509276, -0.006021153, -0.08570962, -0.1451793, - 0.060212336, 0.055259194, 0.06974018, 0.049454916, - -0.027794661, -0.08077226, -0.016179763, 0.1169753, - 0.17213494, -0.0056326236, -0.053934924, -0.0124349, - -0.11520337, 0.05409887, 0.088759385, 0.0019655675, - 0.0042065294, 0.03881498, 0.019844765, 0.041858196, - -0.05695512, 0.047233116, 0.038937137, -0.06542224, - 0.014429736, -0.09719407, 0.13908425, -0.05379757, - 0.012321099, 0.082840554, -0.029899208, 0.044217527, - 0.059855383, 0.07711018, -0.045319796, 0.0948846, - -0.011724666, -0.0033288454, -0.033542685, -0.04764985, - -0.13873616, 0.040668588, 0.034832682, -0.015319203, - -0.018715994, 0.046002675, 0.0599172, -0.043107376, - 0.0294216, -0.002314414, -0.022424703, 0.0030315618, - 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, - 0.12375372, -0.0006038222, 0.029104086, 0.087442465, - 0.052958444, 0.07558703, 0.04817258, 0.044462286, - -0.015213451, -0.08783778, -0.0561384, -0.003008196, - 0.047060397, -0.002058388, 0.03429439, -0.018839769, - 0.024734668, 0.024614193, -0.042046934, 0.09597743, - -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, - -0.02558259, -0.022822596, -0.023273505, -0.02464396, - -0.10991725, -0.006240552, 0.0074488563, 0.024044557, - 0.04383914, -0.046476185, 0.028658995, 0.060410924, - 0.050786525, 0.009452605, -0.0073054377, -0.024810238, - 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, - 0.015898481, 0.021362653, -0.030262267, 0.016587038, - -0.011442813, 0.041154444, -0.007631438, -0.03423484, - -0.010977775, 0.036152758, 0.0066366293, 0.11915515, - 0.02318443, -0.041350313, 0.021485701, -0.10906167, - -0.028218046, -0.00954771, 0.020531068, -0.11995105, - -0.03672871, 0.024019798, 0.014255957, -0.05221243, - -0.00661567, -0.04630967, 0.033188973, 0.10107534, - -0.014027541, 0.030796422, -0.10270911, -0.035999842, - 0.15443139, 0.07684145, 0.036571592, -0.035900835, - -0.0034699554, 0.06209149, 0.015920248, -0.031122351, - -0.03858649, 0.01849943, 0.13872518, 0.01503974, - 0.069941424, -0.06948533, -0.0088794185, 0.061282158, - -0.047401894, 0.03100163, -0.041533746, -0.10430945, - 0.044574402, -0.01425562, -0.024290353, 0.034563623, - 0.05866852, 0.023947537, -0.09445152, 0.035450947, - 0.02247216, -0.0042998926, 0.061146557, -0.10250651, - 0.020881841, -0.06747029, 0.10062043, -0.0023941975, - 0.03532124, -0.016341697, 0.09685456, -0.016764693, - 0.051808182, 0.05875331, -0.04536488, 0.001626336, - -0.028892258, -0.01048663, -0.009793449, -0.017093895, - 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, - -0.001845119, -0.03551521, 0.0018358806, 0.05763657, - -0.01769146, 0.040995963, 0.02235177, -0.060430344, - 0.11475477, -0.023854522, 0.10071741, 0.0686208, - -0.014250481, 0.034261297, 0.047418304, 0.08562733, - -0.030519066, 0.0060542435, 0.014653856, -0.038836084, - 0.04096551, 0.032249358, -0.08355519, -0.026823482, - 0.056386515, -0.010401743, -0.028396193, 0.08507674, - 0.014410365, 0.020995233, 0.17040324, 0.11511526, - 0.02459721, 0.0066619175, 0.025853224, -0.023133837, - -0.081302024, 0.017264642, -0.009585969, 0.09491168, - -0.051313367, 0.054532815, -0.014298593, 0.10657464, - 0.007076659, 0.10964551, 0.0409152, 0.008275321, - -0.07283536, 0.07937492, 0.04192024, -0.1075027}); - - lstm.SetRecurrentToCellWeights( - {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, - 0.055647098, -0.05713207, -0.05626563, 0.005559383, - 0.03375411, -0.025757805, -0.088049285, 0.06017052, - -0.06570978, 0.007384076, 0.035123326, -0.07920549, - 0.053676967, 0.044480428, -0.07663568, 0.0071805613, - 0.08089997, 0.05143358, 0.038261272, 0.03339287, - -0.027673481, 0.044746667, 0.028349208, 0.020090483, - -0.019443132, -0.030755889, -0.0040000007, 0.04465846, - -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, - -0.10893326, 0.076739706, -0.08509834, -0.027997585, - 0.037871376, 0.01449768, -0.09002357, -0.06111149, - -0.046195522, 0.0422062, -0.005683705, -0.1253618, - -0.012925729, -0.04890792, 0.06985068, 0.037654128, - 0.03398274, -0.004781977, 0.007032333, -0.031787455, - 0.010868644, -0.031489216, 0.09525667, 0.013939797, - 0.0058680447, 0.0167067, 0.02668468, -0.04797466, - -0.048885044, -0.12722108, 0.035304096, 0.06554885, - 0.00972396, -0.039238118, -0.05159735, -0.11329045, - 0.1613692, -0.03750952, 0.06529313, -0.071974665, - -0.11769596, 0.015524369, -0.0013754242, -0.12446318, - 0.02786344, -0.014179351, 0.005264273, 0.14376344, - 0.015983658, 0.03406988, -0.06939408, 0.040699873, - 0.02111075, 0.09669095, 0.041345075, -0.08316494, - -0.07684199, -0.045768797, 0.032298047, -0.041805092, - 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, - -0.024950314, 0.11574242, 0.04508852, -0.04335324, - 0.06760663, -0.027437469, 0.07216407, 0.06977076, - -0.05438599, 0.034033038, -0.028602652, 0.05346137, - 0.043184172, -0.037189785, 0.10420091, 0.00882477, - -0.054019816, -0.074273005, -0.030617684, -0.0028467078, - 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, - 0.04361412, -0.007001822, 0.09631092, -0.06702025, - -0.042049985, -0.035070654, -0.04103342, -0.10273396, - 0.0544271, 0.037184782, -0.13150354, -0.0058036847, - -0.008264958, 0.042035464, 0.05891794, 0.029673764, - 0.0063542654, 0.044788733, 0.054816857, 0.062257513, - -0.00093483756, 0.048938446, -0.004952862, -0.007730018, - -0.04043371, -0.017094059, 0.07229206, -0.023670016, - -0.052195564, -0.025616996, -0.01520939, 0.045104615, - -0.007376126, 0.003533447, 0.006570588, 0.056037236, - 0.12436656, 0.051817212, 0.028532185, -0.08686856, - 0.11868599, 0.07663395, -0.07323171, 0.03463402, - -0.050708205, -0.04458982, -0.11590894, 0.021273347, - 0.1251325, -0.15313013, -0.12224372, 0.17228661, - 0.023029093, 0.086124025, 0.006445803, -0.03496501, - 0.028332196, 0.04449512, -0.042436164, -0.026587414, - -0.006041347, -0.09292539, -0.05678812, 0.03897832, - 0.09465633, 0.008115513, -0.02171956, 0.08304309, - 0.071401566, 0.019622514, 0.032163795, -0.004167056, - 0.02295182, 0.030739572, 0.056506045, 0.004612461, - 0.06524936, 0.059999723, 0.046395954, -0.0045512207, - -0.1335546, -0.030136576, 0.11584653, -0.014678886, - 0.0020118146, -0.09688814, -0.0790206, 0.039770417, - -0.0329582, 0.07922767, 0.029322514, 0.026405897, - 0.04207835, -0.07073373, 0.063781224, 0.0859677, - -0.10925287, -0.07011058, 0.048005477, 0.03438226, - -0.09606514, -0.006669445, -0.043381985, 0.04240257, - -0.06955775, -0.06769346, 0.043903265, -0.026784198, - -0.017840602, 0.024307009, -0.040079936, -0.019946516, - 0.045318738, -0.12233574, 0.026170589, 0.0074471775, - 0.15978073, 0.10185836, 0.10298046, -0.015476589, - -0.039390966, -0.072174534, 0.0739445, -0.1211869, - -0.0347889, -0.07943156, 0.014809798, -0.12412325, - -0.0030663363, 0.039695457, 0.0647603, -0.08291318, - -0.018529687, -0.004423833, 0.0037507233, 0.084633216, - -0.01514876, -0.056505352, -0.012800942, -0.06994386, - 0.012962922, -0.031234352, 0.07029052, 0.016418684, - 0.03618972, 0.055686004, -0.08663945, -0.017404709, - -0.054761406, 0.029065743, 0.052404847, 0.020238016, - 0.0048197987, -0.0214882, 0.07078733, 0.013016777, - 0.06262858, 0.009184685, 0.020785125, -0.043904778, - -0.0270329, -0.03299152, -0.060088247, -0.015162964, - -0.001828936, 0.12642565, -0.056757294, 0.013586685, - 0.09232601, -0.035886683, 0.06000002, 0.05229691, - -0.052580316, -0.082029596, -0.010794592, 0.012947712, - -0.036429964, -0.085508935, -0.13127148, -0.017744139, - 0.031502828, 0.036232427, -0.031581745, 0.023051167, - -0.05325106, -0.03421577, 0.028793324, -0.034633752, - -0.009881397, -0.043551125, -0.018609839, 0.0019097115, - -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); - - lstm.SetRecurrentToOutputWeights({ - 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, - -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, - -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, - -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, - -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, - -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, - -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, - 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, - -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, - 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, - -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, - -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, - 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, - 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, - -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, - 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, - 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, - 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, - 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, - 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, - -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, - 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, - -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, - 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, - 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, - 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, - -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, - -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, - -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, - -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, - -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, - -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, - 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, - 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, - -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, - 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, - -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, - -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, - -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, - 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, - 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, - 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, - -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, - 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, - -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, - -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, - -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, - -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, - 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, - -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, - 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, - -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, - -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, - -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, - -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, - 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, - 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, - -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, - 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, - 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, - -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, - 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, - 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, - 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, - }); - - lstm.SetCellToInputWeights( - {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, - -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, - -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, - 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); - - lstm.SetCellToForgetWeights( - {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, - -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, - -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, - 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); - - lstm.SetCellToOutputWeights( - {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, - -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, - -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, - 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); - - lstm.SetProjectionWeights( - {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, - 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, - -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, - -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, - 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, - 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, - 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, - 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, - -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, - -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, - -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, - 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, - 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, - 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, - 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, - 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, - -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, - 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, - -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, - 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, - -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, - -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, - 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, - -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, - 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, - -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, - -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, - 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, - -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, - -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, - -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, - 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, - 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, - -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, - 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, - 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, - 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, - 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, - 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, - -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, - -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, - 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, - -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, - -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, - 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, - 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, - 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, - -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, - -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, - -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, - 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, - -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, - 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, - 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, - -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, - -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, - -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, - 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, - -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, - -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, - -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, - 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, - 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, - 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); - - static float lstm_input[][20] = { - {// Batch0: 4 (input_sequence_size) * 5 (n_input) - 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, - 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, - 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, - - {// Batch1: 4 (input_sequence_size) * 5 (n_input) - 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, - 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, - 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; - - static float lstm_golden_output[][64] = { - {// Batch0: 4 (input_sequence_size) * 16 (n_output) - -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, - -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, - -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, - 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, - -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, - -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, - 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, - 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, - 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, - 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, - -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, - -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, - 0.0286833, 0.00824207, 0.0264887, 0.0305169}, - {// Batch1: 4 (input_sequence_size) * 16 (n_output) - -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, - -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, - 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, - 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, - -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, - -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, - 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, - 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, - 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, - 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, - -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, - -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, - 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); // Resetting cell_state and output_state lstm.ResetCellState(); lstm.ResetOutputState(); - const int input_sequence_size = - sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs()); - for (int i = 0; i < input_sequence_size; i++) { - float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); - float* batch0_end = batch0_start + lstm.num_inputs(); + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm); +} - lstm.SetInput(0, batch0_start, batch0_end); +TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; - float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); - float* batch1_end = batch1_start + lstm.num_inputs(); - lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end); + HybridLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights(input_to_input_weights_); + lstm.SetInputToCellWeights(input_to_cell_weights_); + lstm.SetInputToForgetWeights(input_to_forget_weights_); + lstm.SetInputToOutputWeights(input_to_output_weights_); + + lstm.SetInputGateBias(input_gate_bias_); + lstm.SetCellBias(cell_gate_bias_); + lstm.SetForgetGateBias(forget_gate_bias_); + lstm.SetOutputGateBias(output_gate_bias_); + + lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_); + lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_); + lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_); + lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_); + + lstm.SetCellToInputWeights(cell_to_input_weights_); + lstm.SetCellToForgetWeights(cell_to_forget_weights_); + lstm.SetCellToOutputWeights(cell_to_output_weights_); + + lstm.SetProjectionWeights(projection_weights_); - lstm.Invoke(); + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); - float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); - float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); - float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); - float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); - std::vector expected; - expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); - expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); - EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); - } + VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467); } } // namespace diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index 62f4e94a386fbbc6987e8a6dc1a9a47ce3349cbb..b69a221447db963bcd3a7e6a69f132fe3767bfd1 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -120,8 +120,9 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, double real_multiplier = input1->params.scale * input2->params.scale / output->params.scale; - QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, - &output_shift); + QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier, + &output_shift); + output_shift *= -1; int32 output_activation_min, output_activation_max; CalculateActivationRangeUint8(params->activation, output, diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 66069996049b1a415a6397f7321970eb90d0c050..fca040649d8ff76b42db1962135cd4315e961ec0 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -73,6 +73,7 @@ TfLiteRegistration* Register_SQUEEZE(); TfLiteRegistration* Register_STRIDED_SLICE(); TfLiteRegistration* Register_EXP(); TfLiteRegistration* Register_TOPK_V2(); +TfLiteRegistration* Register_LOG(); TfLiteRegistration* Register_LOG_SOFTMAX(); TfLiteRegistration* Register_CAST(); TfLiteRegistration* Register_DEQUANTIZE(); @@ -85,11 +86,16 @@ TfLiteRegistration* Register_GREATER_EQUAL(); TfLiteRegistration* Register_LESS(); TfLiteRegistration* Register_LESS_EQUAL(); TfLiteRegistration* Register_FLOOR(); +TfLiteRegistration* Register_TILE(); TfLiteRegistration* Register_NEG(); TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE(); TfLiteRegistration* Register_SIN(); TfLiteRegistration* Register_TRANSPOSE_CONV(); +TfLiteRegistration* Register_EXPAND_DIMS(); +TfLiteRegistration* Register_SPARSE_TO_DENSE(); +TfLiteRegistration* Register_EQUAL(); +TfLiteRegistration* Register_NOT_EQUAL(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -123,7 +129,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, Register_LOCAL_RESPONSE_NORMALIZATION()); - AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_LSTM, Register_LSTM(), /* min_version */ 1, + /* max_version */ 2); AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, Register_BIDIRECTIONAL_SEQUENCE_LSTM()); AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, @@ -144,6 +151,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE()); AddBuiltin(BuiltinOperator_EXP, Register_EXP()); AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); + AddBuiltin(BuiltinOperator_LOG, Register_LOG()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE()); @@ -161,6 +169,11 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SLICE, Register_SLICE()); AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); + AddBuiltin(BuiltinOperator_TILE, Register_TILE()); + AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS()); + AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); + AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL()); + AddBuiltin(BuiltinOperator_NOT_EQUAL, Register_NOT_EQUAL()); #if 0 // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index f2092eaa36db32ebbc959ac23365bb13dd034e68..86c4cd3ee88013ca4174f444d0388bc036d9cde6 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -61,12 +61,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1); - // TODO(ahentz): Our current implementations only support float32. - TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32); // ResizeBilinear creates a float tensor even when the input is made of // integers. - output->type = kTfLiteFloat32; + output->type = input->type; if (!IsConstantTensor(size)) { SetTensorToDynamic(output); @@ -90,17 +88,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } if (output->type == kTfLiteFloat32) { -#define TF_LITE_RESIZE_BILINEAR(type) \ - type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ - GetTensorData(size), GetTensorDims(size), \ - GetTensorData(output), GetTensorDims(output), \ +#define TF_LITE_RESIZE_BILINEAR(type, datatype) \ + type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ + GetTensorData(size), GetTensorDims(size), \ + GetTensorData(output), GetTensorDims(output), \ params->align_corners) if (kernel_type == kReference) { - TF_LITE_RESIZE_BILINEAR(reference_ops); + TF_LITE_RESIZE_BILINEAR(reference_ops, float); } if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { - TF_LITE_RESIZE_BILINEAR(optimized_ops); + TF_LITE_RESIZE_BILINEAR(optimized_ops, float); + } + } else if (output->type == kTfLiteUInt8) { + if (kernel_type == kReference) { + TF_LITE_RESIZE_BILINEAR(reference_ops, uint8_t); + } + if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { + TF_LITE_RESIZE_BILINEAR(optimized_ops, uint8_t); } #undef TF_LITE_RESIZE_BILINEAR } else { diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 4e03f3820a5c14ee1692c553db61e385716b1723..10caffea03ebcec7862df1627541ac3d076b04e4 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -22,6 +22,7 @@ namespace tflite { namespace { using ::testing::ElementsAreArray; +using uint8 = std::uint8_t; class ResizeBilinearOpModel : public SingleOpModel { public: @@ -34,7 +35,7 @@ class ResizeBilinearOpModel : public SingleOpModel { } else { size_ = AddInput({TensorType_INT32, {2}}); } - output_ = AddOutput(TensorType_FLOAT32); // Always float. + output_ = AddOutput(input.type); SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, CreateResizeBilinearOptions(builder_).Union()); @@ -45,12 +46,16 @@ class ResizeBilinearOpModel : public SingleOpModel { } } - void SetInput(std::initializer_list data) { + template + void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } void SetSize(std::initializer_list data) { PopulateTensor(size_, data); } - std::vector GetOutput() { return ExtractVector(output_); } + template + std::vector GetOutput() { + return ExtractVector(output_); + } private: int input_; @@ -60,60 +65,121 @@ class ResizeBilinearOpModel : public SingleOpModel { TEST(ResizeBilinearOpTest, HorizontalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); - m.SetInput({3, 6}); + m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); - const_m.SetInput({3, 6}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + +TEST(ResizeBilinearOpTest, HorizontalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}}); + m.SetInput({3, 6}); + m.SetSize({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } TEST(ResizeBilinearOpTest, VerticalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); - m.SetInput({3, 9}); + m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); - const_m.SetInput({3, 9}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + +TEST(ResizeBilinearOpTest, VerticalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}}); + m.SetInput({3, 9}); + m.SetSize({3, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + EXPECT_THAT(const_m.GetOutput(), + ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } TEST(ResizeBilinearOpTest, TwoDimensionalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); - m.SetInput({ + m.SetInput({ 3, 6, // 9, 12 // }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 6, // 9, 12 // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); } TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}); - m.SetInput({ + m.SetInput({ 3, 6, // 9, 12, // 4, 10, // @@ -121,60 +187,123 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - 4, 8, 10, // - 8, 12, 14, // - 10, 14, 16, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 6, // 9, 12, // 4, 10, // 10, 16 // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 5, 6, // - 7, 9, 10, // - 9, 11, 12, // - 4, 8, 10, // - 8, 12, 14, // - 10, 14, 16, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); } TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}); - m.SetInput({ + m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // }); m.SetSize({3, 3}); m.Invoke(); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 14, 12, 16, // - }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3}); - const_m.SetInput({ + const_m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // }); const_m.Invoke(); - EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ - 3, 4, 5, 8, 6, 10, // - 7, 8, 9, 12, 10, 14, // - 9, 10, 11, 14, 12, 16, // - }))); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}}); + m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 13, 16, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 13, 16, // + }))); } +TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) { + ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}}); + m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + m.SetSize({3, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 13, 12, 16, // + }))); + + ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3}); + const_m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 13, 12, 16, // + }))); +} } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc new file mode 100644 index 0000000000000000000000000000000000000000..404c32ad9ca8b9f1e467b747708ccb451f2a5118 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc @@ -0,0 +1,275 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace sparse_to_dense { + +constexpr int kIndicesTensor = 0; +constexpr int kOutputShapeTensor = 1; +constexpr int kValueInputTensor = 2; +constexpr int kDefaultValueTensor = 3; +constexpr int kOutputTensor = 0; + +constexpr int kMaxDimensions = 4; + +template +TfLiteStatus Resize(TfLiteContext* context, const TfLiteTensor* output_shape, + TfLiteTensor* output) { + const int output_dimensions = NumElements(output_shape); + TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(output_dimensions); + for (int i = 0; i < output_dimensions; ++i) { + output_shape_array->data[i] = GetTensorData(output_shape)[i]; + } + + return context->ResizeTensor(context, output, output_shape_array); +} + +TfLiteStatus CheckDimensionsMatch(TfLiteContext* context, + const TfLiteTensor* indices, + const TfLiteTensor* output_shape, + const TfLiteTensor* values) { + switch (NumDimensions(indices)) { + case 0: + case 1: { + if (NumDimensions(values) == 0) { + TF_LITE_ENSURE_EQ(context, NumElements(indices), NumElements(values)); + } + TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 1); + break; + } + case 2: { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 1), + NumElements(output_shape)); + if (NumDimensions(values) == 0) + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + NumElements(values)); + break; + } + default: + context->ReportError( + context, "Wrong indices dimensions %d, should be less than 3.", + NumDimensions(indices)); + return kTfLiteError; + } + return kTfLiteOk; +} + +// Convert indices into a vector of 4-d vectors. +// TODO(renjieliu): Revisit here to improve the performance, since multiple +// allocations of std::vectors will be quite slow on phones. +template +TfLiteStatus GetIndicesVector(TfLiteContext* context, + const TfLiteTensor* indices, + const int num_indices, + std::vector>* indices_vector) { + // Note because TfLite will reverse the dimensions, so pad zeros upfront. + switch (NumDimensions(indices)) { + case 0: + case 1: { + const auto indices_data = GetTensorData(indices); + for (int i = 0; i < num_indices; ++i) { + std::vector index({0, 0, 0, indices_data[i]}); + indices_vector->push_back(index); + } + break; + } + case 2: { + const int true_dimensions = SizeOfDimension(indices, 1); + TF_LITE_ENSURE(context, true_dimensions <= kMaxDimensions); + for (int i = 0; i < num_indices; ++i) { + std::vector index; + index.reserve(kMaxDimensions); + // Fill the index with 1 up to kMaxDimensions - true_dimensions to + // satisfy the needs for 4-dimension index. + for (int j = 0; j < kMaxDimensions - true_dimensions; ++j) { + index.push_back(0); + } + for (int j = 0; j < true_dimensions; ++j) { + index.push_back(GetTensorData(indices)[i * true_dimensions + j]); + } + + indices_vector->push_back(index); + } + break; + } + default: + context->ReportError(context, + "Indices dimensions problem, got %d dimensions", + NumDimensions(indices)); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus ResizeOutputShape(TfLiteContext* context, + const TfLiteTensor* output_shape, + TfLiteTensor* output) { + if (output_shape->type == kTfLiteInt32) { + return Resize(context, output_shape, output); + } else if (output_shape->type == kTfLiteInt64) { + return Resize(context, output_shape, output); + } else { + context->ReportError(context, "Dense shape type %d not supported.", + output_shape->type); + return kTfLiteError; + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* output_shape = + GetInput(context, node, kOutputShapeTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + const TfLiteTensor* default_value = + GetInput(context, node, kDefaultValueTensor); + + // TODO(renjieliu): Handle validate_indices. + + // Indices can be 0-D, 1-D or 2-D. + TF_LITE_ASSERT(NumDimensions(indices) >= 0); + TF_LITE_ENSURE(context, NumDimensions(indices) < 3); + TF_LITE_ASSERT(NumDimensions(output_shape) >= 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); + // Values can be 0-D or 1-D. + TF_LITE_ASSERT(NumDimensions(values) >= 0); + TF_LITE_ENSURE(context, NumDimensions(values) < 2); + + TF_LITE_ENSURE_EQ(context, NumElements(default_value), 1); + + TF_LITE_ENSURE( + context, indices->type == kTfLiteInt32 || indices->type == kTfLiteInt64); + TF_LITE_ENSURE(context, output_shape->type == kTfLiteInt32 || + output_shape->type == kTfLiteInt64); + TF_LITE_ENSURE_EQ(context, values->type, default_value->type); + + // Ensure dimensions match. + TF_LITE_ENSURE_OK( + context, CheckDimensionsMatch(context, indices, output_shape, values)); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); + + if (!IsConstantTensor(output_shape)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputShape(context, output_shape, output); +} + +template +TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* output_shape = + GetInput(context, node, kOutputShapeTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + const TfLiteTensor* default_value = + GetInput(context, node, kDefaultValueTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputShape(context, output_shape, output)); + } + + const int num_indices = SizeOfDimension(indices, 0); + const bool value_is_scalar = NumDimensions(values) == 0; + std::vector> indices_vector; + indices_vector.reserve(num_indices); + TF_LITE_ENSURE_OK(context, GetIndicesVector(context, indices, num_indices, + &indices_vector)); + reference_ops::SparseToDense(indices_vector, GetTensorData(values), + *GetTensorData(default_value), + GetTensorData(output), GetTensorDims(output), + value_is_scalar); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + + // Currently only supports float32 and int32. + switch (values->type) { + case kTfLiteFloat32: { + switch (indices->type) { + case kTfLiteInt32: { + return SparseToDenseImpl(context, node); + } + case kTfLiteInt64: { + return SparseToDenseImpl(context, node); + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + indices->type); + return kTfLiteError; + } + break; + } + case kTfLiteInt32: { + switch (indices->type) { + case kTfLiteInt32: { + return SparseToDenseImpl(context, node); + } + case kTfLiteInt64: { + return SparseToDenseImpl(context, node); + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + indices->type); + return kTfLiteError; + } + break; + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + values->type); + return kTfLiteError; + } +} + +} // namespace sparse_to_dense + +TfLiteRegistration* Register_SPARSE_TO_DENSE() { + static TfLiteRegistration r = {nullptr, nullptr, sparse_to_dense::Prepare, + sparse_to_dense::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a51ec17afcefd791680d7aa42cef467f481f6dbc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc @@ -0,0 +1,155 @@ + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class SparseToDenseOpModel : public SingleOpModel { + public: + SparseToDenseOpModel(std::initializer_list indices_shape, + std::initializer_list output_shape_shape, + std::initializer_list values_shape, T default_value, + TensorType tensor_index_type, + TensorType tensor_input_type) { + indices_ = AddInput(tensor_index_type); + output_shape_ = AddInput(TensorType_INT32); + values_ = AddInput(tensor_input_type); + default_value_ = AddInput(tensor_input_type); + output_ = AddOutput(tensor_input_type); + + SetBuiltinOp(BuiltinOperator_SPARSE_TO_DENSE, + BuiltinOptions_SparseToDenseOptions, + CreateSparseToDenseOptions(builder_, false).Union()); + BuildInterpreter({indices_shape, output_shape_shape, values_shape, {1}}); + + PopulateTensor(default_value_, {default_value}); + } + + int indices() { return indices_; } + int output_shape() { return output_shape_; } + int values() { return values_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int indices_; + int output_shape_; + int values_; + int default_value_; + int output_; +}; + +TEST(SparseToDenseOpModelTest, ZeroDimensionTest) { + SparseToDenseOpModel m({1}, {1}, {1}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {3}); + m.PopulateTensor(m.output_shape(), {5}); + m.PopulateTensor(m.values(), {7}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 7, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(SparseToDenseOpModelTest, OneDimensionTest) { + SparseToDenseOpModel m({3}, {1}, {3}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {1, 3, 5}); + m.PopulateTensor(m.output_shape(), {7}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 0, 4, 0, 6, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({7})); +} + +TEST(SparseToDenseOpModelTest, TwoDimensionsTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 4, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, DefaultValueTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, IntegerValueTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT32, + TensorType_INT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, Int64IndexTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT64, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index d788159a8d80e6479024b7b75624839387a461c7..a8b803589962032db3ed579d31e8b736c3afada0 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -126,16 +126,19 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, int32 input1_multiplier; int input1_shift; - QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, - &input1_shift); + QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, + &input1_multiplier, &input1_shift); + input1_shift *= -1; int32 input2_multiplier; int input2_shift; - QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, - &input2_shift); + QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, + &input2_multiplier, &input2_shift); + input2_shift *= -1; int32 output_multiplier; int output_shift; - QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, - &output_shift); + QuantizeMultiplierSmallerThanOneExp(real_output_multiplier, + &output_multiplier, &output_shift); + output_shift *= -1; int32 output_activation_min, output_activation_max; CalculateActivationRangeUint8(params->activation, output, @@ -175,7 +178,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { output); } else { context->ReportError( - context, "output type %d is not support, requires float|uint8 types.", + context, "output type %d is not supported, requires float|uint8 types.", output->type); return kTfLiteError; } diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index 1a01ee093626c08badd65858fc16ad44e69e4912..d23ec201b41887b0682242687fc938d76d058c44 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -112,6 +112,12 @@ void SingleOpModel::BuildInterpreter( if (shape.empty()) continue; CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk); } + + // Modify delegate with function. + if (apply_delegate_fn_) { + apply_delegate_fn_(interpreter_.get()); + } + CHECK(interpreter_->AllocateTensors() == kTfLiteOk) << "Cannot allocate tensors"; } diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index 55edc97d19fa75bedb6c0928fcf9c7be5f434522..db80c0082c394a2cb2f9388d3db5bd1a7cbe6266 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -114,6 +114,13 @@ class SingleOpModel { SingleOpModel() {} ~SingleOpModel() {} + // Set a function callback that is run right after graph is prepared + // that allows applying external delegates. This is useful for testing + // other runtimes like NN API or GPU. + void SetApplyDelegate(std::function apply_delegate_fn) { + apply_delegate_fn_ = apply_delegate_fn; + } + // Copying or assignment is disallowed to simplify ownership semantics. SingleOpModel(const SingleOpModel&) = delete; SingleOpModel& operator=(const SingleOpModel&) = delete; @@ -317,6 +324,9 @@ class SingleOpModel { std::vector> operators_; std::vector> buffers_; std::map> custom_registrations_; + // A function pointer that gets called after the interpreter is created but + // before evaluation happens. This is useful for applying a delegate. + std::function apply_delegate_fn_; }; // Base class for single op unit tests. diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc new file mode 100644 index 0000000000000000000000000000000000000000..af77f074742eb3fef10a74616ff679255911fbb2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/tile.cc @@ -0,0 +1,194 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +namespace tflite { +namespace ops { +namespace builtin { +namespace tile { + +constexpr int kInputTensor = 0; +constexpr int kInputMultipliers = 1; +constexpr int kOutputTensor = 0; + +namespace { +template +TfLiteIntArray* MultiplyShapeDims(const TfLiteIntArray& shape, + const TfLiteTensor* multipliers, + int num_dimensions) { + const T* multipliers_v = GetTensorData(multipliers); + + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + for (int i = 0; i < num_dimensions; ++i) { + output_shape->data[i] = shape.data[i] * multipliers_v[i]; + } + return output_shape; +} + +TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); + + const int num_dimensions = NumDimensions(input); + const int num_multipliers = NumElements(multipliers); + TF_LITE_ENSURE_EQ(context, num_dimensions, num_multipliers); + switch (multipliers->type) { + case kTfLiteInt32: + return context->ResizeTensor( + context, output, + MultiplyShapeDims(*input->dims, multipliers, + num_dimensions)); + case kTfLiteInt64: + return context->ResizeTensor( + context, output, + MultiplyShapeDims(*input->dims, multipliers, + num_dimensions)); + default: + context->ReportError(context, "Tile not supported multiply tensor type."); + return kTfLiteError; + } +} + +template +void CopyMultipleTimes(const T* in_data, int32_t in_size, int32_t multiplier, + T* out_data) { + for (int i = 0; i < multiplier; ++i) { + const T* in_end = in_data + in_size; + T* new_out_data = std::copy(in_data, in_end, out_data); + in_data = out_data; + out_data = new_out_data; + } +} + +template +std::pair TileOneDimension(const TfLiteIntArray& in_dimensions, + const T* in_data, const M* multipliers, + T* out_data, int dimension) { + const int dimension_size = in_dimensions.data[dimension]; + if (dimension == in_dimensions.size - 1) { + CopyMultipleTimes(in_data, dimension_size, multipliers[dimension], + out_data); + return std::make_pair(dimension_size, + dimension_size * multipliers[dimension]); + } + int total_stride_size = 0, total_tiled_stride_size = 0; + const T* copy_from_data = in_data; + T* copy_to_data = out_data; + for (int i = 0; i < dimension_size; ++i) { + int stride_size = 0, tiled_stride_size = 0; + std::tie(stride_size, tiled_stride_size) = + TileOneDimension(in_dimensions, copy_from_data, multipliers, + copy_to_data, dimension + 1); + copy_from_data += stride_size; + copy_to_data += tiled_stride_size; + total_stride_size += stride_size; + total_tiled_stride_size += tiled_stride_size; + } + CopyMultipleTimes(out_data, total_tiled_stride_size, + multipliers[dimension] - 1, + out_data + total_tiled_stride_size); + return std::make_pair(total_stride_size, + total_tiled_stride_size * multipliers[dimension]); +} + +template +void Tile(const TfLiteIntArray& in_dimensions, const TfLiteTensor* in_data, + const TfLiteTensor* multipliers, TfLiteTensor* out_data) { + // Doing recursively tiling from top to down dimension. + switch (multipliers->type) { + case kTfLiteInt32: + TileOneDimension(in_dimensions, GetTensorData(in_data), + GetTensorData(multipliers), + GetTensorData(out_data), 0); + break; + case kTfLiteInt64: + TileOneDimension(in_dimensions, GetTensorData(in_data), + GetTensorData(multipliers), + GetTensorData(out_data), 0); + break; + default: + break; + } +} +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); + // Only int32 and int64 multipliers type is supported. + TF_LITE_ENSURE_MSG(context, + (multipliers->type == kTfLiteInt32) || + (multipliers->type == kTfLiteInt64), + "Tile only supports int32 and int64 mutlipliers."); + + if (IsConstantTensor(multipliers)) { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } else { + SetTensorToDynamic(output); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* multipliers = GetInput(context, node, kInputMultipliers); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } + + switch (output->type) { + case kTfLiteFloat32: + Tile(*(input->dims), input, multipliers, output); + break; + case kTfLiteUInt8: + Tile(*(input->dims), input, multipliers, output); + break; + case kTfLiteInt32: + Tile(*(input->dims), input, multipliers, output); + break; + case kTfLiteInt64: + Tile(*(input->dims), input, multipliers, output); + break; + default: + context->ReportError(context, "Type is currently not supported by Tile."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace tile +TfLiteRegistration* Register_TILE() { + static TfLiteRegistration r = {nullptr, nullptr, tile::Prepare, tile::Eval}; + return &r; +} +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a134a75d56ae03a5d03a3cdf632146474b863971 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/tile_test.cc @@ -0,0 +1,256 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; +class TileOpModel : public SingleOpModel { + public: + TileOpModel(std::initializer_list input_shape, TensorType input_type, + TensorType multiply_type) { + input_ = AddInput(input_type); + multipliers_ = AddInput(TensorType_INT32); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_TILE, BuiltinOptions_TileOptions, 0); + BuildInterpreter({input_shape, {static_cast(input_shape.size())}}); + } + + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputUInt8(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputInt32(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputInt64(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetMultipliers(std::initializer_list data) { + PopulateTensor(multipliers_, data); + } + + std::vector GetOutputFloat() { return ExtractVector(output_); } + + std::vector GetOutputUInt8() { return ExtractVector(output_); } + + std::vector GetOutputInt32() { return ExtractVector(output_); } + + std::vector GetOutputInt64() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int multipliers_; + int output_; +}; + +TEST(TileTest, Float32Vector) { + TileOpModel m({3}, TensorType_FLOAT32, TensorType_INT32); + m.SetInputFloat({1.f, 2.f, 3.f}); + m.SetMultipliers({2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray({1.f, 2.f, 3.f, 1.f, 2.f, 3.f})); +} + +TEST(TileTest, Float32Matrix) { + TileOpModel m({2, 3}, TensorType_FLOAT32, TensorType_INT32); + m.SetInputFloat({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Float32HighDimension) { + TileOpModel m({1, 2, 3}, TensorType_FLOAT32, TensorType_INT32); + m.SetInputFloat({ + 11.f, + 12.f, + 13.f, + 21.f, + 22.f, + 23.f, + }); + m.SetMultipliers({2, 3, 1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutputFloat(), + ElementsAreArray({11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, + 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, + 11.f, 12.f, 13.f, 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, + 21.f, 22.f, 23.f, 11.f, 12.f, 13.f, 21.f, 22.f, 23.f})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 6, 3})); +} + +TEST(TileTest, Uint8Matrix) { + TileOpModel m({2, 3}, TensorType_UINT8, TensorType_INT32); + m.SetInputUInt8({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputUInt8(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Int32Matrix) { + TileOpModel m({2, 3}, TensorType_INT32, TensorType_INT32); + m.SetInputInt32({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputInt32(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Int64Matrix) { + TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT32); + m.SetInputInt64({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} + +TEST(TileTest, Int64Matrix64Multipliers) { + TileOpModel m({2, 3}, TensorType_INT64, TensorType_INT64); + m.SetInputInt64({ + 11, + 12, + 13, + 21, + 22, + 23, + }); + m.SetMultipliers({2, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputInt64(), ElementsAreArray({ + 11, + 12, + 13, + 21, + 22, + 23, + 11, + 12, + 13, + 21, + 22, + 23, + })); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 3})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc index 3c99661029ed1ac881536f83519dcec355c60d50..e83b1ec9879d3c360203a52835d8486d0a9b81bb 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc @@ -79,7 +79,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Ensure that weights and inputs have the same channel dimension. // Note: TOCO will reorder weights in the following format: OHWI. TF_LITE_ENSURE_EQ(context, SizeOfDimension(input, 3), - SizeOfDimension(weights, 0)); + SizeOfDimension(weights, 3)); if (!IsConstantTensor(output_shape)) { SetTensorToDynamic(output); diff --git a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc index 52be08934997f484337e4a3592bc7af832601695..55df8971806ed0baae9f5bcaebd24fb8065ec300 100644 --- a/tensorflow/contrib/lite/kernels/transpose_conv_test.cc +++ b/tensorflow/contrib/lite/kernels/transpose_conv_test.cc @@ -88,10 +88,10 @@ TEST(TransposeConvOpModelTest, SimpleTest) { // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[18, 1]) TEST(TransposeConvOpModelTest, TwoFiltersTest) { - TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_SAME, 1, 1); + TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_SAME, 1, 1); m.PopulateTensor(m.output_shape(), {1, 4, 4, 1}); - m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, - 8, 10, 12, 14, 16, 18}); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18}); m.PopulateTensor( m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, @@ -117,10 +117,10 @@ TEST(TransposeConvOpModelTest, TwoFiltersTest) { // And filter value is derived by: // filter = tf.reshape(tf.transpose(filter, perm=[3, 0, 1, 2]), shape=[1, 18]) TEST(TransposeConvOpModelTest, PaddingValidTest) { - TransposeConvOpModel m({1, 4, 4, 2}, {2, 3, 3, 1}, Padding_VALID, 1, 1); + TransposeConvOpModel m({1, 4, 4, 2}, {1, 3, 3, 2}, Padding_VALID, 1, 1); m.PopulateTensor(m.output_shape(), {1, 6, 6, 1}); - m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, - 8, 10, 12, 14, 16, 18}); + m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18}); m.PopulateTensor( m.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, @@ -171,10 +171,10 @@ TEST(TransposeConvOpModelTest, StrideValidTest) { // [1, 2, 2, 1 ], // "VALID") TEST(TransposeConvOpModelTest, MultiChannelTest) { - TransposeConvOpModel m({1, 2, 2, 1}, {1, 3, 3, 2}, Padding_VALID, 2, 2); + TransposeConvOpModel m({1, 2, 2, 1}, {2, 3, 3, 1}, Padding_VALID, 2, 2); m.PopulateTensor(m.output_shape(), {1, 5, 5, 2}); - m.PopulateTensor(m.filter(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, - 13, 14, 15, 16, 17, 18}); + m.PopulateTensor(m.filter(), {1, 3, 5, 7, 9, 11, 13, 15, 17, 2, 4, 6, + 8, 10, 12, 14, 16, 18}); m.PopulateTensor(m.input(), {1, 2, 3, 4}); m.Invoke(); diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc index 8429dba54bd1806125aadc2119ca59c1bd42ce89..164a0cbd08d6ce82a413f12ba6b1703087a30aba 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -41,7 +41,7 @@ constexpr int kOutputTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* scratch_tensor_index = new int; - context->AddTensors(context, /*tensors_to_add=*/2, scratch_tensor_index); + context->AddTensors(context, /*tensors_to_add=*/3, scratch_tensor_index); return scratch_tensor_index; } @@ -102,7 +102,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (input->type == kTfLiteFloat32 && input_weights->type == kTfLiteUInt8) { int* scratch_tensor_index = reinterpret_cast(node->user_data); TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries = TfLiteIntArrayCreate(3); node->temporaries->data[0] = *scratch_tensor_index; TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/0); input_quantized->type = kTfLiteUInt8; @@ -125,6 +125,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { context->ResizeTensor(context, hidden_state_quantized, hidden_state_quantized_size)); } + node->temporaries->data[2] = *scratch_tensor_index + 2; + TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/2); + scaling_factors->type = kTfLiteFloat32; + scaling_factors->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); + scaling_factors_size->data[0] = batch_size; + if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) { + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, + scaling_factors_size)); + } } return kTfLiteOk; } @@ -187,14 +197,12 @@ TfLiteStatus EvalFloat(const TfLiteTensor* input, return kTfLiteOk; } -TfLiteStatus EvalQuantized(const TfLiteTensor* input, - const TfLiteTensor* input_weights, - const TfLiteTensor* recurrent_weights, - const TfLiteTensor* bias, - const TfLiteSequenceRNNParams* params, - TfLiteTensor* input_scratch, - TfLiteTensor* hidden_state_scratch, - TfLiteTensor* hidden_state, TfLiteTensor* output) { +TfLiteStatus EvalHybrid( + const TfLiteTensor* input, const TfLiteTensor* input_weights, + const TfLiteTensor* recurrent_weights, const TfLiteTensor* bias, + const TfLiteSequenceRNNParams* params, TfLiteTensor* input_scratch, + TfLiteTensor* hidden_state_scratch, TfLiteTensor* scaling_factors, + TfLiteTensor* hidden_state, TfLiteTensor* output) { const bool time_major = params->time_major; const int batch_size = (time_major) ? input->dims->data[1] : input->dims->data[0]; @@ -218,6 +226,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, reinterpret_cast(input_scratch->data.uint8); int8_t* quantized_hidden_state_ptr = reinterpret_cast(hidden_state_scratch->data.uint8); + float* scaling_factors_ptr = scaling_factors->data.f; if (time_major) { // Initialize the pointer to hidden state. @@ -233,7 +242,8 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, input_ptr_batch, input_weights_ptr, input_weights_scale, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, batch_size, params->activation, quantized_input_ptr, - quantized_hidden_state_ptr, hidden_state_ptr_batch, output_ptr_batch); + quantized_hidden_state_ptr, scaling_factors_ptr, + hidden_state_ptr_batch, output_ptr_batch); } } else { // For each batch @@ -252,7 +262,7 @@ TfLiteStatus EvalQuantized(const TfLiteTensor* input, recurrent_weights_ptr, recurrent_weights_scale, bias_ptr, input_size, num_units, /*batch_size=*/1, params->activation, quantized_input_ptr, quantized_hidden_state_ptr, - hidden_state_ptr_batch, output_ptr_batch); + scaling_factors_ptr, hidden_state_ptr_batch, output_ptr_batch); } } } @@ -278,9 +288,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(mirkov): implement eval with quantized inputs as well. TfLiteTensor* input_quantized = GetTemporary(context, node, 0); TfLiteTensor* hidden_state_quantized = GetTemporary(context, node, 1); - return EvalQuantized(input, input_weights, recurrent_weights, bias, - params, input_quantized, hidden_state_quantized, - hidden_state, output); + TfLiteTensor* scaling_factors = GetTemporary(context, node, 2); + return EvalHybrid(input, input_weights, recurrent_weights, bias, params, + input_quantized, hidden_state_quantized, + scaling_factors, hidden_state, output); } default: context->ReportError(context, "Type %d not currently supported.", diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 80fcb28bc7f6c09c7b979fcefcbc25deef583a00..039f32b38eb29068b223dd63355c66295301beba 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -322,12 +322,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = nullptr; switch (op_type) { - case BuiltinOperator_CALL: - // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are - // ok for now, since there is no call implementation either. - break; - case BuiltinOperator_CUSTOM: - break; case BuiltinOperator_CONV_2D: { TfLiteConvParams* params = MallocPOD(); if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { @@ -343,21 +337,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_TANH: - case BuiltinOperator_LOGISTIC: - case BuiltinOperator_RELU: - case BuiltinOperator_RELU_N1_TO_1: - case BuiltinOperator_RELU6: - case BuiltinOperator_CONCAT_EMBEDDINGS: - case BuiltinOperator_EXP: - case BuiltinOperator_TOPK_V2: - case BuiltinOperator_LOG_SOFTMAX: - case BuiltinOperator_DEQUANTIZE: - case BuiltinOperator_PRELU: - case BuiltinOperator_FLOOR: - case BuiltinOperator_NEG: - case BuiltinOperator_SIN: - break; case BuiltinOperator_CAST: { TfLiteCastParams* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_CastOptions()) { @@ -445,9 +424,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_EMBEDDING_LOOKUP: - // no-op. - break; case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { TfLiteEmbeddingLookupSparseParams* params = MallocPOD(); @@ -558,6 +534,14 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, parse_activation(lstm_params->fused_activation_function()); params->cell_clip = lstm_params->cell_clip(); params->proj_clip = lstm_params->proj_clip(); + switch (lstm_params->kernel_type()) { + case LSTMKernelType_FULL: + params->kernel_type = kTfLiteLSTMFullKernel; + break; + case LSTMKernelType_BASIC: + params->kernel_type = kTfLiteLSTMBasicKernel; + break; + } } *builtin_data = reinterpret_cast(params); break; @@ -571,12 +555,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_PAD: { - break; - } - case BuiltinOperator_PADV2: { - break; - } case BuiltinOperator_RESHAPE: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { @@ -616,15 +594,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_SPACE_TO_BATCH_ND: { - break; - } - case BuiltinOperator_BATCH_TO_SPACE_ND: { - break; - } - case BuiltinOperator_TRANSPOSE: { - break; - } case BuiltinOperator_MEAN: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_MeanOptions()) { @@ -664,10 +633,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_MAXIMUM: - case BuiltinOperator_MINIMUM: { - break; - } case BuiltinOperator_ARG_MAX: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) { @@ -677,16 +642,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } - case BuiltinOperator_GREATER: - case BuiltinOperator_GREATER_EQUAL: - case BuiltinOperator_LESS: - case BuiltinOperator_LESS_EQUAL: - case BuiltinOperator_SELECT: { - break; - } - case BuiltinOperator_SLICE: { - break; - } case BuiltinOperator_TRANSPOSE_CONV: { TfLiteTransposeConvParams* params = MallocPOD(); @@ -699,11 +654,61 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_SPARSE_TO_DENSE: { + TfLiteSparseToDenseParams* params = + MallocPOD(); + if (auto* sparse_to_dense_params = + op->builtin_options_as_SparseToDenseOptions()) { + params->validate_indices = sparse_to_dense_params->validate_indices(); + } + *builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_DELEGATE: { // TODO(ycling): Revisit when supporting saving delegated models. error_reporter->Report("DELEGATE op shouldn't exist in model."); return kTfLiteError; } + + // Below are the ops with no builtin_data strcture. + case BuiltinOperator_BATCH_TO_SPACE_ND: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + case BuiltinOperator_CALL: + case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_CUSTOM: + case BuiltinOperator_DEQUANTIZE: + case BuiltinOperator_EMBEDDING_LOOKUP: + case BuiltinOperator_EQUAL: + case BuiltinOperator_EXP: + case BuiltinOperator_EXPAND_DIMS: + case BuiltinOperator_FLOOR: + case BuiltinOperator_GREATER: + case BuiltinOperator_GREATER_EQUAL: + case BuiltinOperator_LESS: + case BuiltinOperator_LESS_EQUAL: + case BuiltinOperator_LOG: + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_LOG_SOFTMAX: + case BuiltinOperator_MAXIMUM: + case BuiltinOperator_MINIMUM: + case BuiltinOperator_NEG: + case BuiltinOperator_NOT_EQUAL: + case BuiltinOperator_PAD: + case BuiltinOperator_PADV2: + case BuiltinOperator_PRELU: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU6: + case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_SELECT: + case BuiltinOperator_SIN: + case BuiltinOperator_SLICE: + case BuiltinOperator_SPACE_TO_BATCH_ND: + case BuiltinOperator_TANH: + case BuiltinOperator_TILE: + case BuiltinOperator_TOPK_V2: + case BuiltinOperator_TRANSPOSE: + break; } return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc index e6c8d966f1aff5a867f9469f8fcdec526df84763..c7e08814fdf502f1ecfea60af3385fc7aa6055fa 100644 --- a/tensorflow/contrib/lite/models/smartreply/predictor_test.cc +++ b/tensorflow/contrib/lite/models/smartreply/predictor_test.cc @@ -35,8 +35,8 @@ const char kModelName[] = "smartreply_ondevice_model.bin"; const char kSamples[] = "smartreply_samples.tsv"; string TestDataPath() { - return string(StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/", - "contrib/lite/models/testdata/")); + return string(absl::StrCat(tensorflow::testing::TensorFlowSrcRoot(), "/", + "contrib/lite/models/testdata/")); } MATCHER_P(IncludeAnyResponesIn, expected_response, "contains the response") { @@ -55,7 +55,7 @@ class PredictorTest : public ::testing::Test { protected: PredictorTest() { model_ = tflite::FlatBufferModel::BuildFromFile( - StrCat(TestDataPath(), "/", kModelName).c_str()); + absl::StrCat(TestDataPath(), "/", kModelName).c_str()); CHECK(model_); } ~PredictorTest() override {} @@ -121,7 +121,7 @@ TEST_F(PredictorTest, BatchTest) { int total_triggers = 0; string line; - std::ifstream fin(StrCat(TestDataPath(), "/", kSamples)); + std::ifstream fin(absl::StrCat(TestDataPath(), "/", kSamples)); while (std::getline(fin, line)) { const std::vector fields = absl::StrSplit(line, '\t'); if (fields.empty()) { diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 21f1df0be320ba433d623641dc6abdf7d1f3ab69..c71ad1d37937228662798a7370d66147191032ec 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -159,7 +159,6 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, nn_type, static_cast(tensor->dims->size), reinterpret_cast(tensor->dims->data), scale, zeroPoint}; CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); - // TODO(aselle): Based on Michael's suggestion, limiting this to read // only memory if (tensor->allocation_type == kTfLiteMmapRo) { @@ -172,7 +171,12 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, CHECK_NN(ANeuralNetworksModel_setOperandValue( nn_model, next_id, tensor->data.raw, tensor->bytes)); } + } else if (tensor->bytes == 0) { + // These size 0 tensors are optional tensors reserved. + CHECK_NN( + ANeuralNetworksModel_setOperandValue(nn_model, next_id, nullptr, 0)); } + ++next_id; } return next_id; @@ -181,7 +185,9 @@ uint32_t addTensorOperands(tflite::Interpreter* interpreter, // Adds the operations and their parameters to the NN API model. // 'next-id' is the operand ID of the next operand of the model. void AddOpsAndParams(tflite::Interpreter* interpreter, - ANeuralNetworksModel* nn_model, uint32_t next_id) { + ANeuralNetworksModel* nn_model, uint32_t next_id, + std::vector* model_state_inputs, + std::vector* model_state_outputs) { for (size_t i = 0; i < interpreter->nodes_size(); i++) { const auto* node_and_registration = interpreter->node_and_registration(i); const TfLiteNode& node = node_and_registration->first; @@ -192,6 +198,8 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, // Add the parameters. std::vector augmented_inputs( node.inputs->data, node.inputs->data + node.inputs->size); + std::vector augmented_outputs( + node.outputs->data, node.outputs->data + node.outputs->size); auto add_scalar_int32 = [&nn_model, &augmented_inputs, &next_id](int value) { @@ -211,15 +219,29 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, augmented_inputs.push_back(next_id++); }; + // Handle state tensors of RNN, LSTM, SVDF. + // For each state_out tensor, a corresponding state_in operand needs to be + // created for NNAPI. auto duplicate_state_tensor_float32 = - [interpreter, &nn_model, &augmented_inputs](int tensor_id) { + [interpreter, &nn_model, &next_id, &augmented_inputs, + &model_state_inputs, &model_state_outputs](int tensor_id) { const TfLiteTensor* tensor = interpreter->tensor(tensor_id); - CHECK_NN(ANeuralNetworksModel_setOperandValue( - nn_model, tensor_id, tensor->data.raw, tensor->bytes)); - augmented_inputs.push_back(tensor_id); + ANeuralNetworksOperandType operand_type{ + ANEURALNETWORKS_TENSOR_FLOAT32, + static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), + tensor->params.scale, tensor->params.zero_point}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + augmented_inputs.push_back(next_id); + model_state_inputs->push_back(next_id); + model_state_outputs->push_back(tensor_id); + next_id++; }; - auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); }; + auto add_add_params = [&add_scalar_int32](void* data) { + auto* builtin = reinterpret_cast(data); + add_scalar_int32(builtin->activation); + }; auto add_pooling_params = [&add_scalar_int32](void* data) { auto builtin = reinterpret_cast(data); @@ -279,39 +301,62 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, add_scalar_float32(builtin->proj_clip); }; + // LSTM in NNAPI requires scratch tensor as an output operand. + auto add_lstm_scratch_tensor_float32 = [interpreter, &node, &nn_model, + &next_id, &augmented_outputs]() { + int scratch_buffer_index = node.temporaries->data[0]; + const TfLiteTensor* tensor = interpreter->tensor(scratch_buffer_index); + ANeuralNetworksOperandType operand_type{ + ANEURALNETWORKS_TENSOR_FLOAT32, + static_cast(tensor->dims->size), + reinterpret_cast(tensor->dims->data), tensor->params.scale, + tensor->params.zero_point}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)); + augmented_outputs.insert(augmented_outputs.begin(), next_id++); + }; + auto add_mean_params = [&add_scalar_int32](void* data) { auto builtin = reinterpret_cast(data); add_scalar_int32(builtin->keep_dims); }; -#if 0 - auto add_reshape_params = [&](void* data) { - auto builtin = reinterpret_cast(data); - uint32_t tensor_size_shape = builtin->num_dimensions; - ANeuralNetworksOperandType operand_type{ - ANEURALNETWORKS_TENSOR_INT32, - {static_cast(1), - reinterpret_cast(&tensor_size_shape)}, - 0, - 0}; - CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) - CHECK_NN(ANeuralNetworksModel_setOperandValue( - nn_model, next_id, builtin->shape, - sizeof(int) * builtin->num_dimensions)); - augmented_inputs.push_back(next_id++); + auto add_svdf_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->rank); + add_scalar_int32(builtin->activation); }; -#endif + + auto add_rnn_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast(data); + add_scalar_int32(builtin->activation); + }; + + // Handle optional input tensors. + auto add_optional_tensors = [&nn_model, &augmented_inputs, + &next_id](int nn_type) { + for (size_t idx = 0; idx < augmented_inputs.size(); idx++) { + if (augmented_inputs[idx] == kOptionalTensor) { + const std::vector dim = {0, 0}; + ANeuralNetworksOperandType operand_type{nn_type, 2, dim.data(), 0, 0}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue(nn_model, next_id, + nullptr, 0)) + augmented_inputs[idx] = next_id++; + } + } + }; + int nnapi_version = 10; ANeuralNetworksOperationType nn_op_type; switch (builtin) { case tflite::BuiltinOperator_ADD: nn_op_type = ANEURALNETWORKS_ADD; - add_add_params(); + add_add_params(node.builtin_data); break; case tflite::BuiltinOperator_MUL: nn_op_type = ANEURALNETWORKS_MUL; - add_add_params(); + add_add_params(node.builtin_data); break; case tflite::BuiltinOperator_AVERAGE_POOL_2D: add_pooling_params(node.builtin_data); @@ -370,13 +415,31 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, break; case tflite::BuiltinOperator_LSTM: { duplicate_state_tensor_float32( - node.outputs->data[/*kOutputStateTensor*/ 1]); + node.outputs->data[/*kOutputStateTensor*/ 0]); duplicate_state_tensor_float32( - node.outputs->data[/*kCellStateTensor*/ 2]); + node.outputs->data[/*kCellStateTensor*/ 1]); add_lstm_params(node.builtin_data); + add_lstm_scratch_tensor_float32(); + add_optional_tensors(ANEURALNETWORKS_TENSOR_FLOAT32); nn_op_type = ANEURALNETWORKS_LSTM; break; } + case tflite::BuiltinOperator_SVDF: { + duplicate_state_tensor_float32(node.outputs->data[/*kStateTensor*/ 0]); + add_svdf_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_SVDF; + break; + } + case tflite::BuiltinOperator_RNN: { + duplicate_state_tensor_float32( + node.outputs->data[/*kHiddenStateTensor*/ 0]); + add_rnn_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_RNN; + break; + } + case tflite::BuiltinOperator_EMBEDDING_LOOKUP: + nn_op_type = ANEURALNETWORKS_EMBEDDING_LOOKUP; + break; case tflite::BuiltinOperator_PAD: nnapi_version = 11; // require NNAPI 1.1 nn_op_type = ANEURALNETWORKS_PAD; @@ -396,12 +459,9 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, break; case tflite::BuiltinOperator_CONCAT_EMBEDDINGS: case tflite::BuiltinOperator_LSH_PROJECTION: - case tflite::BuiltinOperator_SVDF: case tflite::BuiltinOperator_HASHTABLE_LOOKUP: - case tflite::BuiltinOperator_RNN: case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: - case tflite::BuiltinOperator_EMBEDDING_LOOKUP: case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: @@ -437,7 +497,13 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SELECT: case tflite::BuiltinOperator_SLICE: case tflite::BuiltinOperator_SIN: + case tflite::BuiltinOperator_LOG: case tflite::BuiltinOperator_TRANSPOSE_CONV: + case tflite::BuiltinOperator_TILE: + case tflite::BuiltinOperator_EXPAND_DIMS: + case tflite::BuiltinOperator_SPARSE_TO_DENSE: + case tflite::BuiltinOperator_EQUAL: + case tflite::BuiltinOperator_NOT_EQUAL: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; @@ -454,8 +520,9 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, // Add the operation. CHECK_NN(ANeuralNetworksModel_addOperation( nn_model, nn_op_type, static_cast(augmented_inputs.size()), - augmented_inputs.data(), static_cast(node.outputs->size), - reinterpret_cast(node.outputs->data))); + augmented_inputs.data(), + static_cast(augmented_outputs.size()), + reinterpret_cast(augmented_outputs.data()))); } } @@ -479,12 +546,25 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) { } uint32_t next_id = addTensorOperands(interpreter, nn_model_, skip_list); - AddOpsAndParams(interpreter, nn_model_, next_id); + AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_, + &model_states_outputs_); + + std::vector augmented_inputs = interpreter->inputs(); + std::vector augmented_outputs = interpreter->outputs(); + + // All state tensors input/output need to be treated as model input/output. + augmented_inputs.insert(augmented_inputs.end(), + model_states_inputs_.begin(), + model_states_inputs_.end()); + augmented_outputs.insert(augmented_outputs.end(), + model_states_outputs_.begin(), + model_states_outputs_.end()); + CHECK_NN(ANeuralNetworksModel_identifyInputsAndOutputs( - nn_model_, static_cast(interpreter->inputs().size()), - reinterpret_cast(interpreter->inputs().data()), - static_cast(interpreter->outputs().size()), - reinterpret_cast(interpreter->outputs().data()))); + nn_model_, static_cast(augmented_inputs.size()), + reinterpret_cast(augmented_inputs.data()), + static_cast(augmented_outputs.size()), + reinterpret_cast(augmented_outputs.data()))); CHECK_NN(ANeuralNetworksModel_finish(nn_model_)); } if (!nn_compiled_model_) { @@ -511,6 +591,7 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { CHECK_NN(ANeuralNetworksExecution_setInput( execution, i, nullptr, tensor->data.raw, tensor->bytes)); } + // Tell nn api where to place final data. for (size_t i = 0; i < interpreter->outputs().size(); i++) { int output = interpreter->outputs()[i]; @@ -518,6 +599,24 @@ TfLiteStatus NNAPIDelegate::Invoke(Interpreter* interpreter) { CHECK_NN(ANeuralNetworksExecution_setOutput( execution, i, nullptr, tensor->data.raw, tensor->bytes)); } + + // The state_out of previous invocation need to be mapped to state_in of + // current invocation. + for (size_t i = 0; i < model_states_outputs_.size(); i++) { + int state_tensor_idx = model_states_outputs_[i]; + TfLiteTensor* tensor = interpreter->tensor(state_tensor_idx); + // Here we are using a deep copy for state_in tensors so that we are not + // reading and writing into the same buffer during a invocation. + // TODO(miaowang): using double shared buffer to minimize the copies. + CHECK_NN(ANeuralNetworksExecution_setInput( + execution, i + interpreter->inputs().size(), nullptr, tensor->data.raw, + tensor->bytes)); + // Tell NNAPI where to output the state_out. + CHECK_NN(ANeuralNetworksExecution_setOutput( + execution, i + interpreter->outputs().size(), nullptr, tensor->data.raw, + tensor->bytes)); + } + // Currently use blocking compute. ANeuralNetworksEvent* event = nullptr; CHECK_NN(ANeuralNetworksExecution_startCompute(execution, &event)); diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h index e98000929a1168c786f6c18f498f9d1d72311ada..94dea4f9b23f208fddbacd3c77d889ea753a8a1d 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.h +++ b/tensorflow/contrib/lite/nnapi_delegate.h @@ -59,6 +59,14 @@ class NNAPIDelegate { ANeuralNetworksModel* nn_model_ = nullptr; // The NN API compilation handle ANeuralNetworksCompilation* nn_compiled_model_ = nullptr; + + // List of state tensors for LSTM, RNN, SVDF. + // NN API does not allow ops to maintain states across multiple + // invocations. We need to manually create state input tensors from + // corresponding state output tensors of TFLite operations, and map them + // correctly. + std::vector model_states_inputs_; + std::vector model_states_outputs_; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h index 38a27069421586f28a5fbe4c7880a28f80548b98..9d7e3f20854a3596181ffa885cc17cfdbd16356e 100644 --- a/tensorflow/contrib/lite/op_resolver.h +++ b/tensorflow/contrib/lite/op_resolver.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" +#include "tensorflow/contrib/lite/util.h" namespace tflite { @@ -55,8 +56,7 @@ struct OperatorKeyHasher { size_t operator()(const T& x) const { size_t a = ValueHasher()(x.first); size_t b = ValueHasher()(x.second); - // Hash combinator used by TensorFlow core. - return a ^ (b + 0x9e3779b97f4a7800ULL + (a << 10) + (a >> 4)); + return CombineHashes({a, b}); } }; } // namespace op_resolver_hasher diff --git a/tensorflow/contrib/lite/profiling/BUILD b/tensorflow/contrib/lite/profiling/BUILD index 15999e5d4188db1e191936ae6d84faf8cce5ca6e..a162b87b8f98576ec7c3b2623d1d34f2baef6cce 100644 --- a/tensorflow/contrib/lite/profiling/BUILD +++ b/tensorflow/contrib/lite/profiling/BUILD @@ -2,9 +2,11 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + common_copts = [ "-Wall", -] +] + tflite_copts() cc_library( name = "profiler", @@ -29,6 +31,43 @@ cc_library( name = "profile_buffer", hdrs = ["profile_buffer.h"], copts = common_copts, + deps = [":time"], +) + +cc_library( + name = "time", + srcs = ["time.cc"], + hdrs = ["time.h"], + copts = common_copts, +) + +cc_library( + name = "profile_summarizer", + srcs = ["profile_summarizer.cc"], + hdrs = ["profile_summarizer.h"], + copts = common_copts, + deps = [ + ":profiler", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/core:stats_calculator_portable", + ], +) + +cc_test( + name = "profile_summarizer_test", + srcs = ["profile_summarizer_test.cc"], + copts = common_copts, + deps = [ + ":profile_summarizer", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:kernel_util", + "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], ) cc_test( diff --git a/tensorflow/contrib/lite/profiling/profile_buffer.h b/tensorflow/contrib/lite/profiling/profile_buffer.h index 299b2a9cad161ce05ba68f39cf612f9866a0b656..65d86dce47f397c7dad6cc2beb8ffa1f95b29d45 100644 --- a/tensorflow/contrib/lite/profiling/profile_buffer.h +++ b/tensorflow/contrib/lite/profiling/profile_buffer.h @@ -18,6 +18,8 @@ limitations under the License. #include #include +#include "tensorflow/contrib/lite/profiling/time.h" + namespace tflite { namespace profiling { @@ -74,7 +76,7 @@ class ProfileBuffer { if (!enabled_) { return kInvalidEventHandle; } - uint64_t timestamp = NowMicros(); + uint64_t timestamp = time::NowMicros(); int index = current_index_ % event_buffer_.size(); event_buffer_[index].tag = tag; event_buffer_[index].event_type = event_type; @@ -103,7 +105,7 @@ class ProfileBuffer { } int event_index = event_handle % max_size; - event_buffer_[event_index].end_timestamp_us = NowMicros(); + event_buffer_[event_index].end_timestamp_us = time::NowMicros(); } // Returns the size of the buffer. @@ -134,12 +136,6 @@ class ProfileBuffer { } private: - static uint64_t NowMicros() { - // TODO(shashishekhar): Refactor this to a separate file. - struct timeval tv; - gettimeofday(&tv, nullptr); - return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; - } bool enabled_; uint32_t current_index_; std::vector event_buffer_; diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.cc b/tensorflow/contrib/lite/profiling/profile_summarizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..45388b500c7897c8b33b49eb6ab4e9f8c4fdb37c --- /dev/null +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.cc @@ -0,0 +1,148 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" + +#include + +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { +namespace profiling { +namespace { + +using Detail = tensorflow::StatsCalculator::Detail; + +struct OperatorDetails { + std::string name; + std::vector inputs; + std::vector outputs; +}; + +std::string GetTensorName(const tflite::Interpreter& interpreter, + int tensor_index) { + const auto tensor = interpreter.tensor(tensor_index); + if (tensor == nullptr || tensor->name == nullptr) { + return "Unknown"; + } + return tensor->name; +} +std::vector GetTensorNames(const tflite::Interpreter& interpreter, + const TfLiteIntArray* tensor_indices) { + std::vector tensors; + tensors.reserve(tensor_indices->size); + for (int i = 0; i < tensor_indices->size; i++) { + tensors.push_back(GetTensorName(interpreter, tensor_indices->data[i])); + } + return tensors; +} + +std::string ToString(const std::vector& str_vector) { + std::stringstream stream; + stream << "["; + bool first = true; + for (const auto& s : str_vector) { + if (!first) { + stream << ", "; + } else { + first = false; + } + stream << s; + } + stream << "]"; + return stream.str(); +} + +OperatorDetails GetOperatorDetails(const tflite::Interpreter& interpreter, + int node_index) { + auto node_reg = interpreter.node_and_registration(node_index); + auto inputs = node_reg->first.inputs; + auto outputs = node_reg->first.outputs; + int code = node_reg->second.builtin_code; + const char* op_name = nullptr; + if (code == tflite::BuiltinOperator_CUSTOM) { + const char* custom_name = node_reg->second.custom_name; + op_name = custom_name ? custom_name : "UnknownCustomOp"; + } else { + op_name = tflite::EnumNamesBuiltinOperator()[code]; + } + OperatorDetails details; + details.name = op_name; + details.inputs = GetTensorNames(interpreter, inputs); + details.outputs = GetTensorNames(interpreter, outputs); + return details; +} + +tensorflow::StatSummarizerOptions GetProfileSummarizerOptions() { + auto options = tensorflow::StatSummarizerOptions(); + options.show_summary = true; + options.show_memory = false; + return options; +} + +} // namespace + +ProfileSummarizer::ProfileSummarizer() + : stats_calculator_( + new ::tensorflow::StatsCalculator(GetProfileSummarizerOptions())) {} + +void ProfileSummarizer::ProcessProfiles( + const std::vector& profile_stats, + const tflite::Interpreter& interpreter) { + std::vector events; + std::copy_if(profile_stats.begin(), profile_stats.end(), + std::back_inserter(events), [](const ProfileEvent* e) { + return e->event_type == + ProfileEvent::EventType::OPERATOR_INVOKE_EVENT && + e->end_timestamp_us >= e->begin_timestamp_us; + }); + // Sort with begin_time. + std::sort(events.begin(), events.end(), + [](const ProfileEvent* const& a, const ProfileEvent* const& b) { + return a->begin_timestamp_us < b->begin_timestamp_us; + }); + if (events.empty()) { + return; + } + + int64_t base_start_us = events[0]->begin_timestamp_us; + int node_num = 0; + int64_t curr_total_us = 0; + std::map details; + for (auto event : events) { + auto op_details = GetOperatorDetails(interpreter, event->event_metadata); + auto node_name = ToString(op_details.outputs); + auto result = details.emplace(node_name, Detail()); + Detail* detail = &(result.first->second); + detail->start_us.UpdateStat(event->begin_timestamp_us - base_start_us); + int64_t node_exec_time = + event->end_timestamp_us - event->begin_timestamp_us; + detail->rel_end_us.UpdateStat(node_exec_time); + curr_total_us += node_exec_time; + ++node_num; + + if (result.second) { + detail->name = node_name; + detail->type = op_details.name; + detail->run_order = node_num; + detail->times_called = 0; + } + ++detail->times_called; + } + stats_calculator_->UpdateDetails(details); + stats_calculator_->UpdateRunTotalUs(curr_total_us); +} +} // namespace profiling +} // namespace tflite diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer.h b/tensorflow/contrib/lite/profiling/profile_summarizer.h new file mode 100644 index 0000000000000000000000000000000000000000..a529ff87428d70d002241311d7f70f185521020f --- /dev/null +++ b/tensorflow/contrib/lite/profiling/profile_summarizer.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_ +#define TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_ + +#include + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/profiling/profiler.h" +#include "tensorflow/core/util/stats_calculator.h" + +namespace tflite { +namespace profiling { + +// Creates a summary of operator invocations in the interpreter. +class ProfileSummarizer { + public: + ProfileSummarizer(); + virtual ~ProfileSummarizer() {} + + // Process profile events to update statistics for operator invocations. + void ProcessProfiles(const std::vector& profile_stats, + const tflite::Interpreter& interpreter); + + // Returns a string detailing the accumulated runtime stats in a tab-separated + // format which can be pasted into a spreadsheet for further analysis. + std::string GetOutputString() const { + return stats_calculator_->GetOutputString(); + } + + std::string GetShortSummary() const { + return stats_calculator_->GetShortSummary(); + } + + private: + std::unique_ptr stats_calculator_; +}; + +} // namespace profiling +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_PROFILE_SUMMARIZER_H_ diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..35cf780713b93db559f86dcaf62e1ac004b5049a --- /dev/null +++ b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc @@ -0,0 +1,116 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { +namespace profiling { + +namespace { + +TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0); + const TfLiteTensor* input2 = tflite::GetInput(context, node, /*index=*/1); + + TfLiteTensor* output = GetOutput(context, node, /*index=*/0); + + int32_t* output_data = output->data.i32; + *output_data = *(input1->data.i32) + *(input2->data.i32); + return kTfLiteOk; +} + +TfLiteRegistration* RegisterSimpleOp() { + static TfLiteRegistration registration = {nullptr, + nullptr, + nullptr, + SimpleOpEval, + tflite::BuiltinOperator_CUSTOM, + "SimpleOpEval", + 1}; + return ®istration; +} + +class SimpleOpModel : public SingleOpModel { + public: + void Init(); + tflite::Interpreter* GetInterpreter() { return interpreter_.get(); } + void SetInputs(int32_t x, int32_t y) { + PopulateTensor(inputs_[0], {x}); + PopulateTensor(inputs_[1], {y}); + } + int32_t GetOutput() { return ExtractVector(output_)[0]; } + + private: + int inputs_[2]; + int output_; +}; + +void SimpleOpModel::Init() { + inputs_[0] = AddInput({TensorType_INT32, {1}}); + inputs_[1] = AddInput({TensorType_INT32, {1}}); + output_ = AddOutput({TensorType_INT32, {}}); + SetCustomOp("SimpleAdd", {}, RegisterSimpleOp); + BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])}); +} + +TEST(ProfileSummarizerTest, Empty) { + ProfileSummarizer summarizer; + std::string output = summarizer.GetOutputString(); + EXPECT_GT(output.size(), 0); +} + +#ifdef TFLITE_PROFILING_ENABLED +TEST(ProfileSummarizerTest, Interpreter) { + Profiler profiler; + SimpleOpModel m; + m.Init(); + auto interpreter = m.GetInterpreter(); + interpreter->SetProfiler(&profiler); + profiler.StartProfiling(); + m.SetInputs(1, 2); + m.Invoke(); + // 3 = 1 + 2 + EXPECT_EQ(m.GetOutput(), 3); + profiler.StopProfiling(); + ProfileSummarizer summarizer; + auto events = profiler.GetProfileEvents(); + EXPECT_EQ(1, events.size()); + summarizer.ProcessProfiles(profiler.GetProfileEvents(), *interpreter); + auto output = summarizer.GetOutputString(); + // TODO(shashishekhar): Add a better test here. + ASSERT_TRUE(output.find("SimpleOp") != std::string::npos) << output; +} +#endif + +} // namespace +} // namespace profiling +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/profiling/time.cc b/tensorflow/contrib/lite/profiling/time.cc new file mode 100644 index 0000000000000000000000000000000000000000..446660bb747cd6e3b694669b64ac1d95cf415fbe --- /dev/null +++ b/tensorflow/contrib/lite/profiling/time.cc @@ -0,0 +1,29 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/profiling/time.h" + +#include + +namespace tflite { +namespace profiling { +namespace time { +uint64_t NowMicros() { + struct timeval tv; + gettimeofday(&tv, nullptr); + return static_cast(tv.tv_sec) * 1000000 + tv.tv_usec; +} +} // namespace time +} // namespace profiling +} // namespace tflite diff --git a/tensorflow/contrib/lite/profiling/time.h b/tensorflow/contrib/lite/profiling/time.h new file mode 100644 index 0000000000000000000000000000000000000000..cc2ec319b8a95b3efa0aab0ac9f97a88bf7b5536 --- /dev/null +++ b/tensorflow/contrib/lite/profiling/time.h @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ +#define TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ + +#include + +namespace tflite { +namespace profiling { +namespace time { +uint64_t NowMicros(); +} // namespace time +} // namespace profiling +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_PROFILING_TIME_H_ diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 4920e83970d1cb7f60a38f95ea05986d52b0bbe7..27909a9458f6b09f96cb556a5254f01e54f46e05 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -36,6 +36,16 @@ py_test( ], ) +py_binary( + name = "tflite_convert", + srcs = ["tflite_convert.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":lite", + ], +) + py_library( name = "lite", srcs = ["lite.py"], @@ -45,7 +55,23 @@ py_library( ":convert", ":convert_saved_model", ":interpreter", + ":lite_constants", ":op_hint", + "//tensorflow/python:graph_util", + "//tensorflow/python/saved_model:constants", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/tools:freeze_graph_lib", + ], +) + +py_test( + name = "lite_test", + srcs = ["lite_test.py"], + data = [":interpreter_test_data"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":lite", ], ) @@ -111,9 +137,9 @@ py_library( visibility = ["//visibility:public"], deps = [ ":convert", - ":lite_constants", "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/python:graph_util", + "//tensorflow/python:platform", "//tensorflow/python/tools:freeze_graph_lib", ], ) @@ -150,20 +176,3 @@ py_test( "//tensorflow/python/saved_model", ], ) - -py_binary( - name = "convert_saved_model_to_frozen_graph", - srcs = ["convert_saved_model_to_frozen_graph.py"], - srcs_version = "PY2AND3", - deps = [ - ":convert_saved_model", - ], -) - -# Transitive dependencies of this target will be included in the pip package. -py_library( - name = "tf_lite_py_pip", - deps = [ - ":convert_saved_model", - ], -) diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index c0926d2f33c0bbc5111e6df90dbd759172021f95..c038c88945b71f30bf091a1098dcf853f5415b1b 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -111,38 +111,75 @@ def tensor_name(x): return x.name.split(":")[0] -def toco_convert(input_data, - input_tensors, - output_tensors, - inference_type=lite_constants.FLOAT, - input_format=lite_constants.TENSORFLOW_GRAPHDEF, - output_format=lite_constants.TFLITE, - quantized_input_stats=None, - drop_control_dependency=True, - allow_custom_ops=False): - """Convert a model using TOCO from `input_format` to `output_format`. +def build_toco_convert_protos(input_tensors, + output_tensors, + inference_type=lite_constants.FLOAT, + inference_input_type=None, + input_format=lite_constants.TENSORFLOW_GRAPHDEF, + output_format=lite_constants.TFLITE, + quantized_input_stats=None, + default_ranges_stats=None, + drop_control_dependency=True, + reorder_across_fake_quant=False, + allow_custom_ops=False, + change_concat_input_ranges=False, + quantize_weights=False, + dump_graphviz_dir=None, + dump_graphviz_video=False): + """Builds protocol buffers describing a conversion of a model using TOCO. Typically this is to convert from TensorFlow GraphDef to TFLite, in which case the default `input_format` and `output_format` are sufficient. Args: - input_data: Input data (i.e. often `sess.graph_def`). input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT) - quantized_input_stats: For each member of input_tensors the mean and - std deviation of training data. Only needed if `inference_type` is - `QUANTIZED_UINT8`. - drop_control_dependency: Drops control dependencies silently. This is due - to tf lite not supporting control dependencies. + inference_type: Target data type of arrays in the output file. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of input arrays. Allows for a + different type for input arrays in the case of quantization. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + input_format: Type of data to read Currently must be + `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) + output_format: Output file format. Currently must be `{TFLITE, + GRAPHVIZ_DOT}`. (default TFLITE) + quantized_input_stats: List of tuples of integers representing the mean and + standard deviation. Each tuple maps to the corresponding input tensor. + Only need if `inference_type` is `QUANTIZED_UINT8`. (default None) + default_ranges_stats: Tuple of integers representing (min, max) range values + for all arrays without a specified range. Intended for experimenting with + quantization via "dummy quantization". (default None) + drop_control_dependency: Boolean indicating whether to drop control + dependencies silently. This is due to TFLite not supporting control + dependencies. (default True) + reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant + nodes in unexpected locations. Used when the location of the FakeQuant + nodes is preventing graph transformations necessary to convert the graph. + Results in a graph that differs from the quantized training graph, + potentially causing differing arithmetic behavior. (default False) + allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. + (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) + quantize_weights: Boolean indicating whether to store weights as quantized + weights followed by dequantize operations. Computation is still done in + float, but reduces model size (at the cost of accuracy and latency). + (default False) + dump_graphviz_dir: Full filepath of folder to dump the graphs at various + stages of processing GraphViz .dot files. Preferred over + --output_format=GRAPHVIZ_DOT in order to keep the requirements of the + output file. (default None) + dump_graphviz_video: Boolean indicating whether to dump the graph after + every graph transformation. (default False) Returns: - The converted data. For example if tflite was the destination, then - this will be a tflite flatbuffer in a bytes array. + model_flags, toco_flags: two protocol buffers describing the conversion + process. Raises: ValueError: If the input tensor type is unknown @@ -152,10 +189,21 @@ def toco_convert(input_data, toco = _toco_flags_pb2.TocoFlags() toco.input_format = input_format toco.output_format = output_format - toco.drop_control_dependency = drop_control_dependency - model = _model_flags_pb2.ModelFlags() toco.inference_type = inference_type + if inference_input_type: + toco.inference_input_type = inference_input_type + toco.drop_control_dependency = drop_control_dependency + toco.reorder_across_fake_quant = reorder_across_fake_quant toco.allow_custom_ops = allow_custom_ops + toco.quantize_weights = quantize_weights + if default_ranges_stats: + toco.default_ranges_min = default_ranges_stats[0] + toco.default_ranges_max = default_ranges_stats[1] + if dump_graphviz_dir: + toco.dump_graphviz_dir = dump_graphviz_dir + toco.dump_graphviz_include_video = dump_graphviz_video + model = _model_flags_pb2.ModelFlags() + model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): if input_tensor.dtype == _dtypes.float32: tflite_input_type = lite_constants.FLOAT @@ -163,6 +211,8 @@ def toco_convert(input_data, tflite_input_type = lite_constants.INT32 elif input_tensor.dtype == _dtypes.int64: tflite_input_type = lite_constants.INT64 + elif input_tensor.dtype == _dtypes.uint8: + tflite_input_type = lite_constants.QUANTIZED_UINT8 # TODO(aselle): Insert strings when they are available else: raise ValueError("Tensors %s not known type %r" % (input_tensor.name, @@ -180,10 +230,35 @@ def toco_convert(input_data, for output_tensor in output_tensors: model.output_arrays.append(tensor_name(output_tensor)) + return model, toco + + +def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): + """"Convert a model using TOCO. + + Typically this function is used to convert from TensorFlow GraphDef to TFLite. + Conversion can be customized by providing arguments that are forwarded to + `build_toco_convert_protos` (see documentation for details). + + Args: + input_data: Input data (i.e. often `sess.graph_def`), + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + *args: See `build_toco_convert_protos`, + **kwargs: See `build_toco_convert_protos`. - # TODO(aselle): Consider handling the case of allowing quantized - # inputs to be converted to float (via the toco.inference_input_type field). - data = toco_convert_protos(model.SerializeToString(), - toco.SerializeToString(), + Returns: + The converted data. For example if TFLite was the destination, then + this will be a tflite flatbuffer in a bytes array. + + Raises: + Defined in `build_toco_convert_protos`. + """ + model_flags, toco_flags = build_toco_convert_protos(input_tensors, + output_tensors, + *args, **kwargs) + data = toco_convert_protos(model_flags.SerializeToString(), + toco_flags.SerializeToString(), input_data.SerializeToString()) return data diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py index a7eddf3408f54dff5fa49ff6fa7b61cd0b8a22e4..1553464b9fe30f596c151bcc67efe891bb913ba3 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model.py +++ b/tensorflow/contrib/lite/python/convert_saved_model.py @@ -18,34 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.lite.python import convert -from tensorflow.contrib.lite.python import lite_constants -from tensorflow.contrib.lite.toco import model_flags_pb2 -from tensorflow.contrib.saved_model.python.saved_model import reader -from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils +from tensorflow.contrib.lite.python.convert import tensor_name from tensorflow.core.framework import types_pb2 from tensorflow.python.client import session from tensorflow.python.framework import graph_util as tf_graph_util from tensorflow.python.framework import ops -from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader -from tensorflow.python.saved_model import signature_constants -from tensorflow.python.saved_model import tag_constants - - -def _write_and_flush_file(file_path, data_str): - """Writes data to file path. - - Args: - file_path: Full path of the file to store data in. - data_str: Data represented as a string. - - Returns: None. - """ - with gfile.Open(file_path, "wb") as data_file: - data_file.write(data_str) - data_file.flush() def _log_tensor_details(tensor_info): @@ -77,21 +57,8 @@ def _get_meta_graph_def(saved_model_dir, tag_set): Raises: ValueError: No valid MetaGraphDef for given tag_set. """ - saved_model = reader.read_saved_model(saved_model_dir) - tag_sets = [] - result_meta_graph_def = None - for meta_graph_def in saved_model.meta_graphs: - meta_graph_tag_set = set(meta_graph_def.meta_info_def.tags) - tag_sets.append(meta_graph_tag_set) - if meta_graph_tag_set == tag_set: - result_meta_graph_def = meta_graph_def - logging.info("The given saved_model contains the following tags: %s", - tag_sets) - if result_meta_graph_def is not None: - return result_meta_graph_def - else: - raise ValueError("No valid MetaGraphDef for this tag_set '{}'. Possible " - "values are '{}'. ".format(tag_set, tag_sets)) + with session.Session(graph=ops.Graph()) as sess: + return loader.load(sess, tag_set, saved_model_dir) def _get_signature_def(meta_graph, signature_key): @@ -110,15 +77,13 @@ def _get_signature_def(meta_graph, signature_key): signature_def_map = meta_graph.signature_def signature_def_keys = set(signature_def_map.keys()) logging.info( - "The given saved_model MetaGraphDef contains SignatureDefs with the " + "The given SavedModel MetaGraphDef contains SignatureDefs with the " "following keys: %s", signature_def_keys) if signature_key not in signature_def_keys: - raise ValueError("No '{}' in the saved_model\'s SignatureDefs. Possible " - "values are '{}'. ".format(signature_key, - signature_def_keys)) - signature_def = signature_def_utils.get_signature_def_by_key( - meta_graph, signature_key) - return signature_def + raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible " + "values are '{}'.".format(signature_key, + ",".join(signature_def_keys))) + return signature_def_map[signature_key] def _get_inputs_outputs(signature_def): @@ -170,29 +135,10 @@ def _get_tensors(graph, signature_def_tensor_names=None, """ tensors = [] if user_tensor_names: - # Get the list of all of the tensors with and without the tensor index. - all_tensor_names = [ - tensor.name for op in graph.get_operations() for tensor in op.outputs - ] - all_tensor_names_only = [name.split(":")[0] for name in all_tensor_names] - # Sort the tensor names. user_tensor_names = sorted(user_tensor_names) - # Get the tensors associated with the tensor names. - tensors = [] - invalid_tensors = [] - for name in user_tensor_names: - if name not in all_tensor_names_only: - invalid_tensors.append(name) - else: - idx = all_tensor_names_only.index(name) - tensors.append(graph.get_tensor_by_name(all_tensor_names[idx])) - - # Throw ValueError if any user input names are not valid tensors. - if invalid_tensors: - raise ValueError("Invalid tensors '{}' were found.".format( - ",".join(invalid_tensors))) + tensors = get_tensors_from_tensor_names(graph, user_tensor_names) elif signature_def_tensor_names: tensors = [ graph.get_tensor_by_name(name) @@ -207,25 +153,74 @@ def _get_tensors(graph, signature_def_tensor_names=None, return tensors -def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, - output_arrays, tag_set, signature_key, batch_size): +def get_tensors_from_tensor_names(graph, tensor_names): + """Gets the Tensors associated with the `tensor_names` in the provided graph. + + Args: + graph: TensorFlow Graph. + tensor_names: List of strings that represent names of tensors in the graph. + + Returns: + A list of Tensor objects in the same order the names are provided. + + Raises: + ValueError: + tensor_names contains an invalid tensor name. + """ + # Get the list of all of the tensors. + tensor_name_to_tensor = { + tensor_name(tensor): tensor for op in graph.get_operations() + for tensor in op.values() + } + + # Get the tensors associated with tensor_names. + tensors = [] + invalid_tensors = [] + for name in tensor_names: + tensor = tensor_name_to_tensor.get(name) + if tensor is None: + invalid_tensors.append(name) + else: + tensors.append(tensor) + + # Throw ValueError if any user input names are not valid tensors. + if invalid_tensors: + raise ValueError("Invalid tensors '{}' were found.".format( + ",".join(invalid_tensors))) + return tensors + + +def set_tensor_shapes(tensors, shapes): + """Sets Tensor shape for each tensor if the shape is defined. + + Args: + tensors: TensorFlow ops.Tensor. + shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). + """ + if shapes: + for tensor in tensors: + shape = shapes.get(tensor_name(tensor)) + if shape is not None: + tensor.set_shape(shape) + + +def freeze_saved_model(saved_model_dir, input_arrays, input_shapes, + output_arrays, tag_set, signature_key): """Converts a SavedModel to a frozen graph. Args: saved_model_dir: SavedModel directory to convert. input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of + from SignatureDef when none are provided. + input_shapes: Dict of strings representing input tensor names to list of integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) + arrays from SignatureDef when none are provided. tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") + analyze. All tags in the tag set must be present. signature_key: Key identifying SignatureDef containing inputs and outputs. - batch_size: Batch size for the model. Replaces the first dimension of an - input size array if undefined. (default 1) Returns: frozen_graph_def: Frozen GraphDef. @@ -236,186 +231,32 @@ def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes, ValueError: SavedModel doesn't contain a MetaGraphDef identified by tag_set. signature_key is not in the MetaGraphDef. + assets/ directory is in the MetaGraphDef. input_shapes does not match the length of input_arrays. - input_shapes has a None value after the 1st dimension. input_arrays or output_arrays are not valid. - Unable to load Session. """ - # Set default values for inputs if they are set to None. - if signature_key is None: - signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - if tag_set is None: - tag_set = set([tag_constants.SERVING]) - if batch_size is None: - batch_size = 1 - # Read SignatureDef. meta_graph = _get_meta_graph_def(saved_model_dir, tag_set) signature_def = _get_signature_def(meta_graph, signature_key) inputs, outputs = _get_inputs_outputs(signature_def) + # Check SavedModel for assets directory. + collection_def = meta_graph.collection_def + if constants.ASSETS_KEY in collection_def: + raise ValueError("SavedModels with assets/ directory are not supported.") + graph = ops.Graph() with session.Session(graph=graph) as sess: - # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory. loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir) # Gets input and output tensors. # TODO(zhixianyan): Use TFLite supported Op list to filter outputs. in_tensors = _get_tensors(graph, inputs, input_arrays) out_tensors = _get_tensors(graph, outputs, output_arrays) - - # Gets fully defined tensor shape. An input tensor with None in the first - # dimension, e.g. (None, 224, 224, 3), is replaced with the batch_size. - # Shapes with None after the first dimension result in a ValueError. - # TODO(zhixianyan): Add supports for input tensor with more None in shape. - for tensor in in_tensors: - if (input_shapes and tensor.name in input_shapes and - input_shapes[tensor.name] is not None): - shape = input_shapes[tensor.name] - else: - shape = tensor.get_shape().as_list() - - if None in shape[1:]: - raise ValueError( - "None is only supported in the 1st dimension. Tensor '{0}' has " - "invalid shape '{1}'.".format(tensor.name, shape)) - elif shape[0] is None: - shape[0] = batch_size - tensor.set_shape(shape) + set_tensor_shapes(in_tensors, input_shapes) output_names = [node.split(":")[0] for node in outputs] frozen_graph_def = tf_graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), output_names) return frozen_graph_def, in_tensors, out_tensors - raise ValueError("Unable to load Session.") - - -def saved_model_to_frozen_graphdef( - saved_model_dir, - output_file_model, - output_file_flags, - input_arrays=None, - input_shapes=None, - output_arrays=None, - tag_set=None, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - batch_size=1): - """Converts a SavedModel to a frozen graph. Writes graph to tmp directory. - - Stores frozen graph and command line flags in the tmp directory. - - Args: - saved_model_dir: SavedModel directory to convert. - output_file_model: Full file path to save frozen graph. - output_file_flags: Full file path to save ModelFlags. - input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of - integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). - Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) - output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) - tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") - signature_key: Key identifying SignatureDef containing inputs and outputs. - batch_size: Batch size for the model. Replaces the first dimension of an - input size array if undefined. (default 1) - - Returns: None. - - Raises: - ValueError: Unable to convert to frozen graph. - """ - frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( - saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, - signature_key, batch_size) - - # Initialize model flags. - model = model_flags_pb2.ModelFlags() - - for input_tensor in in_tensors: - input_array = model.input_arrays.add() - input_array.name = convert.tensor_name(input_tensor) - input_array.shape.dims.extend(map(int, input_tensor.get_shape())) - - for output_tensor in out_tensors: - model.output_arrays.append(convert.tensor_name(output_tensor)) - - # Write model and ModelFlags to file. ModelFlags contain input array and - # output array information that is parsed from the SignatureDef and used for - # analysis by TOCO. - _write_and_flush_file(output_file_model, frozen_graph_def.SerializeToString()) - _write_and_flush_file(output_file_flags, model.SerializeToString()) - - -def tflite_from_saved_model( - saved_model_dir, - output_file=None, - input_arrays=None, - input_shapes=None, - output_arrays=None, - tag_set=None, - signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - batch_size=1, - inference_type=lite_constants.FLOAT, - input_format=lite_constants.TENSORFLOW_GRAPHDEF, - output_format=lite_constants.TFLITE, - quantized_input_stats=None, - drop_control_dependency=True): - """Converts a SavedModel to TFLite FlatBuffer. - - Args: - saved_model_dir: SavedModel directory to convert. - output_file: File path to write result TFLite FlatBuffer. - input_arrays: List of input tensors to freeze graph with. Uses input arrays - from SignatureDef when none are provided. (default None) - input_shapes: Map of strings representing input tensor names to list of - integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}). - Automatically determined when input shapes is None (e.g., {"foo" : None}). - (default None) - output_arrays: List of output tensors to freeze graph with. Uses output - arrays from SignatureDef when none are provided. (default None) - tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to - analyze. All tags in the tag set must be present. (default "serve") - signature_key: Key identifying SignatureDef containing inputs and outputs. - batch_size: Batch size for the model. Replaces the first dimension of an - input size array if undefined. (default 1) - inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`. - input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF). - output_format: Type of data to write (currently must be TFLITE or - GRAPHVIZ_DOT) - quantized_input_stats: For each member of input_tensors the mean and - std deviation of training data. Only needed if `inference_type` is - `QUANTIZED_UINT8`. - drop_control_dependency: Drops control dependencies silently. This is due - to tf lite not supporting control dependencies. - - Returns: - The converted data. For example if tflite was the destination, then - this will be a tflite flatbuffer in a bytes array. - - Raises: - ValueError: Unable to convert to frozen graph. - """ - frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model( - saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set, - signature_key, batch_size) - - result = convert.toco_convert( - input_data=frozen_graph_def, - input_tensors=in_tensors, - output_tensors=out_tensors, - inference_type=inference_type, - input_format=input_format, - output_format=output_format, - quantized_input_stats=quantized_input_stats, - drop_control_dependency=drop_control_dependency) - - if output_file is not None: - with gfile.Open(output_file, "wb") as f: - f.write(result) - logging.info("Successfully converted to: %s", output_file) - - return result diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py index db95fc8ad7f94b52d33c72f6ec5819bdfe8cf05f..92c4ebb2465c2abaa1cefd020e69b2f7ad6a54a5 100644 --- a/tensorflow/contrib/lite/python/convert_saved_model_test.py +++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py @@ -25,12 +25,12 @@ from __future__ import print_function import os from tensorflow.contrib.lite.python import convert_saved_model -from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.python import keras from tensorflow.python.client import session from tensorflow.python.estimator import estimator_lib as estimator from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.layers import layers from tensorflow.python.ops import array_ops @@ -38,13 +38,68 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops from tensorflow.python.ops.losses import losses -from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import saved_model +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import training as train -class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): +class TensorFunctionsTest(test_util.TensorFlowTestCase): + + def testGetTensorsValid(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + tensors = convert_saved_model.get_tensors_from_tensor_names( + sess.graph, ["Placeholder"]) + self.assertEqual("Placeholder:0", tensors[0].name) + + def testGetTensorsInvalid(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + with self.assertRaises(ValueError) as error: + convert_saved_model.get_tensors_from_tensor_names(sess.graph, + ["invalid-input"]) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + def testSetTensorShapeValid(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [5, 3, 5]}) + self.assertEqual([5, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeNoneValid(self): + tensor = array_ops.placeholder(dtype=dtypes.float32) + self.assertEqual(None, tensor.shape) + + convert_saved_model.set_tensor_shapes([tensor], {"Placeholder": [1, 3, 5]}) + self.assertEqual([1, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeInvalid(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], + {"invalid-input": [5, 3, 5]}) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + def testSetTensorShapeEmpty(self): + tensor = array_ops.placeholder(shape=[None, 3, 5], dtype=dtypes.float32) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + convert_saved_model.set_tensor_shapes([tensor], {}) + self.assertEqual([None, 3, 5], tensor.shape.as_list()) + + +class FreezeSavedModelTest(test_util.TensorFlowTestCase): def _createSimpleSavedModel(self, shape): """Create a simple SavedModel on the fly.""" @@ -57,82 +112,167 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): saved_model.simple_save(sess, saved_model_dir, inputs, outputs) return saved_model_dir + def _createSavedModelTwoInputArrays(self, shape): + """Create a simple SavedModel.""" + saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") + with session.Session() as sess: + in_tensor_1 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name="inputB") + in_tensor_2 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name="inputA") + out_tensor = in_tensor_1 + in_tensor_2 + inputs = {"x": in_tensor_1, "y": in_tensor_2} + outputs = {"z": out_tensor} + saved_model.simple_save(sess, saved_model_dir, inputs, outputs) + return saved_model_dir + + def _getArrayNames(self, tensors): + return [tensor.name for tensor in tensors] + + def _getArrayShapes(self, tensors): + dims = [] + for tensor in tensors: + dim_tensor = [] + for dim in tensor.shape: + if isinstance(dim, tensor_shape.Dimension): + dim_tensor.append(dim.value) + else: + dim_tensor.append(dim) + dims.append(dim_tensor) + return dims + + def _convertSavedModel(self, + saved_model_dir, + input_arrays=None, + input_shapes=None, + output_arrays=None, + tag_set=None, + signature_key=None): + if tag_set is None: + tag_set = set([tag_constants.SERVING]) + if signature_key is None: + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + graph_def, in_tensors, out_tensors = convert_saved_model.freeze_saved_model( + saved_model_dir=saved_model_dir, + input_arrays=input_arrays, + input_shapes=input_shapes, + output_arrays=output_arrays, + tag_set=tag_set, + signature_key=signature_key) + return graph_def, in_tensors, out_tensors + def testSimpleSavedModel(self): - """Test a simple SavedModel created on the fly.""" - # Create a simple SavedModel + """Test a SavedModel.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - # Convert to tflite - result = convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir) - self.assertTrue(result) + _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) def testSimpleSavedModelWithNoneBatchSizeInShape(self): - """Test a simple SavedModel, with None in input tensor's shape.""" + """Test a SavedModel with None in input tensor's shape.""" saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) - result = convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir) - self.assertTrue(result) + _, in_tensors, out_tensors = self._convertSavedModel(saved_model_dir) - def testSimpleSavedModelWithMoreNoneInShape(self): - """Test a simple SavedModel, fail as more None in input shape.""" - saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, None, 3]) - # Convert to tflite: this should raise ValueError, as 3rd dim is None. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir) + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]]) - def testSimpleSavedModelWithWrongSignatureKey(self): - """Test a simple SavedModel, fail as given signature is invalid.""" + def testSimpleSavedModelWithInvalidSignatureKey(self): + """Test a SavedModel that fails due to an invalid signature_key.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - # Convert to tflite: this should raise ValueError, as - # signature_key does not exit in the saved_model. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir, signature_key="wrong-key") - - def testSimpleSavedModelWithWrongOutputArray(self): - """Test a simple SavedModel, fail as given output_arrays is invalid.""" - # Create a simple SavedModel + with self.assertRaises(ValueError) as error: + self._convertSavedModel(saved_model_dir, signature_key="invalid-key") + self.assertEqual( + "No 'invalid-key' in the SavedModel's SignatureDefs. " + "Possible values are 'serving_default'.", str(error.exception)) + + def testSimpleSavedModelWithInvalidOutputArray(self): + """Test a SavedModel that fails due to invalid output arrays.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - # Convert to tflite: this should raise ValueError, as - # output_arrays is not valid for the saved_model. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir, output_arrays=["wrong-output"]) + with self.assertRaises(ValueError) as error: + self._convertSavedModel(saved_model_dir, output_arrays=["invalid-output"]) + self.assertEqual("Invalid tensors 'invalid-output' were found.", + str(error.exception)) def testSimpleSavedModelWithWrongInputArrays(self): - """Test a simple SavedModel, fail as given input_arrays is invalid.""" + """Test a SavedModel that fails due to invalid input arrays.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - # Checks invalid input_arrays. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir, input_arrays=["wrong-input"]) - # Checks valid and invalid input_arrays. - with self.assertRaises(ValueError): - convert_saved_model.tflite_from_saved_model( - saved_model_dir=saved_model_dir, - input_arrays=["Placeholder", "wrong-input"]) + + # Check invalid input_arrays. + with self.assertRaises(ValueError) as error: + self._convertSavedModel(saved_model_dir, input_arrays=["invalid-input"]) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) + + # Check valid and invalid input_arrays. + with self.assertRaises(ValueError) as error: + self._convertSavedModel( + saved_model_dir, input_arrays=["Placeholder", "invalid-input"]) + self.assertEqual("Invalid tensors 'invalid-input' were found.", + str(error.exception)) def testSimpleSavedModelWithCorrectArrays(self): - """Test a simple SavedModel, with correct input_arrays and output_arrays.""" + """Test a SavedModel with correct input_arrays and output_arrays.""" saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3]) - result = convert_saved_model.tflite_from_saved_model( + _, in_tensors, out_tensors = self._convertSavedModel( saved_model_dir=saved_model_dir, input_arrays=["Placeholder"], output_arrays=["add"]) - self.assertTrue(result) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[None, 16, 16, 3]]) def testSimpleSavedModelWithCorrectInputArrays(self): - """Test a simple SavedModel, with correct input_arrays and input_shapes.""" + """Test a SavedModel with correct input_arrays and input_shapes.""" saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - result = convert_saved_model.tflite_from_saved_model( + _, in_tensors, out_tensors = self._convertSavedModel( saved_model_dir=saved_model_dir, input_arrays=["Placeholder"], input_shapes={"Placeholder": [1, 16, 16, 3]}) - self.assertTrue(result) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) + + def testTwoInputArrays(self): + """Test a simple SavedModel.""" + saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3]) + + _, in_tensors, out_tensors = self._convertSavedModel( + saved_model_dir=saved_model_dir, input_arrays=["inputB", "inputA"]) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0", "inputB:0"]) + self.assertEqual( + self._getArrayShapes(in_tensors), [[1, 16, 16, 3], [1, 16, 16, 3]]) + + def testSubsetInputArrays(self): + """Test a SavedModel with a subset of the input array names of the model.""" + saved_model_dir = self._createSavedModelTwoInputArrays(shape=[1, 16, 16, 3]) + + # Check case where input shape is given. + _, in_tensors, out_tensors = self._convertSavedModel( + saved_model_dir=saved_model_dir, + input_arrays=["inputA"], + input_shapes={"inputA": [1, 16, 16, 3]}) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) + + # Check case where input shape is None. + _, in_tensors, out_tensors = self._convertSavedModel( + saved_model_dir=saved_model_dir, input_arrays=["inputA"]) + + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["inputA:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 16, 16, 3]]) def testMultipleMetaGraphDef(self): - """Test saved model with multiple MetaGraphDef.""" + """Test saved model with multiple MetaGraphDefs.""" saved_model_dir = os.path.join(self.get_temp_dir(), "savedmodel_two_mgd") builder = saved_model.builder.SavedModelBuilder(saved_model_dir) with session.Session(graph=ops.Graph()) as sess: @@ -161,91 +301,13 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase): builder.save(True) # Convert to tflite - convert_saved_model.tflite_from_saved_model( + _, in_tensors, out_tensors = self._convertSavedModel( saved_model_dir=saved_model_dir, tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"])) - -class ConvertSavedModelTestBasicGraphToText(test_util.TensorFlowTestCase): - - def _createSimpleSavedModel(self, shape): - """Create a simple SavedModel.""" - saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel") - with session.Session() as sess: - in_tensor_1 = array_ops.placeholder( - shape=shape, dtype=dtypes.float32, name="inputB") - in_tensor_2 = array_ops.placeholder( - shape=shape, dtype=dtypes.float32, name="inputA") - out_tensor = in_tensor_1 + in_tensor_2 - inputs = {"x": in_tensor_1, "y": in_tensor_2} - outputs = {"z": out_tensor} - saved_model.simple_save(sess, saved_model_dir, inputs, outputs) - return saved_model_dir - - def _getInputArrayNames(self, model_proto): - return [data.name for data in model_proto.input_arrays] - - def _getInputArrayShapes(self, model_proto): - return [ - [dim for dim in data.shape.dims] for data in model_proto.input_arrays - ] - - def _get_model_flags_proto_from_file(self, filename): - proto = _model_flags_pb2.ModelFlags() - with gfile.Open(filename, "rb") as output_file: - proto.ParseFromString(output_file.read()) - output_file.close() - return proto - - def testSimpleSavedModel(self): - """Test a simple SavedModel.""" - saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - output_file_model = os.path.join(self.get_temp_dir(), "model.pb") - output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt") - - convert_saved_model.saved_model_to_frozen_graphdef( - saved_model_dir=saved_model_dir, - output_file_model=output_file_model, - output_file_flags=output_file_flags, - input_arrays=["inputB", "inputA"]) - - proto = self._get_model_flags_proto_from_file(output_file_flags) - self.assertEqual(proto.output_arrays, ["add"]) - self.assertEqual(self._getInputArrayNames(proto), ["inputA", "inputB"]) - self.assertEqual( - self._getInputArrayShapes(proto), [[1, 16, 16, 3], [1, 16, 16, 3]]) - - def testSimpleSavedModelWithDifferentInputNames(self): - """Test a simple SavedModel.""" - saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3]) - output_file_model = os.path.join(self.get_temp_dir(), "model.pb") - output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt") - - # Check case where input shape is given. - convert_saved_model.saved_model_to_frozen_graphdef( - saved_model_dir=saved_model_dir, - output_file_model=output_file_model, - output_file_flags=output_file_flags, - input_arrays=["inputA"], - input_shapes={"inputA": [1, 16, 16, 3]}) - - proto = self._get_model_flags_proto_from_file(output_file_flags) - self.assertEqual(proto.output_arrays, ["add"]) - self.assertEqual(self._getInputArrayNames(proto), ["inputA"]) - self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]]) - - # Check case where input shape is None. - convert_saved_model.saved_model_to_frozen_graphdef( - saved_model_dir=saved_model_dir, - output_file_model=output_file_model, - output_file_flags=output_file_flags, - input_arrays=["inputA"], - input_shapes={"inputA": None}) - - proto = self._get_model_flags_proto_from_file(output_file_flags) - self.assertEqual(proto.output_arrays, ["add"]) - self.assertEqual(self._getInputArrayNames(proto), ["inputA"]) - self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]]) + self.assertEqual(self._getArrayNames(out_tensors), ["add:0"]) + self.assertEqual(self._getArrayNames(in_tensors), ["Placeholder:0"]) + self.assertEqual(self._getArrayShapes(in_tensors), [[1, 28, 28]]) class Model(keras.Model): @@ -354,7 +416,7 @@ def dummy_input_fn(): return image, labels -class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase): +class FreezeSavedModelTestTrainGraph(test_util.TensorFlowTestCase): def testTrainedMnistSavedModel(self): """Test mnist SavedModel, trained with dummy data and small steps.""" @@ -379,13 +441,16 @@ class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase): # Convert to tflite and test output saved_model_name = os.listdir(saved_model_dir)[0] saved_model_final_dir = os.path.join(saved_model_dir, saved_model_name) - output_file = os.path.join(saved_model_dir, saved_model_final_dir + ".lite") + # TODO(zhixianyan): no need to limit output_arrays to `Softmax' # once b/74205001 fixed and argmax implemented in tflite. - result = convert_saved_model.tflite_from_saved_model( + result = convert_saved_model.freeze_saved_model( saved_model_dir=saved_model_final_dir, + input_arrays=None, + input_shapes=None, output_arrays=["Softmax"], - output_file=output_file) + tag_set=set([tag_constants.SERVING]), + signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) self.assertTrue(result) diff --git a/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py b/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py deleted file mode 100644 index 4d9782f4a6a9e853c3afdbd97d4264a818937e63..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Python console command for generating frozen models from SavedModels. - -This exists to add SavedModel compatibility to TOCO. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import sys -from tensorflow.contrib.lite.python.convert_saved_model import saved_model_to_frozen_graphdef -from tensorflow.python.platform import app - -FLAGS = None - - -def execute(unused_args): - """Calls function to convert the SavedModel to a frozen graph.""" - # Error handling. - if FLAGS.input_shapes and not FLAGS.input_arrays: - raise ValueError("Input shapes requires input arrays to be specified.") - - # Calls saved_model_to_frozen_graphdef function to generate frozen graph. - input_arrays = (FLAGS.input_arrays.split(",") if FLAGS.input_arrays else None) - input_shapes = None - if FLAGS.input_shapes: - input_shapes = { - input_arrays[idx]: shape.split(",") - for idx, shape in enumerate(FLAGS.input_shapes.split(":")) - } - output_arrays = ( - FLAGS.output_arrays.split(",") if FLAGS.output_arrays else None) - tag_set = set(FLAGS.tag_set.split(",")) if FLAGS.tag_set else None - - saved_model_to_frozen_graphdef( - saved_model_dir=FLAGS.saved_model_directory, - output_file_model=FLAGS.output_file_model, - output_file_flags=FLAGS.output_file_flags, - input_arrays=input_arrays, - input_shapes=input_shapes, - output_arrays=output_arrays, - tag_set=tag_set, - signature_key=FLAGS.signature_key, - batch_size=FLAGS.batch_size) - - -def main(): - global FLAGS - # Parses flags. - parser = argparse.ArgumentParser( - description="Invoke SavedModel to frozen model converter.") - parser.add_argument( - "saved_model_directory", - type=str, - help="Full path to directory containing the SavedModel.") - parser.add_argument( - "output_file_model", - type=str, - help="Full file path to save frozen graph.") - parser.add_argument( - "output_file_flags", type=str, help="Full file path to save ModelFlags.") - parser.add_argument( - "--input_arrays", - type=str, - help="Name of the input arrays, comma-separated.") - parser.add_argument( - "--input_shapes", - type=str, - help="Shapes corresponding to --input_arrays, colon-separated.") - parser.add_argument( - "--output_arrays", - type=str, - help="Name of the output arrays, comma-separated.") - parser.add_argument( - "--tag_set", type=str, help="Name of output arrays, comma-separated.") - parser.add_argument( - "--signature_key", - type=str, - help="Key identifying SignatureDef containing inputs and outputs.") - parser.add_argument( - "--batch_size", - type=int, - help="Batch size for the model. Replaces the first dimension of an " - "input size array if undefined.") - - FLAGS, unparsed = parser.parse_known_args() - - app.run(main=execute, argv=[sys.argv[0]] + unparsed) - - -if __name__ == "__main__": - main() diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py index 5fbc55145217dd8a4e3eec4108e18d1c2be5c883..779bda4c9d05fd056d6a262412fdcf0d47e7c57c 100644 --- a/tensorflow/contrib/lite/python/interpreter.py +++ b/tensorflow/contrib/lite/python/interpreter.py @@ -54,7 +54,7 @@ class Interpreter(object): elif model_content and not model_path: self._interpreter = ( _interpreter_wrapper.InterpreterWrapper_CreateWrapperCPPFromBuffer( - model_content, len(model_content))) + model_content)) if not self._interpreter: raise ValueError( 'Failed to create model from {} bytes'.format(len(model_content))) diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD index 453eda6e7345762666917fd501b69c7181c349e8..12ab38847dc0f838ae2c6bf80ed80805285e4b8b 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/BUILD @@ -15,7 +15,7 @@ cc_library( "//tensorflow/contrib/lite/kernels:builtin_ops", "//tensorflow/core:lib", "//tensorflow/python:numpy_lib", - "//util/python:python_headers", + "//third_party/python_runtime:headers", "@com_google_absl//absl/memory", ], ) @@ -27,6 +27,6 @@ tf_py_wrap_cc( ], deps = [ ":interpreter_wrapper_lib", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 16f4f30b94313453c3a4f0496a6ac5649847f076..5f304ad45d400b13e20bda8184b5b40cfe13f6c2 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -42,6 +42,8 @@ std::unique_ptr CreateInterpreter( return nullptr; } + tensorflow::ImportNumpy(); + std::unique_ptr interpreter; tflite::InterpreterBuilder(*model, resolver)(&interpreter); if (interpreter) { @@ -331,9 +333,14 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile( } InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( - const char* data, size_t len) { + PyObject* data) { + char * buf = nullptr; + Py_ssize_t length; + if (PY_TO_CPPSTRING(data, &buf, &length) == -1) { + return nullptr; + } std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(data, len); + tflite::FlatBufferModel::BuildFromBuffer(buf, length); return model ? new InterpreterWrapper(std::move(model)) : nullptr; } diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h index 0972c572595f5044a305a81afaccbea5f131247c..c02aa3804367f787016ef78fc8557005507f051b 100644 --- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -19,6 +19,8 @@ limitations under the License. #include #include +// Place `` before to avoid build failures in macOS. +#include #include // We forward declare TFLite classes here to avoid exposing them to SWIG. @@ -40,8 +42,7 @@ class InterpreterWrapper { static InterpreterWrapper* CreateWrapperCPPFromFile(const char* model_path); // SWIG caller takes ownership of pointer. - static InterpreterWrapper* CreateWrapperCPPFromBuffer(const char* data, - size_t len); + static InterpreterWrapper* CreateWrapperCPPFromBuffer(PyObject* data); ~InterpreterWrapper(); bool AllocateTensors(); diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 86b25e68acaf5d74e3dd11784446e7bda3d329ee..876ffbbffa5a47b91d1318baf431050fe364aac0 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -16,23 +16,376 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. +@@TocoConverter @@toco_convert @@toco_convert_protos -@@tflite_from_saved_model @@Interpreter @@OpHint @@convert_op_hints_to_stubs +@@FLOAT +@@QUANTIZED_UINT8 +@@TFLITE +@@GRAPHVIZ_DOT + """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import +from six import PY3 + +from google.protobuf import text_format as _text_format +from google.protobuf.message import DecodeError +from tensorflow.contrib.lite.python import lite_constants as constants +from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import +from tensorflow.contrib.lite.python.convert import tensor_name from tensorflow.contrib.lite.python.convert import toco_convert -from tensorflow.contrib.lite.python.convert import toco_convert_protos -from tensorflow.contrib.lite.python.convert_saved_model import tflite_from_saved_model -from tensorflow.contrib.lite.python.interpreter import Interpreter -from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs -from tensorflow.contrib.lite.python.op_hint import OpHint -# pylint: enable=unused-import +from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import +from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model +from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names +from tensorflow.contrib.lite.python.convert_saved_model import set_tensor_shapes +from tensorflow.contrib.lite.python.interpreter import Interpreter # pylint: disable=unused-import +from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs # pylint: disable=unused-import +from tensorflow.contrib.lite.python.op_hint import OpHint # pylint: disable=unused-import +from tensorflow.core.framework import graph_pb2 as _graph_pb2 +from tensorflow.python.client import session as _session +from tensorflow.python.framework import graph_util as tf_graph_util +from tensorflow.python.framework.importer import import_graph_def +from tensorflow.python.ops.variables import global_variables_initializer +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import tag_constants + + +class TocoConverter(object): + """Convert a TensorFlow model into `output_format` using TOCO. + + This is used to convert from a TensorFlow GraphDef or SavedModel into either a + TFLite FlatBuffer or graph visualization. + + Attributes: + + inference_type: Target data type of arrays in the output file. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of input arrays. Allows for a + different type for input arrays in the case of quantization. Currently + must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + output_format: Output file format. Currently must be `{TFLITE, + GRAPHVIZ_DOT}`. (default TFLITE) + quantized_input_stats: Dict of strings representing input tensor names + mapped to tuple of integers representing the mean and standard deviation + of the training data (e.g., {"foo" : (0., 1.)}). Only need if + `inference_type` is `QUANTIZED_UINT8`. (default {}) + default_ranges_stats: Tuple of integers representing (min, max) range values + for all arrays without a specified range. Intended for experimenting with + quantization via "dummy quantization". (default None) + drop_control_dependency: Boolean indicating whether to drop control + dependencies silently. This is due to TFLite not supporting control + dependencies. (default True) + reorder_across_fake_quant: Boolean indicating whether to reorder FakeQuant + nodes in unexpected locations. Used when the location of the FakeQuant + nodes is preventing graph transformations necessary to convert the graph. + Results in a graph that differs from the quantized training graph, + potentially causing differing arithmetic behavior. (default False) + change_concat_input_ranges: Boolean to change behavior of min/max ranges for + inputs and outputs of the concat operator for quantized models. Changes + the ranges of concat operator overlap when true. (default False) + allow_custom_ops: Boolean indicating whether to allow custom operations. + When false any unknown operation is an error. When true, custom ops are + created for any op that is unknown. The developer will need to provide + these to the TensorFlow Lite runtime with a custom resolver. + (default False) + quantize_weights: Boolean indicating whether to store weights as quantized + weights followed by dequantize operations. Computation is still done in + float, but reduces model size (at the cost of accuracy and latency). + (default False) + dump_graphviz_dir: Full filepath of folder to dump the graphs at various + stages of processing GraphViz .dot files. Preferred over + --output_format=GRAPHVIZ_DOT in order to keep the requirements of the + output file. (default None) + dump_graphviz_video: Boolean indicating whether to dump the graph after + every graph transformation. (default False) + + Example usage: + + # Converting a GraphDef from session. + converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors) + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) + + # Converting a GraphDef from file. + converter = lite.TocoConverter.from_frozen_graph( + graph_def_file, input_arrays, output_arrays) + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) + + # Converting a SavedModel. + converter = lite.TocoConverter.from_saved_model(saved_model_dir) + tflite_model = converter.convert() + """ + + def __init__(self, graph_def, input_tensors, output_tensors): + """Constructor for TocoConverter. + + Args: + + graph_def: TensorFlow GraphDef. + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + """ + self._graph_def = graph_def + self._input_tensors = input_tensors + self._output_tensors = output_tensors + self.inference_type = constants.FLOAT + self.inference_input_type = None + self.output_format = constants.TFLITE + self.quantized_input_stats = {} + self.default_ranges_stats = None + self.drop_control_dependency = True + self.reorder_across_fake_quant = False + self.change_concat_input_ranges = False + self.allow_custom_ops = False + self.quantize_weights = False + self.dump_graphviz_dir = None + self.dump_graphviz_video = False + + @classmethod + def from_session(cls, sess, input_tensors, output_tensors): + """Creates a TocoConverter class from a TensorFlow Session. + + Args: + sess: TensorFlow Session. + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + + Returns: + TocoConverter class. + """ + graph_def = _freeze_graph(sess, output_tensors) + return cls(graph_def, input_tensors, output_tensors) + + @classmethod + def from_frozen_graph(cls, + graph_def_file, + input_arrays, + output_arrays, + input_shapes=None): + """Creates a TocoConverter class from a file containing a frozen GraphDef. + + Args: + graph_def_file: Full filepath of file containing TensorFlow GraphDef. + input_arrays: List of input tensors to freeze graph with. + output_arrays: List of output tensors to freeze graph with. + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) + + Returns: + TocoConverter class. + + Raises: + ValueError: + Unable to parse input file. + The graph is not frozen. + input_arrays or output_arrays contains an invalid tensor name. + """ + with _session.Session() as sess: + sess.run(global_variables_initializer()) + + # Read GraphDef from file. + graph_def = _graph_pb2.GraphDef() + with open(graph_def_file, "rb") as f: + file_content = f.read() + try: + graph_def.ParseFromString(file_content) + except (_text_format.ParseError, DecodeError): + try: + print("Ignore 'tcmalloc: large alloc' warnings.") + + if not isinstance(file_content, str): + if PY3: + file_content = file_content.decode('utf-8') + else: + file_content = file_content.encode('utf-8') + _text_format.Merge(file_content, graph_def) + except (_text_format.ParseError, DecodeError): + raise ValueError( + "Unable to parse input file '{}'.".format(graph_def_file)) + sess.graph.as_default() + import_graph_def(graph_def, name="") + + # Get input and output tensors. + input_tensors = get_tensors_from_tensor_names(sess.graph, input_arrays) + output_tensors = get_tensors_from_tensor_names(sess.graph, output_arrays) + set_tensor_shapes(input_tensors, input_shapes) + + # Check if graph is frozen. + if not _is_frozen_graph(sess): + raise ValueError("Please freeze the graph using freeze_graph.py.") + + # Create TocoConverter class. + return cls(sess.graph_def, input_tensors, output_tensors) + + @classmethod + def from_saved_model(cls, + saved_model_dir, + input_arrays=None, + input_shapes=None, + output_arrays=None, + tag_set=None, + signature_key=None): + """Creates a TocoConverter class from a SavedModel. + + Args: + saved_model_dir: SavedModel directory to convert. + input_arrays: List of input tensors to freeze graph with. Uses input + arrays from SignatureDef when none are provided. (default None) + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : + None}). (default None) + output_arrays: List of output tensors to freeze graph with. Uses output + arrays from SignatureDef when none are provided. (default None) + tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to + analyze. All tags in the tag set must be present. (default set("serve")) + signature_key: Key identifying SignatureDef containing inputs and outputs. + (default DEFAULT_SERVING_SIGNATURE_DEF_KEY) + + Returns: + TocoConverter class. + """ + if tag_set is None: + tag_set = set([tag_constants.SERVING]) + if signature_key is None: + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + + result = freeze_saved_model(saved_model_dir, input_arrays, input_shapes, + output_arrays, tag_set, signature_key) + return cls( + graph_def=result[0], input_tensors=result[1], output_tensors=result[2]) + + def convert(self): + """Converts a TensorFlow GraphDef based on instance variables. + + Returns: + The converted data in serialized format. Either a TFLite Flatbuffer or a + Graphviz graph depending on value in `output_format`. + + Raises: + ValueError: + Input shape is not specified. + None value for dimension in input_tensor. + """ + # Checks dimensions in input tensor. + for tensor in self._input_tensors: + if not tensor.get_shape(): + raise ValueError("Provide an input shape for input array '{0}'.".format( + tensor_name(tensor))) + shape = tensor.get_shape().as_list() + if None in shape[1:]: + raise ValueError( + "None is only supported in the 1st dimension. Tensor '{0}' has " + "invalid shape '{1}'.".format(tensor_name(tensor), shape)) + elif shape[0] is None: + self._set_batch_size(batch_size=1) + + # Get quantization stats. Ensures there is one stat per name if the stats + # are specified. + if self.quantized_input_stats: + quantized_stats = [] + invalid_stats = [] + for tensor in self._input_tensors: + name = tensor_name(tensor) + if name in self.quantized_input_stats: + quantized_stats.append(self.quantized_input_stats[name]) + else: + invalid_stats.append(name) + + if invalid_stats: + raise ValueError("Quantization input stats are not available for input " + "tensors '{0}'.".format(",".join(invalid_stats))) + else: + quantized_stats = None + + # Converts model. + result = toco_convert( + input_data=self._graph_def, + input_tensors=self._input_tensors, + output_tensors=self._output_tensors, + inference_type=self.inference_type, + inference_input_type=self.inference_input_type, + input_format=constants.TENSORFLOW_GRAPHDEF, + output_format=self.output_format, + quantized_input_stats=quantized_stats, + default_ranges_stats=self.default_ranges_stats, + drop_control_dependency=self.drop_control_dependency, + reorder_across_fake_quant=self.reorder_across_fake_quant, + change_concat_input_ranges=self.change_concat_input_ranges, + allow_custom_ops=self.allow_custom_ops, + quantize_weights=self.quantize_weights, + dump_graphviz_dir=self.dump_graphviz_dir, + dump_graphviz_video=self.dump_graphviz_video) + return result + + def get_input_arrays(self): + """Returns a list of the names of the input tensors. + + Returns: + List of strings. + """ + return [tensor_name(tensor) for tensor in self._input_tensors] + + def _set_batch_size(self, batch_size): + """Sets the first dimension of the input tensor to `batch_size`. + + Args: + batch_size: Batch size for the model. Replaces the first dimension of an + input size array if undefined. (default 1) + """ + for tensor in self._input_tensors: + shape = tensor.get_shape().as_list() + shape[0] = batch_size + tensor.set_shape(shape) + + +def _is_frozen_graph(sess): + """Determines if the graph is frozen. + + Determines if a graph has previously been frozen by checking for any + operations of type Variable*. If variables are found, the graph is not frozen. + + Args: + sess: TensorFlow Session. + + Returns: + Bool. + """ + for op in sess.graph.get_operations(): + if op.type.startswith("Variable"): + return False + return True + + +def _freeze_graph(sess, output_tensors): + """Returns a frozen GraphDef. + + Freezes a graph with Variables in it. Otherwise the existing GraphDef is + returned. + + Args: + sess: TensorFlow Session. + output_tensors: List of output tensors (only .name is used from this). + + Returns: + Frozen GraphDef. + """ + if not _is_frozen_graph(sess): + sess.run(global_variables_initializer()) + output_arrays = [tensor_name(tensor) for tensor in output_tensors] + return tf_graph_util.convert_variables_to_constants(sess, sess.graph_def, + output_arrays) + else: + return sess.graph_def diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8c9d2c1651dd2d0b3cd27cf638c04429e3131efb --- /dev/null +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -0,0 +1,622 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for lite.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np + +from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.python import lite_constants +from tensorflow.contrib.lite.python.interpreter import Interpreter +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import saved_model +from tensorflow.python.training.training_util import write_graph + + +class FromSessionTest(test_util.TensorFlowTestCase): + + def testFloat(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testQuantization(self): + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') + out_tensor = array_ops.fake_quant_with_min_max_args( + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session( + sess, [in_tensor_1, in_tensor_2], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = { + 'inputA': (0., 1.), + 'inputB': (0., 1.) + } # mean, std_dev + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((1., 0.), + input_details[0]['quantization']) # scale, zero_point + + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.uint8, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((1., 0.), + input_details[1]['quantization']) # scale, zero_point + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('output', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + + def testQuantizationInvalid(self): + in_tensor_1 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') + in_tensor_2 = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') + out_tensor = array_ops.fake_quant_with_min_max_args( + in_tensor_1 + in_tensor_2, min=0., max=1., name='output') + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session( + sess, [in_tensor_1, in_tensor_2], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual( + 'Quantization input stats are not available for input tensors ' + '\'inputB\'.', str(error.exception)) + + def testSizeNoneInvalid(self): + in_tensor = array_ops.placeholder(dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Test invalid shape. None after 1st dimension. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual('Provide an input shape for input array \'Placeholder\'.', + str(error.exception)) + + def testBatchSizeInvalid(self): + in_tensor = array_ops.placeholder( + shape=[1, None, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Test invalid shape. None after 1st dimension. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + with self.assertRaises(ValueError) as error: + converter.convert() + self.assertEqual( + 'None is only supported in the 1st dimension. Tensor ' + '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.', + str(error.exception)) + + def testBatchSizeValid(self): + in_tensor = array_ops.placeholder( + shape=[None, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFreezeGraph(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + var = variable_scope.get_variable( + 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + var + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + # TODO(nupurgarg): Verify value of contents in GraphViz. + def testGraphviz(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.output_format = lite_constants.GRAPHVIZ_DOT + graphviz_output = converter.convert() + self.assertTrue(graphviz_output) + + # TODO(nupurgarg): Verify value of contents in GraphViz. + def testDumpGraphviz(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + graphviz_dir = self.get_temp_dir() + converter.dump_graphviz_dir = graphviz_dir + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensure interpreter is able to allocate and check graphviz data. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + num_items_graphviz = len(os.listdir(graphviz_dir)) + self.assertTrue(num_items_graphviz) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + graphviz_dir = self.get_temp_dir() + converter.dump_graphviz_dir = graphviz_dir + converter.dump_graphviz_video = True + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensure graphviz folder has more data after using video flag. + num_items_graphviz_video = len(os.listdir(graphviz_dir)) + self.assertTrue(num_items_graphviz_video > num_items_graphviz) + + def testInferenceInputType(self): + in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3], dtype=dtypes.uint8) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.inference_input_type = lite_constants.QUANTIZED_UINT8 + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + def testDefaultRangesStats(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.inference_type = lite_constants.QUANTIZED_UINT8 + converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev + converter.default_ranges_stats = (0, 6) # min, max + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.uint8, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((1., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.uint8, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertTrue(output_details[0]['quantization'][0] > 0) # scale + + def testQuantizeWeights(self): + np.random.seed(0) + # We need the tensor to have more than 1024 elements for quantize_weights + # to kick in. Thus, the [33, 33] shape. + in_tensor_1 = array_ops.placeholder( + shape=[33, 33], dtype=dtypes.float32, name='inputA') + in_tensor_2 = constant_op.constant( + np.random.uniform(low=-10., high=10., size=(33, 33)), + shape=[33, 33], + dtype=dtypes.float32, + name='inputB') + out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') + sess = session.Session() + + # Convert float model. + float_converter = lite.TocoConverter.from_session(sess, [in_tensor_1], + [out_tensor]) + float_tflite = float_converter.convert() + self.assertTrue(float_tflite) + + # Convert quantized weights model. + quantized_weights_converter = lite.TocoConverter.from_session( + sess, [in_tensor_1], [out_tensor]) + quantized_weights_converter.quantize_weights = True + quantized_weights_tflite = quantized_weights_converter.convert() + self.assertTrue(quantized_weights_tflite) + + # Ensure that the quantized weights tflite model is smaller. + self.assertTrue(len(quantized_weights_tflite) < len(float_tflite)) + + +class FromFrozenGraphFile(test_util.TensorFlowTestCase): + + def testFloat(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_frozen_graph(graph_def_file, + ['Placeholder'], ['add']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testFloatWithShapesArray(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_frozen_graph( + graph_def_file, ['Placeholder'], ['add'], + input_shapes={'Placeholder': [1, 16, 16, 3]}) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + + def testFreezeGraph(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + var = variable_scope.get_variable( + 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + var + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') + write_graph(sess.graph_def, '', graph_def_file, False) + + # Ensure the graph with variables cannot be converted. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], + ['add']) + self.assertEqual('Please freeze the graph using freeze_graph.py.', + str(error.exception)) + + def testPbtxt(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + _ = in_tensor + in_tensor + sess = session.Session() + + # Write graph to file. + graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt') + write_graph(sess.graph_def, '', graph_def_file, True) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_frozen_graph(graph_def_file, + ['Placeholder'], ['add']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(1, len(input_details)) + self.assertEqual('Placeholder', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testInvalidFile(self): + graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file') + with gfile.Open(graph_def_file, 'wb') as temp_file: + temp_file.write('bad data') + temp_file.flush() + + # Attempts to convert the invalid model. + with self.assertRaises(ValueError) as error: + lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'], + ['add']) + self.assertEqual( + 'Unable to parse input file \'{}\'.'.format(graph_def_file), + str(error.exception)) + + +class FromSavedModelTest(test_util.TensorFlowTestCase): + + def _createSavedModel(self, shape): + """Create a simple SavedModel.""" + saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel') + with session.Session() as sess: + in_tensor_1 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name='inputB') + in_tensor_2 = array_ops.placeholder( + shape=shape, dtype=dtypes.float32, name='inputA') + out_tensor = in_tensor_1 + in_tensor_2 + inputs = {'x': in_tensor_1, 'y': in_tensor_2} + outputs = {'z': out_tensor} + saved_model.simple_save(sess, saved_model_dir, inputs, outputs) + return saved_model_dir + + def testSimpleModel(self): + """Test a SavedModel.""" + saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_saved_model(saved_model_dir) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testNoneBatchSize(self): + """Test a SavedModel, with None in input tensor's shape.""" + saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3]) + + converter = lite.TocoConverter.from_saved_model(saved_model_dir) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testOrderInputArrays(self): + """Test a SavedModel ordering of input arrays.""" + saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) + + converter = lite.TocoConverter.from_saved_model( + saved_model_dir, input_arrays=['inputB', 'inputA']) + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check values from converted model. + interpreter = Interpreter(model_content=tflite_model) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + self.assertEqual(2, len(input_details)) + self.assertEqual('inputA', input_details[0]['name']) + self.assertEqual(np.float32, input_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) + self.assertEqual((0., 0.), input_details[0]['quantization']) + + self.assertEqual('inputB', input_details[1]['name']) + self.assertEqual(np.float32, input_details[1]['dtype']) + self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) + self.assertEqual((0., 0.), input_details[1]['quantization']) + + output_details = interpreter.get_output_details() + self.assertEqual(1, len(output_details)) + self.assertEqual('add', output_details[0]['name']) + self.assertEqual(np.float32, output_details[0]['dtype']) + self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) + self.assertEqual((0., 0.), output_details[0]['quantization']) + + def testSubsetInputArrays(self): + """Test a SavedModel with a subset of the input array names of the model.""" + saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) + + # Check case where input shape is given. + converter = lite.TocoConverter.from_saved_model( + saved_model_dir, + input_arrays=['inputA'], + input_shapes={'inputA': [1, 16, 16, 3]}) + + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Check case where input shape is None. + converter = lite.TocoConverter.from_saved_model( + saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None}) + + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..f497533bed054d260aefc7b3fe67ae655c7cbcda --- /dev/null +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -0,0 +1,364 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python command line interface for running TOCO.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys + +from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 +from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 +from tensorflow.python.platform import app + + +def _parse_array(values): + if values: + return values.split(",") + + +def _parse_int_array(values): + if values: + return [int(val) for val in values.split(",")] + + +def _parse_set(values): + if values: + return set(values.split(",")) + + +def _get_toco_converter(flags): + """Makes a TocoConverter object based on the flags provided. + + Args: + flags: argparse.Namespace object containing TFLite flags. + + Returns: + TocoConverter object. + """ + # Parse input and output arrays. + input_arrays = _parse_array(flags.input_arrays) + input_shapes = None + if flags.input_shapes: + input_shapes_list = [ + _parse_int_array(shape) for shape in flags.input_shapes.split(":") + ] + input_shapes = dict(zip(input_arrays, input_shapes_list)) + output_arrays = _parse_array(flags.output_arrays) + + converter_kwargs = { + "input_arrays": input_arrays, + "input_shapes": input_shapes, + "output_arrays": output_arrays + } + + # Create TocoConverter. + if flags.graph_def_file: + converter_fn = lite.TocoConverter.from_frozen_graph + converter_kwargs["graph_def_file"] = flags.graph_def_file + elif flags.saved_model_dir: + converter_fn = lite.TocoConverter.from_saved_model + converter_kwargs["saved_model_dir"] = flags.saved_model_dir + converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) + converter_kwargs["signature_key"] = flags.saved_model_signature_key + + return converter_fn(**converter_kwargs) + + +def _convert_model(flags): + """Calls function to convert the TensorFlow model into a TFLite model. + + Args: + flags: argparse.Namespace object. + + Raises: + ValueError: Invalid flags. + """ + # Create converter. + converter = _get_toco_converter(flags) + if flags.inference_type: + converter.inference_type = _types_pb2.IODataType.Value(flags.inference_type) + if flags.inference_input_type: + converter.inference_input_type = _types_pb2.IODataType.Value( + flags.inference_input_type) + if flags.output_format: + converter.output_format = _toco_flags_pb2.FileFormat.Value( + flags.output_format) + + if flags.mean_values and flags.std_dev_values: + input_arrays = converter.get_input_arrays() + std_dev_values = _parse_int_array(flags.std_dev_values) + mean_values = _parse_int_array(flags.mean_values) + quant_stats = zip(mean_values, std_dev_values) + if ((not flags.input_arrays and len(input_arrays) > 1) or + (len(input_arrays) != len(quant_stats))): + raise ValueError("Mismatching --input_arrays, --std_dev_values, and " + "--mean_values. The flags must have the same number of " + "items. The current input arrays are '{0}'. " + "--input_arrays must be present when specifying " + "--std_dev_values and --mean_values with multiple input " + "tensors in order to map between names and " + "values.".format(",".join(input_arrays))) + converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) + if (flags.default_ranges_min is not None) and (flags.default_ranges_max is + not None): + converter.default_ranges_stats = (flags.default_ranges_min, + flags.default_ranges_max) + + if flags.drop_control_dependency: + converter.drop_control_dependency = flags.drop_control_dependency + if flags.reorder_across_fake_quant: + converter.reorder_across_fake_quant = flags.reorder_across_fake_quant + if flags.change_concat_input_ranges: + converter.change_concat_input_ranges = flags.change_concat_input_ranges + if flags.allow_custom_ops: + converter.allow_custom_ops = flags.allow_custom_ops + if flags.quantize_weights: + converter.quantize_weights = flags.quantize_weights + if flags.dump_graphviz_dir: + converter.dump_graphviz_dir = flags.dump_graphviz_dir + if flags.dump_graphviz_video: + converter.dump_graphviz_vode = flags.dump_graphviz_video + + # Convert model. + output_data = converter.convert() + with open(flags.output_file, "wb") as f: + f.write(output_data) + + +def _check_flags(flags, unparsed): + """Checks the parsed and unparsed flags to ensure they are valid. + + Raises an error if previously support unparsed flags are found. Raises an + error for parsed flags that don't meet the required conditions. + + Args: + flags: argparse.Namespace object containing TFLite flags. + unparsed: List of unparsed flags. + + Raises: + ValueError: Invalid flags. + """ + + # Check unparsed flags for common mistakes based on previous TOCO. + def _get_message_unparsed(flag, orig_flag, new_flag): + if flag.startswith(orig_flag): + return "\n Use {0} instead of {1}".format(new_flag, orig_flag) + return "" + + if unparsed: + output = "" + for flag in unparsed: + output += _get_message_unparsed(flag, "--input_file", "--graph_def_file") + output += _get_message_unparsed(flag, "--savedmodel_directory", + "--saved_model_dir") + output += _get_message_unparsed(flag, "--std_value", "--std_dev_values") + output += _get_message_unparsed(flag, "--batch_size", "--input_shapes") + output += _get_message_unparsed(flag, "--dump_graphviz", + "--dump_graphviz_dir") + if output: + raise ValueError(output) + + # Check that flags are valid. + if flags.graph_def_file and (not flags.input_arrays or + not flags.output_arrays): + raise ValueError("--input_arrays and --output_arrays are required with " + "--graph_def_file") + + if flags.input_shapes: + if not flags.input_arrays: + raise ValueError("--input_shapes must be used with --input_arrays") + if flags.input_shapes.count(":") != flags.input_arrays.count(","): + raise ValueError("--input_shapes and --input_arrays must have the same " + "number of items") + + if flags.std_dev_values or flags.mean_values: + if bool(flags.std_dev_values) != bool(flags.mean_values): + raise ValueError("--std_dev_values and --mean_values must be used " + "together") + if flags.std_dev_values.count(",") != flags.mean_values.count(","): + raise ValueError("--std_dev_values, --mean_values must have the same " + "number of items") + + if (flags.default_ranges_min is None) != (flags.default_ranges_max is None): + raise ValueError("--default_ranges_min and --default_ranges_max must be " + "used together") + + +def run_main(_): + """Main in toco_convert.py.""" + parser = argparse.ArgumentParser( + description=("Command line tool to run TensorFlow Lite Optimizing " + "Converter (TOCO).")) + + # Output file flag. + parser.add_argument( + "--output_file", + type=str, + help="Full filepath of the output file.", + required=True) + + # Input file flags. + input_file_group = parser.add_mutually_exclusive_group(required=True) + input_file_group.add_argument( + "--graph_def_file", + type=str, + help="Full filepath of file containing TensorFlow GraphDef.") + input_file_group.add_argument( + "--saved_model_dir", + type=str, + help="Full filepath of directory containing the SavedModel.") + + # Model format flags. + parser.add_argument( + "--output_format", + type=str.upper, + choices=["TFLITE", "GRAPHVIZ_DOT"], + help="Output file format.") + parser.add_argument( + "--inference_type", + type=str.upper, + choices=["FLOAT", "QUANTIZED_UINT8"], + help="Target data type of arrays in the output file.") + parser.add_argument( + "--inference_input_type", + type=str.upper, + choices=["FLOAT", "QUANTIZED_UINT8"], + help=("Target data type of input arrays. Allows for a different type for " + "input arrays in the case of quantization.")) + + # Input and output arrays flags. + parser.add_argument( + "--input_arrays", + type=str, + help="Names of the output arrays, comma-separated.") + parser.add_argument( + "--input_shapes", + type=str, + help="Shapes corresponding to --input_arrays, colon-separated.") + parser.add_argument( + "--output_arrays", + type=str, + help="Names of the output arrays, comma-separated.") + + # SavedModel related flags. + parser.add_argument( + "--saved_model_tag_set", + type=str, + help=("Comma-separated set of tags identifying the MetaGraphDef within " + "the SavedModel to analyze. All tags must be present. " + "(default \"serve\")")) + parser.add_argument( + "--saved_model_signature_key", + type=str, + help=("Key identifying the SignatureDef containing inputs and outputs. " + "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) + + # Quantization flags. + parser.add_argument( + "--std_dev_values", + type=str, + help=("Standard deviation of training data for each input tensor, " + "comma-separated. Used for quantization. (default None)")) + parser.add_argument( + "--mean_values", + type=str, + help=("Mean of training data for each input tensor, comma-separated. " + "Used for quantization. (default None)")) + parser.add_argument( + "--default_ranges_min", + type=int, + help=("Default value for min bound of min/max range values used for all " + "arrays without a specified range, Intended for experimenting with " + "quantization via \"dummy quantization\". (default None)")) + parser.add_argument( + "--default_ranges_max", + type=int, + help=("Default value for max bound of min/max range values used for all " + "arrays without a specified range, Intended for experimenting with " + "quantization via \"dummy quantization\". (default None)")) + parser.add_argument( + "--quantize_weights", + type=bool, + help=("Store float weights as quantized weights followed by dequantize " + "operations. Inference is still done in FLOAT, but reduces model " + "size (at the cost of accuracy and latency).")) + + # Graph manipulation flags. + parser.add_argument( + "--drop_control_dependency", + action="store_true", + help=("Boolean indicating whether to drop control dependencies silently. " + "This is due to TensorFlow not supporting control dependencies. " + "(default True)")) + parser.add_argument( + "--reorder_across_fake_quant", + action="store_true", + help=("Boolean indicating whether to reorder FakeQuant nodes in " + "unexpected locations. Used when the location of the FakeQuant " + "nodes is preventing graph transformations necessary to convert " + "the graph. Results in a graph that differs from the quantized " + "training graph, potentially causing differing arithmetic " + "behavior. (default False)")) + parser.add_argument( + "--change_concat_input_ranges", + action="store_true", + help=("Boolean to change behavior of min/max ranges for inputs and " + "outputs of the concat operator for quantized models. Changes the " + "ranges of concat operator overlap when true. (default False)")) + parser.add_argument( + "--allow_custom_ops", + action="store_true", + help=("Boolean indicating whether to allow custom operations. When false " + "any unknown operation is an error. When true, custom ops are " + "created for any op that is unknown. The developer will need to " + "provide these to the TensorFlow Lite runtime with a custom " + "resolver. (default False)")) + + # Logging flags. + parser.add_argument( + "--dump_graphviz_dir", + type=str, + help=("Full filepath of folder to dump the graphs at various stages of " + "processing GraphViz .dot files. Preferred over --output_format=" + "GRAPHVIZ_DOT in order to keep the requirements of the output " + "file.")) + parser.add_argument( + "--dump_graphviz_video", + action="store_true", + help=("Boolean indicating whether to dump the graph after every graph " + "transformation")) + + tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) + try: + _check_flags(tflite_flags, unparsed) + except ValueError as e: + parser.print_usage() + file_name = os.path.basename(sys.argv[0]) + sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e))) + sys.exit(1) + _convert_model(tflite_flags) + + +def main(): + app.run(main=run_main, argv=sys.argv[:1]) + + +if __name__ == "__main__": + main() diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc index ac408d2f94b98d505afe4c951d7cc2ff960606fb..9dc8daa227dd68ccde2efa4013ac4465a72e6bb0 100644 --- a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc @@ -39,7 +39,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ // DO NOT EDIT MANUALLY: This file is automatically generated by -// `schema_builtin_ops_header_generator.py`. +// `schema/builtin_ops_header/generator.cc`. #ifdef __cplusplus extern "C" { @@ -57,7 +57,6 @@ const char* kFileFooter = } // extern "C" #endif // __cplusplus #endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ -} )"; } // anonymous namespace diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 8bdeb035f5a778fa3b0d85da36d6b8d6721445ea..ee5208df1456d01f1a99ecc69722f5fb4ab0763a 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -145,6 +145,12 @@ enum BuiltinOperator : byte { SLICE = 65, SIN = 66, TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, } // Options for the builtin operators. @@ -198,6 +204,11 @@ union BuiltinOptions { SelectOptions, SliceOptions, TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, } enum Padding : byte { SAME, VALID } @@ -309,11 +320,23 @@ table LocalResponseNormalizationOptions { beta:float; } +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} + // An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell table LSTMOptions { + // Parameters for LSTM version 1 or above. fused_activation_function:ActivationFunctionType; cell_clip: float; // Optional, 0.0 means no clipping proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; } table ResizeBilinearOptions { @@ -419,6 +442,9 @@ table DequantizeOptions { table MaximumMinimumOptions { } +table TileOptions { +} + table ArgMaxOptions { output_type : TensorType; } @@ -450,6 +476,19 @@ table TransposeConvOptions { stride_h:int; } +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 35c34f53a6bf9716941f623b43f238c681252747..887e47ed1ea309d025d4be8745ffb8da06e8ee6b 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -151,6 +151,9 @@ struct DequantizeOptionsT; struct MaximumMinimumOptions; struct MaximumMinimumOptionsT; +struct TileOptions; +struct TileOptionsT; + struct ArgMaxOptions; struct ArgMaxOptionsT; @@ -178,6 +181,18 @@ struct SliceOptionsT; struct TransposeConvOptions; struct TransposeConvOptionsT; +struct ExpandDimsOptions; +struct ExpandDimsOptionsT; + +struct SparseToDenseOptions; +struct SparseToDenseOptionsT; + +struct EqualOptions; +struct EqualOptionsT; + +struct NotEqualOptions; +struct NotEqualOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -305,11 +320,17 @@ enum BuiltinOperator { BuiltinOperator_SLICE = 65, BuiltinOperator_SIN = 66, BuiltinOperator_TRANSPOSE_CONV = 67, + BuiltinOperator_SPARSE_TO_DENSE = 68, + BuiltinOperator_TILE = 69, + BuiltinOperator_EXPAND_DIMS = 70, + BuiltinOperator_EQUAL = 71, + BuiltinOperator_NOT_EQUAL = 72, + BuiltinOperator_LOG = 73, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_TRANSPOSE_CONV + BuiltinOperator_MAX = BuiltinOperator_LOG }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[67] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[73] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -377,7 +398,13 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[67] { BuiltinOperator_SELECT, BuiltinOperator_SLICE, BuiltinOperator_SIN, - BuiltinOperator_TRANSPOSE_CONV + BuiltinOperator_TRANSPOSE_CONV, + BuiltinOperator_SPARSE_TO_DENSE, + BuiltinOperator_TILE, + BuiltinOperator_EXPAND_DIMS, + BuiltinOperator_EQUAL, + BuiltinOperator_NOT_EQUAL, + BuiltinOperator_LOG }; return values; } @@ -452,6 +479,12 @@ inline const char **EnumNamesBuiltinOperator() { "SLICE", "SIN", "TRANSPOSE_CONV", + "SPARSE_TO_DENSE", + "TILE", + "EXPAND_DIMS", + "EQUAL", + "NOT_EQUAL", + "LOG", nullptr }; return names; @@ -513,11 +546,16 @@ enum BuiltinOptions { BuiltinOptions_SelectOptions = 47, BuiltinOptions_SliceOptions = 48, BuiltinOptions_TransposeConvOptions = 49, + BuiltinOptions_SparseToDenseOptions = 50, + BuiltinOptions_TileOptions = 51, + BuiltinOptions_ExpandDimsOptions = 52, + BuiltinOptions_EqualOptions = 53, + BuiltinOptions_NotEqualOptions = 54, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_TransposeConvOptions + BuiltinOptions_MAX = BuiltinOptions_NotEqualOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[50] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[55] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -568,7 +606,12 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[50] { BuiltinOptions_LessEqualOptions, BuiltinOptions_SelectOptions, BuiltinOptions_SliceOptions, - BuiltinOptions_TransposeConvOptions + BuiltinOptions_TransposeConvOptions, + BuiltinOptions_SparseToDenseOptions, + BuiltinOptions_TileOptions, + BuiltinOptions_ExpandDimsOptions, + BuiltinOptions_EqualOptions, + BuiltinOptions_NotEqualOptions }; return values; } @@ -625,6 +668,11 @@ inline const char **EnumNamesBuiltinOptions() { "SelectOptions", "SliceOptions", "TransposeConvOptions", + "SparseToDenseOptions", + "TileOptions", + "ExpandDimsOptions", + "EqualOptions", + "NotEqualOptions", nullptr }; return names; @@ -835,6 +883,26 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TransposeConvOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TileOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ExpandDimsOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_EqualOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_NotEqualOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1258,6 +1326,46 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_TransposeConvOptions ? reinterpret_cast(value) : nullptr; } + SparseToDenseOptionsT *AsSparseToDenseOptions() { + return type == BuiltinOptions_SparseToDenseOptions ? + reinterpret_cast(value) : nullptr; + } + const SparseToDenseOptionsT *AsSparseToDenseOptions() const { + return type == BuiltinOptions_SparseToDenseOptions ? + reinterpret_cast(value) : nullptr; + } + TileOptionsT *AsTileOptions() { + return type == BuiltinOptions_TileOptions ? + reinterpret_cast(value) : nullptr; + } + const TileOptionsT *AsTileOptions() const { + return type == BuiltinOptions_TileOptions ? + reinterpret_cast(value) : nullptr; + } + ExpandDimsOptionsT *AsExpandDimsOptions() { + return type == BuiltinOptions_ExpandDimsOptions ? + reinterpret_cast(value) : nullptr; + } + const ExpandDimsOptionsT *AsExpandDimsOptions() const { + return type == BuiltinOptions_ExpandDimsOptions ? + reinterpret_cast(value) : nullptr; + } + EqualOptionsT *AsEqualOptions() { + return type == BuiltinOptions_EqualOptions ? + reinterpret_cast(value) : nullptr; + } + const EqualOptionsT *AsEqualOptions() const { + return type == BuiltinOptions_EqualOptions ? + reinterpret_cast(value) : nullptr; + } + NotEqualOptionsT *AsNotEqualOptions() { + return type == BuiltinOptions_NotEqualOptions ? + reinterpret_cast(value) : nullptr; + } + const NotEqualOptionsT *AsNotEqualOptions() const { + return type == BuiltinOptions_NotEqualOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -1365,6 +1473,35 @@ inline const char *EnumNameLSHProjectionType(LSHProjectionType e) { return EnumNamesLSHProjectionType()[index]; } +enum LSTMKernelType { + LSTMKernelType_FULL = 0, + LSTMKernelType_BASIC = 1, + LSTMKernelType_MIN = LSTMKernelType_FULL, + LSTMKernelType_MAX = LSTMKernelType_BASIC +}; + +inline LSTMKernelType (&EnumValuesLSTMKernelType())[2] { + static LSTMKernelType values[] = { + LSTMKernelType_FULL, + LSTMKernelType_BASIC + }; + return values; +} + +inline const char **EnumNamesLSTMKernelType() { + static const char *names[] = { + "FULL", + "BASIC", + nullptr + }; + return names; +} + +inline const char *EnumNameLSTMKernelType(LSTMKernelType e) { + const size_t index = static_cast(e); + return EnumNamesLSTMKernelType()[index]; +} + enum CombinerType { CombinerType_SUM = 0, CombinerType_MEAN = 1, @@ -2802,10 +2939,12 @@ struct LSTMOptionsT : public flatbuffers::NativeTable { ActivationFunctionType fused_activation_function; float cell_clip; float proj_clip; + LSTMKernelType kernel_type; LSTMOptionsT() : fused_activation_function(ActivationFunctionType_NONE), cell_clip(0.0f), - proj_clip(0.0f) { + proj_clip(0.0f), + kernel_type(LSTMKernelType_FULL) { } }; @@ -2814,7 +2953,8 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { enum { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, - VT_PROJ_CLIP = 8 + VT_PROJ_CLIP = 8, + VT_KERNEL_TYPE = 10 }; ActivationFunctionType fused_activation_function() const { return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); @@ -2825,11 +2965,15 @@ struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { float proj_clip() const { return GetField(VT_PROJ_CLIP, 0.0f); } + LSTMKernelType kernel_type() const { + return static_cast(GetField(VT_KERNEL_TYPE, 0)); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_CELL_CLIP) && VerifyField(verifier, VT_PROJ_CLIP) && + VerifyField(verifier, VT_KERNEL_TYPE) && verifier.EndTable(); } LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2849,6 +2993,9 @@ struct LSTMOptionsBuilder { void add_proj_clip(float proj_clip) { fbb_.AddElement(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); } + void add_kernel_type(LSTMKernelType kernel_type) { + fbb_.AddElement(LSTMOptions::VT_KERNEL_TYPE, static_cast(kernel_type), 0); + } explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2865,10 +3012,12 @@ inline flatbuffers::Offset CreateLSTMOptions( flatbuffers::FlatBufferBuilder &_fbb, ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, float cell_clip = 0.0f, - float proj_clip = 0.0f) { + float proj_clip = 0.0f, + LSTMKernelType kernel_type = LSTMKernelType_FULL) { LSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); + builder_.add_kernel_type(kernel_type); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } @@ -4131,6 +4280,46 @@ inline flatbuffers::Offset CreateMaximumMinimumOptions( flatbuffers::Offset CreateMaximumMinimumOptions(flatbuffers::FlatBufferBuilder &_fbb, const MaximumMinimumOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct TileOptionsT : public flatbuffers::NativeTable { + typedef TileOptions TableType; + TileOptionsT() { + } +}; + +struct TileOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TileOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + TileOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TileOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit TileOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TileOptionsBuilder &operator=(const TileOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTileOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + TileOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateTileOptions(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct ArgMaxOptionsT : public flatbuffers::NativeTable { typedef ArgMaxOptions TableType; TensorType output_type; @@ -4543,6 +4732,180 @@ inline flatbuffers::Offset CreateTransposeConvOptions( flatbuffers::Offset CreateTransposeConvOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct ExpandDimsOptionsT : public flatbuffers::NativeTable { + typedef ExpandDimsOptions TableType; + ExpandDimsOptionsT() { + } +}; + +struct ExpandDimsOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ExpandDimsOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ExpandDimsOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ExpandDimsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ExpandDimsOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit ExpandDimsOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ExpandDimsOptionsBuilder &operator=(const ExpandDimsOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateExpandDimsOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + ExpandDimsOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateExpandDimsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct SparseToDenseOptionsT : public flatbuffers::NativeTable { + typedef SparseToDenseOptions TableType; + bool validate_indices; + SparseToDenseOptionsT() + : validate_indices(false) { + } +}; + +struct SparseToDenseOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SparseToDenseOptionsT NativeTableType; + enum { + VT_VALIDATE_INDICES = 4 + }; + bool validate_indices() const { + return GetField(VT_VALIDATE_INDICES, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VALIDATE_INDICES) && + verifier.EndTable(); + } + SparseToDenseOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SparseToDenseOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_validate_indices(bool validate_indices) { + fbb_.AddElement(SparseToDenseOptions::VT_VALIDATE_INDICES, static_cast(validate_indices), 0); + } + explicit SparseToDenseOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SparseToDenseOptionsBuilder &operator=(const SparseToDenseOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSparseToDenseOptions( + flatbuffers::FlatBufferBuilder &_fbb, + bool validate_indices = false) { + SparseToDenseOptionsBuilder builder_(_fbb); + builder_.add_validate_indices(validate_indices); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct EqualOptionsT : public flatbuffers::NativeTable { + typedef EqualOptions TableType; + EqualOptionsT() { + } +}; + +struct EqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef EqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + EqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct EqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit EqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + EqualOptionsBuilder &operator=(const EqualOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + EqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct NotEqualOptionsT : public flatbuffers::NativeTable { + typedef NotEqualOptions TableType; + NotEqualOptionsT() { + } +}; + +struct NotEqualOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef NotEqualOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + NotEqualOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct NotEqualOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit NotEqualOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + NotEqualOptionsBuilder &operator=(const NotEqualOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateNotEqualOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + NotEqualOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -4821,6 +5184,21 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const TransposeConvOptions *builtin_options_as_TransposeConvOptions() const { return builtin_options_type() == BuiltinOptions_TransposeConvOptions ? static_cast(builtin_options()) : nullptr; } + const SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const { + return builtin_options_type() == BuiltinOptions_SparseToDenseOptions ? static_cast(builtin_options()) : nullptr; + } + const TileOptions *builtin_options_as_TileOptions() const { + return builtin_options_type() == BuiltinOptions_TileOptions ? static_cast(builtin_options()) : nullptr; + } + const ExpandDimsOptions *builtin_options_as_ExpandDimsOptions() const { + return builtin_options_type() == BuiltinOptions_ExpandDimsOptions ? static_cast(builtin_options()) : nullptr; + } + const EqualOptions *builtin_options_as_EqualOptions() const { + return builtin_options_type() == BuiltinOptions_EqualOptions ? static_cast(builtin_options()) : nullptr; + } + const NotEqualOptions *builtin_options_as_NotEqualOptions() const { + return builtin_options_type() == BuiltinOptions_NotEqualOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -5043,6 +5421,26 @@ template<> inline const TransposeConvOptions *Operator::builtin_options_as inline const SparseToDenseOptions *Operator::builtin_options_as() const { + return builtin_options_as_SparseToDenseOptions(); +} + +template<> inline const TileOptions *Operator::builtin_options_as() const { + return builtin_options_as_TileOptions(); +} + +template<> inline const ExpandDimsOptions *Operator::builtin_options_as() const { + return builtin_options_as_ExpandDimsOptions(); +} + +template<> inline const EqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_EqualOptions(); +} + +template<> inline const NotEqualOptions *Operator::builtin_options_as() const { + return builtin_options_as_NotEqualOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -6008,6 +6406,7 @@ inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_ { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; { auto _e = cell_clip(); _o->cell_clip = _e; }; { auto _e = proj_clip(); _o->proj_clip = _e; }; + { auto _e = kernel_type(); _o->kernel_type = _e; }; } inline flatbuffers::Offset LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -6021,11 +6420,13 @@ inline flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBuffe auto _fused_activation_function = _o->fused_activation_function; auto _cell_clip = _o->cell_clip; auto _proj_clip = _o->proj_clip; + auto _kernel_type = _o->kernel_type; return tflite::CreateLSTMOptions( _fbb, _fused_activation_function, _cell_clip, - _proj_clip); + _proj_clip, + _kernel_type); } inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -6643,6 +7044,29 @@ inline flatbuffers::Offset CreateMaximumMinimumOptions(fl _fbb); } +inline TileOptionsT *TileOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new TileOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void TileOptions::UnPackTo(TileOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset TileOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateTileOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateTileOptions(flatbuffers::FlatBufferBuilder &_fbb, const TileOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TileOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateTileOptions( + _fbb); +} + inline ArgMaxOptionsT *ArgMaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ArgMaxOptionsT(); UnPackTo(_o, _resolver); @@ -6862,6 +7286,101 @@ inline flatbuffers::Offset CreateTransposeConvOptions(flat _stride_h); } +inline ExpandDimsOptionsT *ExpandDimsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ExpandDimsOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ExpandDimsOptions::UnPackTo(ExpandDimsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset ExpandDimsOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateExpandDimsOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateExpandDimsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpandDimsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ExpandDimsOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateExpandDimsOptions( + _fbb); +} + +inline SparseToDenseOptionsT *SparseToDenseOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SparseToDenseOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SparseToDenseOptions::UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = validate_indices(); _o->validate_indices = _e; }; +} + +inline flatbuffers::Offset SparseToDenseOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSparseToDenseOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SparseToDenseOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _validate_indices = _o->validate_indices; + return tflite::CreateSparseToDenseOptions( + _fbb, + _validate_indices); +} + +inline EqualOptionsT *EqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new EqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void EqualOptions::UnPackTo(EqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset EqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const EqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const EqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateEqualOptions( + _fbb); +} + +inline NotEqualOptionsT *NotEqualOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new NotEqualOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void NotEqualOptions::UnPackTo(NotEqualOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset NotEqualOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateNotEqualOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateNotEqualOptions(flatbuffers::FlatBufferBuilder &_fbb, const NotEqualOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const NotEqualOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateNotEqualOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -7244,6 +7763,26 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -7458,6 +7997,26 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -7660,6 +8219,26 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateTransposeConvOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(value); + return CreateSparseToDenseOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(value); + return CreateTileOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(value); + return CreateExpandDimsOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(value); + return CreateEqualOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(value); + return CreateNotEqualOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -7862,6 +8441,26 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new TransposeConvOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_SparseToDenseOptions: { + value = new SparseToDenseOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_TileOptions: { + value = new TileOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ExpandDimsOptions: { + value = new ExpandDimsOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_EqualOptions: { + value = new EqualOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_NotEqualOptions: { + value = new NotEqualOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -8114,6 +8713,31 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_TileOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_ExpandDimsOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_EqualOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_NotEqualOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/simple_memory_arena.cc b/tensorflow/contrib/lite/simple_memory_arena.cc index 2f2004f56bcad5b56f9dd6d4bc824ec14d79e795..4eaf6f1bfe76efc1e6737d03d58be9bc87bb849d 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.cc +++ b/tensorflow/contrib/lite/simple_memory_arena.cc @@ -36,6 +36,12 @@ TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context, ArenaAlloc* new_alloc) { TF_LITE_ENSURE(context, alignment < arena_alignment_); + if (size == 0) { + new_alloc->offset = 0; + new_alloc->size = 0; + return kTfLiteOk; + } + size_t current_top = 0; if (!allocs_.empty()) { @@ -75,6 +81,10 @@ TfLiteStatus SimpleMemoryArena::Allocate(TfLiteContext* context, TfLiteStatus SimpleMemoryArena::Deallocate(TfLiteContext* context, const ArenaAlloc& alloc) { + if (alloc.size == 0) { + return kTfLiteOk; + } + int erased_allocs_count = 0; auto it = allocs_.begin(); while (it != allocs_.end()) { @@ -122,7 +132,11 @@ TfLiteStatus SimpleMemoryArena::ResolveAlloc(TfLiteContext* context, char** output_ptr) { TF_LITE_ENSURE(context, committed_); TF_LITE_ENSURE(context, output_ptr != nullptr); - *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + if (alloc.size == 0) { + *output_ptr = nullptr; + } else { + *output_ptr = underlying_buffer_aligned_ptr_ + alloc.offset; + } return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h index 5faf78b59e3755d22e4e866d433e622baa6c66c1..f738315cf2f91403f9dcb6fa9e66b40bd70495aa 100644 --- a/tensorflow/contrib/lite/simple_memory_arena.h +++ b/tensorflow/contrib/lite/simple_memory_arena.h @@ -39,7 +39,8 @@ struct ArenaAlloc { // This small class is responsible for allocating, deallocating and reusing // dynamic memory from a common underlying buffer. The arena can be used in // scenarios when the pattern of memory allocations and deallocations is -// repetitive, e.g. running NN inference in multiple iterations. +// repetitive, e.g. running NN inference in multiple iterations. Note that +// zero-sized allocations are explicitly allowed, and will resolve to null. class SimpleMemoryArena { public: explicit SimpleMemoryArena(size_t arena_alignment) diff --git a/tensorflow/contrib/lite/simple_memory_arena_test.cc b/tensorflow/contrib/lite/simple_memory_arena_test.cc index 4444f642eb75c563c57762d095e454ac63d836c6..60d4d5e768aeda958574422e1c36a7cc2f6a1429 100644 --- a/tensorflow/contrib/lite/simple_memory_arena_test.cc +++ b/tensorflow/contrib/lite/simple_memory_arena_test.cc @@ -43,6 +43,47 @@ TEST(SimpleMemoryArenaTest, BasicArenaOperations) { EXPECT_EQ(allocs[5].offset, 1024); } +TEST(SimpleMemoryArenaTest, BasicZeroAlloc) { + TfLiteContext context; + SimpleMemoryArena arena(64); + ArenaAlloc alloc; + + // Zero-sized allocs should have a 0 offset and size. + ASSERT_EQ(arena.Allocate(&context, 32, 0, &alloc), kTfLiteOk); + EXPECT_EQ(alloc.offset, 0); + EXPECT_EQ(alloc.size, 0); + + // Deallocation of zero-sized allocs should always succeed (even redundantly). + ASSERT_EQ(arena.Deallocate(&context, alloc), kTfLiteOk); + ASSERT_EQ(arena.Deallocate(&context, alloc), kTfLiteOk); + + // The zero-sized alloc should resolve to null. + char* resolved_ptr = nullptr; + ASSERT_EQ(arena.Commit(&context), kTfLiteOk); + ASSERT_EQ(arena.ResolveAlloc(&context, alloc, &resolved_ptr), kTfLiteOk); + EXPECT_EQ(resolved_ptr, nullptr); +} + +TEST(SimpleMemoryArenaTest, InterleavedZeroAlloc) { + TfLiteContext context; + SimpleMemoryArena arena(64); + ArenaAlloc allocs[4]; + + // Interleave some zero and non-zero-sized allocations and deallocations. + ASSERT_EQ(arena.Allocate(&context, 32, 2047, &allocs[0]), kTfLiteOk); + ASSERT_EQ(arena.Allocate(&context, 32, 0, &allocs[1]), kTfLiteOk); + ASSERT_EQ(arena.Allocate(&context, 32, 1023, &allocs[2]), kTfLiteOk); + ASSERT_EQ(arena.Deallocate(&context, allocs[1]), kTfLiteOk); + ASSERT_EQ(arena.Deallocate(&context, allocs[2]), kTfLiteOk); + ASSERT_EQ(arena.Allocate(&context, 32, 2047, &allocs[3]), kTfLiteOk); + + // Deallocation of a zero-sized alloc should not impact the allocator offsets. + EXPECT_EQ(allocs[0].offset, 0); + EXPECT_EQ(allocs[1].offset, 0); + EXPECT_EQ(allocs[2].offset, 2048); + EXPECT_EQ(allocs[3].offset, 2048); +} + TEST(SimpleMemoryArenaTest, TestAfterClear) { TfLiteContext context; SimpleMemoryArena arena(64); diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 74fc32a12b12ec3bca81590a74b81bc3caff0d96..b823c97f38e7660652aa0ce3538b11de59dc9aea 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -20,11 +20,15 @@ load( size = "large", srcs = ["generated_examples_zip_test.cc"], args = [ - "--zip_file_path=$(location :zip_%s)" % test_name, - # TODO(angerson) We may be able to add an external unzip binary instead - # of relying on an existing one for OSS builds. - "--unzip_binary_path=/usr/bin/unzip", - ], + ] + select({ + "//tensorflow:android": [], + "//conditions:default": [ + "--zip_file_path=$(location :zip_%s)" % test_name, + # TODO(angerson) We may be able to add an external unzip binary instead + # of relying on an existing one for OSS builds. + "--unzip_binary_path=/usr/bin/unzip", + ], + }), data = [ ":zip_%s" % test_name, ], @@ -155,6 +159,7 @@ cc_library( deps = [ ":split", ":test_runner", + "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:builtin_ops", ], diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 0e036bda92e4c42056f3f5df9fd8bcddcd932c13..f5e25784fa17209af7cfb06d32aeea2b9b947196 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -58,10 +58,11 @@ from tensorflow.python.ops import rnn parser = argparse.ArgumentParser(description="Script to generate TFLite tests.") parser.add_argument("output_path", help="Directory where the outputs will be go.") -parser.add_argument("--zip_to_output", - type=str, - help="Particular zip to output.", - required=False) +parser.add_argument( + "--zip_to_output", + type=str, + help="Particular zip to output.", + required=True) parser.add_argument("--toco", type=str, help="Path to toco tool.", @@ -97,8 +98,6 @@ KNOWN_BUGS = { r"fully_connected.*transpose_.=True": "67586970", # Softmax graphs are too complex. r"softmax.*dim=0": "67749831", - # SpaceToDepth only supports float32. - r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", # BatchToSpaceND only supports 4D tensors. r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733", # Div will use floordiv. @@ -118,6 +117,8 @@ class ExtraTocoOptions(object): self.allow_custom_ops = False # Rnn states that are used to support rnn / lstm cells. self.rnn_states = None + # Split the LSTM inputs from 5 inoputs to 18 inputs for TFLite. + self.split_tflite_lstm_inputs = None def toco_options(data_types, @@ -146,14 +147,20 @@ def toco_options(data_types, " --inference_type=%s" % inference_type + " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" + " --input_arrays=%s" % ",".join(input_arrays) + - " --input_shapes=%s" % shape_str + " --output_arrays=%s" % ",".join(output_arrays)) + if shape_str: + s += (" --input_shapes=%s" % shape_str) if extra_toco_options.drop_control_dependency: s += " --drop_control_dependency" if extra_toco_options.allow_custom_ops: s += " --allow_custom_ops" if extra_toco_options.rnn_states: s += (" --rnn_states='" + extra_toco_options.rnn_states + "'") + if extra_toco_options.split_tflite_lstm_inputs is not None: + if extra_toco_options.split_tflite_lstm_inputs: + s += " --split_tflite_lstm_inputs=true" + else: + s += " --split_tflite_lstm_inputs=false" return s @@ -238,6 +245,19 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): return value.astype(dtype) +def create_scalar_data(dtype, min_value=-100, max_value=100): + """Build scalar tensor data range from min_value to max_value exclusively.""" + + if dtype in _TF_TYPE_INFO: + dtype = _TF_TYPE_INFO[dtype][0] + + if dtype in (tf.float32, tf.float16): + value = (max_value - min_value) * np.random.random() + min_value + elif dtype in (tf.int32, tf.uint8, tf.int64): + value = np.random.randint(min_value, max_value + 1) + return np.array(value, dtype=dtype) + + def freeze_graph(session, outputs): """Freeze the current graph. @@ -385,7 +405,7 @@ def make_zip_of_tests(zip_path, for parameters in test_parameters: keys = parameters.keys() for curr in itertools.product(*parameters.values()): - label = zip_path.replace(".zip", "") + (",".join( + label = zip_path.replace(".zip", "_") + (",".join( "%s=%r" % z for z in sorted(zip(keys, curr))).replace(" ", "")) if label[0] == "/": label = label[1:] @@ -447,6 +467,11 @@ def make_zip_of_tests(zip_path, sess, tf.global_variables() + inputs + outputs) if use_frozen_graph else sess.graph_def + + if "split_tflite_lstm_inputs" in param_dict_real: + extra_toco_options.split_tflite_lstm_inputs = param_dict_real[ + "split_tflite_lstm_inputs"] + tflite_model_binary, toco_log = toco_convert( graph_def.SerializeToString(), input_tensors, output_tensors, extra_toco_options) @@ -730,65 +755,83 @@ def make_binary_op_tests(zip_path, binary_operator): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_mean_tests(zip_path): - """Make a set of tests to do mean.""" +def make_reduce_tests(reduce_op): + """Make a set of tests to do reduce operation. - test_parameters = [{ - "input_dtype": [tf.float32, tf.int32, tf.int64], - "input_shape": [[3, 2, 4]], - "axis": [ - None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], - [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], - [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] - ], - "const_axis": [True, False], - "keepdims": [True, False], - }, { - "input_dtype": [tf.float32], - "input_shape": [[1, 8, 8, 3]], - "axis": [ - None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], - [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2, - -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], - [2, 2, 3], [-3, -3, -4], [-3, 2, 1] - ], - "const_axis": [True, False], - "keepdims": [True, False], - }] + Args: + reduce_op: TensorFlow reduce operation to test, i.e. `tf.reduce_mean`. - def build_graph(parameters): - """Build the mean op testing graph.""" - input_tensor = tf.placeholder( - dtype=parameters["input_dtype"], - name="input", - shape=parameters["input_shape"]) + Returns: + a function representing the true generator with `reduce_op_in` curried. + """ - # Get axis as either a placeholder or constants. - if parameters["const_axis"]: - axis = parameters["axis"] - input_tensors = [input_tensor] - else: - if isinstance(parameters["axis"], list): - shape = [len(parameters["axis"])] + def f(zip_path): + """Actual function that generates examples.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape": [[3, 2, 4]], + "axis": [ + None, 0, 1, 2, [0, 1], [0, 2], [1, 2], [0, 1, 2], [1, 0], [2, 0], + [2, 1], [2, 1, 0], [2, 0, 1], -1, -2, -3, [1, -1], [0, -1], [-1, 0], + [-1, -2, -3], [0, 0, 0], [2, 2, 0], [1, 0, -3, -3] + ], + "const_axis": [True, False], + "keepdims": [True, False], + }, { + "input_dtype": [tf.float32], + "input_shape": [[1, 8, 8, 3]], + "axis": [ + None, 0, 1, 2, 3, [1, 2], [0, 3], [1, 2, 3], [0, 1, 2, 3], + [3, 2, 1, 0], [3, 1, 0, 2], [2, 0], [3, 0], [3, 1], [1, 0], -1, -2, + -3, -4, [0, -2], [2, 3, -1, 0], [3, 1, 2, -3], [3, -4], [2, 2, 2], + [2, 2, 3], [-3, -3, -4], [-3, 2, 1] + ], + "const_axis": [True, False], + "keepdims": [True, False], + }] + + def build_graph(parameters): + """Build the mean op testing graph.""" + input_tensor = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + + # Get axis as either a placeholder or constants. + if parameters["const_axis"]: + axis = parameters["axis"] + input_tensors = [input_tensor] else: - shape = [0] # shape for None or integers. - axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) - input_tensors = [input_tensor, axis] + if isinstance(parameters["axis"], list): + shape = [len(parameters["axis"])] + else: + shape = [0] # shape for None or integers. + axis = tf.placeholder(dtype=tf.int32, name="axis", shape=shape) + input_tensors = [input_tensor, axis] - out = tf.reduce_mean( - input_tensor, axis=axis, keepdims=parameters["keepdims"]) - return input_tensors, [out] + out = reduce_op( + input_tensor, axis=axis, keepdims=parameters["keepdims"]) + return input_tensors, [out] - def build_inputs(parameters, sess, inputs, outputs): - values = [ - create_tensor_data(parameters["input_dtype"], parameters["input_shape"]) - ] - if not parameters["const_axis"]: - if parameters["axis"]: - values.append(np.array(parameters["axis"])) - return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + def build_inputs(parameters, sess, inputs, outputs): + values = [ + create_tensor_data(parameters["input_dtype"], + parameters["input_shape"])] + if not parameters["const_axis"]: + if parameters["axis"]: + values.append(np.array(parameters["axis"])) + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return f + + +def make_mean_tests(zip_path): + """Make a set of tests to do mean.""" + + return make_reduce_tests(tf.reduce_mean)(zip_path) def make_exp_tests(zip_path): @@ -1577,7 +1620,7 @@ def make_space_to_depth_tests(zip_path): """Make a set of tests to do space_to_depth.""" test_parameters = [{ - "dtype": [tf.float32, tf.float16, tf.int32, tf.uint8, tf.int64], + "dtype": [tf.float32, tf.int32, tf.uint8, tf.int64], "input_shape": [[2, 12, 24, 1]], "block_size": [2, 3, 4], }] @@ -1987,6 +2030,7 @@ def make_lstm_tests(zip_path): "time_step_size": [1], "input_vec_size": [3], "num_cells": [4], + "split_tflite_lstm_inputs": [True, False], }, ] @@ -2121,6 +2165,74 @@ def make_arg_max_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_equal_tests(zip_path): + """Make a set of tests to do equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the equal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_not_equal_tests(zip_path): + """Make a set of tests to do not equal.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32, tf.int64], + "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]), + ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]), + ([5, 5], [1]), ([10], [2, 4, 10])], + }] + + def build_graph(parameters): + """Build the not euqal op testing graph.""" + input_value1 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape_pair"][0]) + input_value2 = tf.placeholder( + dtype=parameters["input_dtype"], + name="input2", + shape=parameters["input_shape_pair"][1]) + out = tf.not_equal(input_value1, input_value2) + return [input_value1, input_value2], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value1 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][0]) + input_value2 = create_tensor_data(parameters["input_dtype"], + parameters["input_shape_pair"][1]) + return [input_value1, input_value2], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_greater_tests(zip_path): """Make a set of tests to do greater.""" @@ -2308,30 +2420,44 @@ def make_neg_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) -def make_sin_tests(zip_path): - """Make a set of tests to do sin.""" +def _make_elementwise_tests(op): + """Make a set of tests to do element-wise operations.""" - test_parameters = [{ - "input_dtype": [tf.float32], - "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], - }] + def f(zip_path): + """Actual function that generates examples.""" + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]], + }] - def build_graph(parameters): - """Build the sin op testing graph.""" - input_value = tf.placeholder( - dtype=parameters["input_dtype"], - name="input1", - shape=parameters["input_shape"]) - out = tf.sin(input_value) - return [input_value], [out] + def build_graph(parameters): + """Build the sin op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + name="input1", + shape=parameters["input_shape"]) + out = op(input_value) + return [input_value], [out] - def build_inputs(parameters, sess, inputs, outputs): - input_value = create_tensor_data(parameters["input_dtype"], - parameters["input_shape"]) - return [input_value], sess.run( - outputs, feed_dict={inputs[0]: input_value}) + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict={inputs[0]: input_value}) - make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + return f + + +def make_sin_tests(zip_path): + """Make a set of tests to do sin.""" + return _make_elementwise_tests(tf.sin)(zip_path) + + +def make_log_tests(zip_path): + """Make a set of tests to do log.""" + return _make_elementwise_tests(tf.log)(zip_path) def make_where_tests(zip_path): @@ -2485,6 +2611,134 @@ def make_transpose_conv_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_tile_tests(zip_path): + """Make a set of tests to do tile.""" + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[3, 2, 1], [2, 2, 2]], + "multiplier_dtype": [tf.int32, tf.int64], + "multiplier_shape": [[3]] + }] + + def build_graph(parameters): + """Build the tile op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + shape=parameters["input_shape"], + name="input") + multiplier_value = tf.placeholder( + dtype=parameters["multiplier_dtype"], + shape=parameters["multiplier_shape"], + name="multiplier") + out = tf.tile(input_value, multiplier_value) + return [input_value, multiplier_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + multipliers_value = create_tensor_data(parameters["multiplier_dtype"], + parameters["multiplier_shape"]) + return [input_value, multipliers_value], sess.run( + outputs, + feed_dict={ + inputs[0]: input_value, + inputs[1]: multipliers_value + }) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_expand_dims_tests(zip_path): + """Make a set of tests to do expand_dims.""" + + test_parameters = [{ + "input_type": [tf.float32, tf.int32], + "input_shape": [[3, 4], [10, 10, 3]], + "axis_value": [0, 1, 2, -1, -2], + }] + + def build_graph(parameters): + """Build the where op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_type"], + name="input", + shape=parameters["input_shape"]) + axis_value = tf.placeholder(dtype=tf.int32, name="axis", shape=[1]) + out = tf.expand_dims(input_value, axis=axis_value) + return [input_value, axis_value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_type"], + parameters["input_shape"]) + axis_value = np.array([parameters["axis_value"]], dtype=np.int32) + return [input_value, axis_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value, axis_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_sparse_to_dense_tests(zip_path): + """Make a set of tests to do sparse to dense.""" + + test_parameters = [{ + "value_dtype": [tf.float32, tf.int32], + "index_dtype": [tf.int32, tf.int64], + "value_count": [1, 3, 6, 8], + "dense_shape": [[15], [3, 10], [4, 4, 4, 4], [7, 10, 9]], + "default_value": [0, -1], + "value_is_scalar": [True, False], + }] + + # Return a single value for 1-D dense shape, but a tuple for other shapes. + def generate_index(dense_shape): + if len(dense_shape) == 1: + return np.random.randint(dense_shape[0]) + else: + index = [] + for shape in dense_shape: + index.append(np.random.randint(shape)) + return tuple(index) + + def build_graph(parameters): + """Build the sparse_to_dense op testing graph.""" + dense_shape = parameters["dense_shape"] + + # Special handle for value_is_scalar case. + # value_count must be 1. + if parameters["value_is_scalar"] and parameters["value_count"] == 1: + value = tf.placeholder( + name="value", dtype=parameters["value_dtype"], shape=()) + else: + value = tf.placeholder( + name="value", + dtype=parameters["value_dtype"], + shape=[parameters["value_count"]]) + indices = set() + while len(indices) < parameters["value_count"]: + indices.add(generate_index(dense_shape)) + indices = tf.constant(tuple(indices), dtype=parameters["index_dtype"]) + # TODO(renjieliu): Add test for validate_indices case. + out = tf.sparse_to_dense( + indices, + dense_shape, + value, + parameters["default_value"], + validate_indices=False) + + return [value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + if parameters["value_is_scalar"] and parameters["value_count"] == 1: + input_value = create_scalar_data(parameters["value_dtype"]) + else: + input_value = create_tensor_data(parameters["value_dtype"], + [parameters["value_count"]]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + # Toco binary path provided by the generate rule. bin_path = None diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 2f069ff8e79b4a08824121c49e9327619cfeb858..8a59d756f8dbbcefc930b5285c1ced8ce6b08845 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -36,7 +36,12 @@ bool FLAGS_ignore_known_bugs = true; // TODO(b/71769302) zip_files_dir should have a more accurate default, if // possible string* FLAGS_zip_file_path = new string("./"); +#ifndef __ANDROID__ string* FLAGS_unzip_binary_path = new string("/usr/bin/unzip"); +#else +string* FLAGS_unzip_binary_path = new string("/system/bin/unzip"); +#endif +bool FLAGS_use_nnapi = false; } // namespace // TensorFlow system environment for file system called. @@ -48,7 +53,7 @@ tensorflow::Env* env = tensorflow::Env::Default(); // TODO(ahentz): make sure we clean this list up frequently. std::map kBrokenTests = { // Add only supports float32. (and "constant" tests use Add) - {R"(^\/adda.*int32)", "68808744"}, + {R"(^\/add_a.*int32)", "68808744"}, {R"(^\/constant.*int32)", "68808744"}, {R"(^\/mul.*int32)", "68808744"}, {R"(^\/div.*int32)", "68808744"}, @@ -61,25 +66,25 @@ std::map kBrokenTests = { "70527055"}, // L2Norm only supports tensors with 4D or fewer. - {R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, + {R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, // SpaceToBatchND only supports 4D tensors. {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"}, // L2Norm only works for dim=-1. - {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"}, - {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"}, - {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, - {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", + {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[.,.\])", "67963812"}, + {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[.,.\])", "67963812"}, + {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=-2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=2,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=0,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, + {R"(^\/l2norm_dim=1,epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[3,15,14,3\])", "67963812"}, - {R"(^\/l2normdim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, + {R"(^\/l2norm_dim=\[2,3\],epsilon=.*,input_shape=\[1,3,4,3\])", "67963812"}, // ResizeBilinear looks completely incompatible with Tensorflow {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"}, @@ -212,7 +217,7 @@ TEST_P(OpsTest, RunZipTests) { std::ifstream tflite_stream(tflite_test_case); ASSERT_TRUE(tflite_stream.is_open()) << tflite_test_case; - tflite::testing::TfLiteDriver test_driver(/*use_nnapi=*/true); + tflite::testing::TfLiteDriver test_driver(FLAGS_use_nnapi); test_driver.SetModelBaseDir(tflite_dir); string bug_number; @@ -273,7 +278,10 @@ int main(int argc, char** argv) { "Required: Location of the test zip file."), tensorflow::Flag("unzip_binary_path", tflite::testing::FLAGS_unzip_binary_path, - "Required: Location of a suitable unzip binary.")}; + "Required: Location of a suitable unzip binary."), + tensorflow::Flag("use_nnapi", &tflite::testing::FLAGS_use_nnapi, + "Whether to enable the NNAPI delegate")}; + bool success = tensorflow::Flags::Parse(&argc, argv, flags); if (!success || (argc == 2 && !strcmp(argv[1], "--helpfull"))) { fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); @@ -281,6 +289,8 @@ int main(int argc, char** argv) { } ::tflite::LogToStderr(); + // TODO(mikie): googletest arguments do not work - maybe the tensorflow flags + // parser removes them? ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 75ac24719aa8fad960ae06d006eda386d44d721a..f518bf864c6a71679400f0013bbcd40142bb8ca1 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/testing/split.h" namespace tflite { @@ -143,6 +144,7 @@ void TfLiteDriver::AllocateTensors() { Invalidate("Failed to allocate tensors"); return; } + ResetLSTMStateTensors(); must_allocate_tensors_ = false; } } @@ -161,6 +163,7 @@ void TfLiteDriver::LoadModel(const string& bin_file_path) { Invalidate("Failed build interpreter"); return; } + interpreter_->UseNNAPI(use_nnapi_); must_allocate_tensors_ = true; } @@ -281,5 +284,36 @@ bool TfLiteDriver::CheckResults() { return success; } +void TfLiteDriver::ResetLSTMStateTensors() { + // This is a workaround for initializing state tensors for LSTM. + // TODO(ycling): Refactoring and find a better way to initialize state + // tensors. Maybe write the reset instructions into the test data. + for (auto node_index : interpreter_->execution_plan()) { + const auto& node_and_reg = interpreter_->node_and_registration(node_index); + const auto& node = node_and_reg->first; + const auto& registration = node_and_reg->second; + + if (registration.builtin_code == tflite::BuiltinOperator_LSTM) { + const auto* params = + reinterpret_cast(node.builtin_data); + if (params->kernel_type == kTfLiteLSTMFullKernel && + node.outputs->size >= 2) { + // The first 2 outputs of LSTM are state tensors. + for (int i = 0; i < 2; ++i) { + int node_index = node.outputs->data[i]; + ResetTensor(node_index); + } + } else if (params->kernel_type == kTfLiteLSTMBasicKernel && + node.inputs->size == 5) { + // The 2th and 5th inputs are state tensors. + for (int i : {1, 4}) { + int node_index = node.inputs->data[i]; + ResetTensor(node_index); + } + } + } + } +} + } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h index 02b7de1534e648734d7bc53154afa42f2ef256b4..5493ba3631b0423942cc9c4f98fbd6393a404060 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.h +++ b/tensorflow/contrib/lite/testing/tflite_driver.h @@ -48,6 +48,8 @@ class TfLiteDriver : public TestRunner { string ReadOutput(int id) override { return "no-op"; } private: + void ResetLSTMStateTensors(); + class Expectation; bool use_nnapi_ = false; diff --git a/tensorflow/contrib/lite/tflite_static.bp b/tensorflow/contrib/lite/tflite_static.bp index d873ee234e802b62bed08a3bfa8dbca4069d9269..5a78ace359bb99b25f374262b6f689a4ad88cbd1 100644 --- a/tensorflow/contrib/lite/tflite_static.bp +++ b/tensorflow/contrib/lite/tflite_static.bp @@ -24,10 +24,10 @@ cc_library_static { "error_reporter.cc", "graph_info.cc", "interpreter.cc", - "model.cc", + "model.cc", "nnapi_delegate.cc", "optional_debug_tools.cc", - "op_resolver.cc", + "op_resolver.cc", "simple_memory_arena.cc", "string_util.cc", "util.cc", @@ -52,6 +52,7 @@ cc_library_static { "kernels/embedding_lookup.cc", "kernels/embedding_lookup_sparse.cc", "kernels/exp.cc", + "kernels/expand_dims.cc", "kernels/floor.cc", "kernels/fully_connected.cc", "kernels/gather.cc", @@ -71,15 +72,17 @@ cc_library_static { "kernels/register.cc", "kernels/reshape.cc", "kernels/resize_bilinear.cc", - "kernels/select.cc", + "kernels/select.cc", "kernels/skip_gram.cc", "kernels/slice.cc", "kernels/space_to_batch_nd.cc", "kernels/space_to_depth.cc", + "kernels/sparse_to_dense.cc", "kernels/squeeze.cc", "kernels/strided_slice.cc", "kernels/sub.cc", "kernels/svdf.cc", + "kernels/tile.cc", "kernels/transpose.cc", "kernels/transpose_conv.cc", "kernels/unidirectional_sequence_lstm.cc", diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index b8acc9a8e0361a4c38fcbe2f16be172e637b95c6..0789dc99286361183aea4c95db98c11ff700ea79 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -245,6 +245,7 @@ cc_library( "graph_transformations/quantization_util.cc", "graph_transformations/quantization_util.h", "graph_transformations/quantize.cc", + "graph_transformations/quantize_weights.cc", "graph_transformations/read_fake_quant_min_max.cc", "graph_transformations/remove_final_dequantize_op.cc", "graph_transformations/remove_tensorflow_assert.cc", @@ -373,6 +374,7 @@ tf_cc_test( ":toco_tooling", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_googletest//:gtest_main", ], @@ -410,6 +412,7 @@ tf_cc_test( deps = [ ":model", ":tooling_util", + "//tensorflow/core:lib", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 6c0311af0a926711955caaa1c7507d7c52c77069..9f5ca66d050f0ead9b8856c77dba8d9bbd182d10 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -234,6 +234,7 @@ struct ParsedTocoFlags { Arg drop_fake_quant = Arg(false); Arg reorder_across_fake_quant = Arg(false); Arg allow_custom_ops = Arg(false); + Arg quantize_weights = Arg(false); // Deprecated flags Arg input_type; Arg input_types; @@ -242,6 +243,7 @@ struct ParsedTocoFlags { Arg propagate_fake_quant_num_bits = Arg(false); Arg allow_nudging_weights_to_use_fast_gemm_kernel = Arg(false); Arg dedupe_array_min_size_bytes = Arg(64); + Arg split_tflite_lstm_inputs = Arg(true); }; } // namespace toco diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index 6e5927295fe2a9d237683155c9bca90048d478c4..878bda36ef3900d6d8c509aca40cee834cefe514 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -16,8 +16,6 @@ limitations under the License. #include #include -#include -#include #include #include "absl/strings/str_replace.h" @@ -134,6 +132,12 @@ void AppendArrayVal(string* string, Array const& array, int index) { return; } AppendF(string, "%d", data[index]); + } else if (array.buffer->type == ArrayDataType::kBool) { + const auto& data = array.GetBuffer().data; + if (index >= data.size()) { + return; + } + AppendF(string, "%d", data[index]); } } @@ -142,6 +146,7 @@ NodeProperties GetPropertiesForArray(const Model& model, NodeProperties node_properties; node_properties.color = GetColorForArray(model, array_name); node_properties.label = absl::StrReplaceAll(array_name, {{"/", "/\\n"}}); + node_properties.log2_buffer_size = 0.0f; // Append array shape to the label. auto& array = model.GetArray(array_name); @@ -161,9 +166,12 @@ NodeProperties GetPropertiesForArray(const Model& model, } node_properties.label += "]"; - int buffer_size = RequiredBufferSizeForShape(array.shape()); - node_properties.log2_buffer_size = - std::log2(static_cast(buffer_size)); + int buffer_size = 0; + if (IsValid(array.shape())) { + buffer_size = RequiredBufferSizeForShape(array.shape()); + node_properties.log2_buffer_size = + std::log2(static_cast(buffer_size)); + } if (array.buffer) { const auto& array = model.GetArray(array_name); @@ -196,8 +204,6 @@ NodeProperties GetPropertiesForArray(const Model& model, AppendF(&node_properties.label, "}"); } } - } else { - node_properties.log2_buffer_size = 0.0f; } if (array.minmax) { @@ -304,7 +310,15 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { constexpr char kRNNBackEdgeFormat[] = "\t \"%s\" -> \"%s\" [color=\"#0F9D58\"];\n"; - std::set already_added_arrays; + for (const auto& array_kv : model.GetArrayMap()) { + // Add node for array. + const string& array_name = array_kv.first; + const auto& array_properties = GetPropertiesForArray(model, array_name); + AppendF(output_file_contents, kNodeFormat, array_name, + array_properties.label, "octagon", + array_properties.color.FillColorString().c_str(), + array_properties.color.TextColorString().c_str()); + } for (int op_index = 0; op_index < model.operators.size(); op_index++) { const Operator& op = *model.operators[op_index]; // Add node for operator. @@ -313,20 +327,13 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { AppendF(output_file_contents, kNodeFormat, operator_id, op_properties.label, "box", op_properties.color.FillColorString().c_str(), op_properties.color.TextColorString().c_str()); - // Add nodes and edges for all inputs of the operator. + // Add edges for all inputs of the operator. for (const auto& input : op.inputs) { if (!model.HasArray(input)) { // Arrays should _always_ exist. Except, perhaps, during development. continue; } auto array_properties = GetPropertiesForArray(model, input); - if (!already_added_arrays.count(input)) { - AppendF(output_file_contents, kNodeFormat, input, - array_properties.label, "octagon", - array_properties.color.FillColorString().c_str(), - array_properties.color.TextColorString().c_str()); - } - // Draw lines that transport more data thicker (Otherwise, where would the // data fit? right?). float line_width = @@ -342,22 +349,14 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { } AppendF(output_file_contents, kEdgeFormat, input, operator_id, line_width, weight); - already_added_arrays.insert(input); } - // Add nodes and edges for all outputs of the operator. + // Add edges for all outputs of the operator. for (const auto& output : op.outputs) { if (!model.HasArray(output)) { // Arrays should _always_ exist. Except, perhaps, during development. continue; } auto array_properties = GetPropertiesForArray(model, output); - if (!already_added_arrays.count(output)) { - AppendF(output_file_contents, kNodeFormat, output, - array_properties.label, "octagon", - array_properties.color.FillColorString().c_str(), - array_properties.color.TextColorString().c_str()); - } - // See comments above regarding weight and line_width calculations. float line_width = std::max(0.5f, array_properties.log2_buffer_size / 3.0f); @@ -367,7 +366,6 @@ void DumpGraphviz(const Model& model, string* output_file_contents) { } AppendF(output_file_contents, kEdgeFormat, operator_id, output, line_width, weight); - already_added_arrays.insert(output); } } diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index f5157149afca17383a8625c489f15a23ce6dd224..c7c80ab21cc88ecb94e5234750431130dbd8400c 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -494,7 +494,7 @@ void ConvertTransposeConvOperator(const Model& model, const auto& weights_array = model.GetArray(weights_array_name); CHECK(weights_array.buffer->type == ArrayDataType::kFloat); ConvertFloatTensorConst(model, weights_array_name, AxesOrder::kOHWI, - AxesOrder::kHWIO, tensorflow_graph); + AxesOrder::kHWOI, tensorflow_graph); auto& strides = (*conv2d_op->mutable_attr())["strides"]; strides.mutable_list()->add_i(1); strides.mutable_list()->add_i(src_op.stride_height); @@ -1728,6 +1728,25 @@ void ConvertComparisonOperator(const Model& model, const Operator& src_op, (*comparison_op->mutable_attr())["T"].set_type(data_type); } +void ConvertSparseToDenseOperator(const Model& model, + const SparseToDenseOperator& src_op, + const char* op_name, + GraphDef* tensorflow_graph) { + auto* sparse_to_dense_op = tensorflow_graph->add_node(); + sparse_to_dense_op->set_op(op_name); + sparse_to_dense_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 4); + for (int i = 0; i < 4; ++i) { + *sparse_to_dense_op->add_input() = src_op.inputs[i]; + } + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[3]); + (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type); + const auto index_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type); + (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b( + src_op.validate_indices); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1919,6 +1938,10 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertRandomUniformOperator( model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowEqual) { + ConvertComparisonOperator(model, src_op, "Equal", tensorflow_graph); + } else if (src_op.type == OperatorType::kTensorFlowNotEqual) { + ConvertComparisonOperator(model, src_op, "NotEqual", tensorflow_graph); } else if (src_op.type == OperatorType::kTensorFlowGreater) { ConvertComparisonOperator(model, src_op, "Greater", tensorflow_graph); } else if (src_op.type == OperatorType::kTensorFlowGreaterEqual) { diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md index 7680cdd344814bf6cbc7bbe11c915f220642d55d..8e93f02ef109f7bccd07ce54baff3d0bb4ae50c7 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md @@ -26,8 +26,6 @@ Table of contents: * [Convert a TensorFlow Lite FlatBuffer back into TensorFlow GraphDef format](#to-graphdef) * [Logging](#logging) - * [Standard logging](#standard-logging) - * [Verbose logging](#verbose-logging) * [Graph "video" logging](#graph-video-logging) * [Graph visualizations](#graph-visualizations) * [Using --output_format=GRAPHVIZ_DOT](#using-output-formatgraphviz-dot) @@ -277,49 +275,6 @@ bazel run --config=opt \ ## Logging -### Standard logging - -The converter generates some informative log messages during processing. The -easiest way to view them is to add `--logtostderr` to command lines as seen in -the following example. - -``` -curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \ - | tar xzv -C /tmp -bazel run --config=opt \ - //tensorflow/contrib/lite/toco:toco -- \ - --input_file=/tmp/mobilenet_v1_0.50_128/frozen_graph.pb \ - --output_file=/tmp/foo.tflite \ - --input_format=TENSORFLOW_GRAPHDEF \ - --output_format=TFLITE \ - --inference_type=FLOAT \ - --input_shape=1,128,128,3 \ - --input_array=input \ - --output_array=MobilenetV1/Predictions/Reshape_1 \ - --logtostderr -``` - -After some initialization messages, we get the following informative messages: - -``` -I1101 21:51:33.297475 5339 graph_transformations.cc:39] Before general graph transformations: 416 operators, 583 arrays (0 quantized) -I1101 21:51:33.308972 5339 graph_transformations.cc:39] After general graph transformations pass 1: 31 operators, 89 arrays (0 quantized) -I1101 21:51:33.309204 5339 graph_transformations.cc:39] Before dequantization graph transformations: 31 operators, 89 arrays (0 quantized) -I1101 21:51:33.309368 5339 allocate_transient_arrays.cc:312] Total transient array allocated size: 1048576 bytes, theoretical optimal value: 786432 bytes. -I1101 21:51:33.309484 5339 toco_tooling.cc:249] Estimated count of arithmetic ops: 0.099218 billion (note that a multiply-add is counted as 2 ops). -``` - -### Verbose logging - -For debugging purposes, the converter supports two levels of verbose logging, -which can be set by passing a `--v=` flag: - -* For `--v=1`, the converter generates text dumps of the graph at various - points during processing as well as log messages about every graph - transformation that took place. -* For `--v=2`, the converter additionally generates log messages about graph - transformations that were considered but not performed. - ### Graph "video" logging When `--dump_graphviz=` is used (see the section on [graph diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md index 9e99287f828c22aa81eb216c087f3261e378fc14..8085ae07489816c38677ff792e7ac71f1a75fa71 100644 --- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md +++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md @@ -203,17 +203,11 @@ have. graph transformations on them, at the cost of no longer faithfully matching inference and training arithmetic. -## Logging flags - -The following are standard Google logging flags: +* `--quantize_weights`. Type: boolean. Default: false. Store weights as + quantized weights followed by dequantize operations. Computation is still + done in float, but reduces model size (at the cost of accuracy and latency). -* `--logtostderr` redirects Google logging to standard error, typically making - it visible in a terminal. -* `--v` sets verbose logging levels (for debugging purposes). Defined levels: - * `--v=1`: log all graph transformations that did make a change on the - graph. - * `--v=2`: log all graph transformations that did *not* make a change on - the graph. +## Logging flags The following flags allow to generate graph visualizations of the actual graph at various points during transformations: diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md index f0fd638a618c75c75d336a746f9b1d8dccaea470..a7841a685528fb18bb08f1943278339a2daec16a 100644 --- a/tensorflow/contrib/lite/toco/g3doc/python_api.md +++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md @@ -1,69 +1,202 @@ -# TensorFlow Lite Optimizing Converter (TOCO) Python API reference +# TensorFlow Lite Optimizing Converter & Interpreter Python API reference -This page provides examples on how to use TOCO via the Python API. It is -complemented by the following documents: +This page provides examples on how to use TOCO and the TensorFlow Lite +interpreter via the Python API. It is complemented by the following documents: * [README](../README.md) * [Command-line examples](cmdline_examples.md) * [Command-line glossary](cmdline_reference.md) +Table of contents: + +* [High-level overview](#high-level-overview) +* [API](#api) +* [Basic examples](#basic) + * [Exporting a GraphDef from tf.Session](#basic-graphdef-sess) + * [Exporting a GraphDef from file](#basic-graphdef-file) + * [Exporting a SavedModel](#basic-savedmodel) +* [Complex examples](#complex) + * [Exporting a quantized GraphDef](#complex-quant) +* [TensorFlow Lite Python interpreter](#interpreter) + * [Using the interpreter from a model file](#interpreter-file) + * [Using the interpreter from model data](#interpreter-data) + ## High-level overview While the TensorFlow Lite Optimizing Converter can be used from the command -line, it is often convenient to use it as part of Python model build and +line, it is often convenient to use it as part of a Python model build and training script. This is so that conversion can be part of your model development pipeline. This allows you to know early and often that you are designing a model that can be targeted to devices with mobile. ## API -In Python you can run `help(tf.contrib.lite)` to get documentation on functions. -In particular, `tf.contrib.lite.toco_convert` presents a simple API and -`tf.contrib.lite.toco_from_protos` allows more detailed control of TOCO using -the protobuf interface to TOCO. +The API for converting TensorFlow models to TensorFlow Lite is +`tf.contrib.lite.TocoConverter`. The API for calling the Python intepreter is +`tf.contrib.lite.Interpreter`. + +`TocoConverter` provides class methods based on the original format of the +model. `TocoConverter.from_session()` is available for GraphDefs. +`TocoConverter.from_saved_model()` is available for SavedModels. Example usages +for simple float-point models are shown in [Basic Examples](#basic). Examples +usages for more complex models is shown in [Complex Examples](#complex). + +**NOTE**: Currently, `TocoConverter` will cause a fatal error to the Python +interpreter when the conversion fails. This will be remedied as soon as +possible. + +## Basic examples + +The following section shows examples of how to convert a basic float-point model +from each of the supported data formats into a TensorFlow Lite FlatBuffers. -## Example +### Exporting a GraphDef from tf.Session -In particular, here we show creating a simple model and converting it to a -TensorFlow Lite Model. +The following example shows how to convert a TensorFlow GraphDef into a +TensorFlow Lite FlatBuffer from a `tf.Session` object. ```python import tensorflow as tf img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) -val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +var = tf.get_variable("weights", dtype=tf.float32, shape=(1, 64, 64, 3)) +val = img + var out = tf.identity(val, name="out") + with tf.Session() as sess: - tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) - open("test.tflite", "wb").write(tflite_model) + converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) ``` -**NOTE** Currently, the TOCO command will cause a fatal error to the Python -interpreter when TOCO conversion fails. This will be remedied as soon as -possible. +### Exporting a GraphDef from file + +The following example shows how to convert a TensorFlow GraphDef into a +TensorFlow Lite FlatBuffer when the GraphDef is stored in a file. Both `.pb` and +`.pbtxt` files are accepted. + +The example uses +[Mobilenet_1.0_224](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz). +The function only supports GraphDefs frozen via +[freeze_graph.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py). + +```python +import tensorflow as tf + +graph_def_file = "/path/to/Downloads/mobilenet_v1_1.0_224/frozen_graph.pb" +input_arrays = ["input"] +output_arrays = ["MobilenetV1/Predictions/Softmax"] + +converter = tf.contrib.lite.TocoConverter.from_frozen_graph( + graph_def_file, input_arrays, output_arrays) +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + +### Exporting a SavedModel + +The following example shows how to convert a SavedModel into a TensorFlow Lite +FlatBuffer. + +```python +import tensorflow as tf + +converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir) +tflite_model = converter.convert() +open("converted_model.tflite", "wb").write(tflite_model) +``` + +For more complex SavedModels, the optional parameters that can be passed into +`TocoConverter.from_saved_model()` are `input_arrays`, `input_shapes`, +`output_arrays`, `tag_set` and `signature_key`. Details of each parameter are +available by running `help(tf.contrib.lite.TocoConverter)`. -## Example 2: Export with variables +## Complex examples -If a model has variables, they need to be turned into constants. This process is -known as freezing, and it can actually be accomplished with +For models where the default value of the attributes is not sufficient, the +attribute's values should be set before calling `convert()`. In order to call +any constants use `tf.contrib.lite.constants.` as seen below with +`QUANTIZED_UINT8`. Run `help(tf.contrib.lite.TocoConverter)` in the Python +terminal for detailed documentation on the attributes. + +Although the examples are demonstrated on GraphDefs containing only constants. +The same logic can be applied irrespective of the input data format. + +### Exporting a quantized GraphDef + +The following example shows how to convert a quantized model into a TensorFlow +Lite FlatBuffer. ```python import tensorflow as tf img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) -var = tf.get_variable("weights", dtype=tf.float32, shape=(1,64,64,3)) -val = img + var +const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +val = img + const +out = tf.fake_quant_with_min_max_args(val, min=0., max=1., name="output") + +with tf.Session() as sess: + converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + converter.inference_type = tf.contrib.lite.constants.QUANTIZED_UINT8 + input_arrays = converter.get_input_arrays() + converter.quantized_input_stats = {input_arrays[0] : (0., 1.)} # mean, std_dev + tflite_model = converter.convert() + open("converted_model.tflite", "wb").write(tflite_model) +``` + +## TensorFlow Lite Python interpreter + +### Using the interpreter from a model file + +The following example shows how to use the TensorFlow Lite Python interpreter +when provided a TensorFlow Lite FlatBuffer file. The example also demonstrates +how to run inference on random input data. Run +`help(tf.contrib.lite.Interpreter)` in the Python terminal to get detailed +documentation on the interpreter. + +```python +import numpy as np +import tensorflow as tf -def canonical_name(x): - return x.name.split(":")[0] +# Load TFLite model and allocate tensors. +interpreter = tf.contrib.lite.Interpreter(model_path="converted_model.tflite") +interpreter.allocate_tensors() +# Get input and output tensors. +input_details = interpreter.get_input_details() +output_details = interpreter.get_output_details() + +# Test model on random input data. +input_shape = input_details[0]['shape'] +input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) +interpreter.set_tensor(input_details[0]['index'], input_data) + +interpreter.invoke() +output_data = interpreter.get_tensor(output_details[0]['index']) +print(output_data) +``` + +### Using the interpreter from model data + +The following example shows how to use the TensorFlow Lite Python interpreter +when starting with the TensorFlow Lite Flatbuffer model previously loaded. This +example shows an end-to-end use case, starting from building the TensorFlow +model. + +```python +import numpy as np +import tensorflow as tf + +img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) +const = tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +val = img + const out = tf.identity(val, name="out") + with tf.Session() as sess: - sess.run(tf.global_variables_initializer()) - out_tensors = [out] - frozen_graphdef = tf.graph_util.convert_variables_to_constants( - sess, sess.graph_def, map(canonical_name, out_tensors)) - tflite_model = tf.contrib.lite.toco_convert( - frozen_graphdef, [img], out_tensors) - open("converted_model.tflite", "wb").write(tflite_model) + converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out]) + tflite_model = converter.convert() + +# Load TFLite model and allocate tensors. +interpreter = tf.contrib.lite.Interpreter(model_content=tflite_model) +interpreter.allocate_tensors() ``` diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc index 076415ece8c1039caa32e947fe54ab3e101bec9e..8ca2cd66ac6377a70a4c504ac006dc0388b88bf7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -46,8 +46,9 @@ bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) { const int kheight = weights_shape.dims(1); const int kwidth = weights_shape.dims(2); if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 && - conv_op->stride_height == 1) { - // 1x1 unstrided conv does not need an im2col array. + conv_op->stride_height == 1 && conv_op->dilation_width_factor == 1 && + conv_op->dilation_height_factor == 1) { + // 1x1 unstrided undilated conv does not need an im2col array. return false; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 8da242aa9c2ca4917a681c95c3eded894664c046..1bc7557d46cfa5e1b27468d2da271e75fd491d36 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -139,6 +139,7 @@ DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits); DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes) DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax) DECLARE_GRAPH_TRANSFORMATION(Quantize) +DECLARE_GRAPH_TRANSFORMATION(QuantizeWeights) DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp) DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert) DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index d63ee7c9519d169a2f44ec1afe81125217db8976..bda6dce22be0f0ca83eb8339ad17573b0267c18c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -362,6 +362,8 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForAverageOrMaxPool(model, op); break; + case OperatorType::kResizeBilinear: + case OperatorType::kSlice: case OperatorType::kStridedSlice: case OperatorType::kSqueeze: case OperatorType::kTensorFlowReshape: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc index ae3301f467de5714230e731b4bab87ddc1637201..d49857cfc22ecaf5feb06b39a42187f8adb61d50 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -90,12 +90,13 @@ bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { } // Conv Op - ConvOperator* conv_op = dynamic_cast( - has_expand_op ? GetOpWithInput(*model, post_stb_op->outputs[0]) - : GetOpWithInput(*model, stb_op->outputs[0])); - if (!conv_op || conv_op->type != OperatorType::kConv) { + const string& input_of_conv_op = + has_expand_op ? post_stb_op->outputs[0] : stb_op->outputs[0]; + auto* conv_base_op = GetOpWithInput(*model, input_of_conv_op); + if (conv_base_op->type != OperatorType::kConv) { return false; } + auto* conv_op = static_cast(conv_base_op); if (conv_op->inputs.size() != 2) { // The conv op must only have weights, no bias. return false; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc index 3f768bfee12ebe31ebeb72855eb67ec03d5bcf8c..5b6a984ee143a6007471b165510030cd3ad3f73c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc @@ -33,9 +33,10 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { return false; } - // Already a compact LstmCell with LstmCellOperator::NUM_INPUTS of inputs, - // do not need to merge cell inputs. - if (src_op->inputs.size() == LstmCellOperator::NUM_INPUTS) { + // Already a compact LstmCell. Do not need to merge cell inputs. + const auto* src_lstm_op = static_cast(src_op); + if (src_lstm_op->kernel_type != LstmCellOperator::KERNEL_FULL || + src_lstm_op->inputs.size() != kExtendedLstmInputCount) { return false; } @@ -136,6 +137,7 @@ bool MergeLstmCellInputs::Run(Model* model, std::size_t op_index) { // Emplace a new LSTM cell operator (use basic 5 inputs kernel). auto lstm_cell_op = absl::make_unique(); + lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_BASIC; // Compact LstmCell's 5 inputs. lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc index 8e66323bd769ca166d6b521c5b7b2f1cb944b0a2..e6e3dfa1de9c9fdd5e759fd547d11a7b8c95d837 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -33,9 +33,10 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { return false; } - // Already an extended LstmCell with kExtendedLstmInputCount of inputs, - // do not need to split cell inputs. - if (curr_op->inputs.size() == kExtendedLstmInputCount) { + const auto* curr_lstm_op = static_cast(curr_op); + // Already an extended LstmCell. Do not need to split cell inputs. + if (curr_lstm_op->kernel_type != LstmCellOperator::KERNEL_BASIC || + curr_lstm_op->inputs.size() != LstmCellOperator::NUM_INPUTS) { return false; } @@ -56,6 +57,7 @@ bool SplitLstmCellInputs::Run(Model* model, std::size_t op_index) { // Emplace a new LstmCell operator with extended inputs (kernel/lstm.cc). auto lstm_cell_op = absl::make_unique(); + lstm_cell_op->kernel_type = LstmCellOperator::KERNEL_FULL; lstm_cell_op->inputs.resize(kExtendedLstmInputCount); int num_input = model->GetArray(curr_op->inputs[LstmCellOperator::DATA_INPUT]) .shape() diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 6342cf3e8af4d85ad869a5d60a63d62ca2b00588..92d283ca2cc7069f4b80c95ffdadcad81a884cbf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -60,6 +60,8 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { case OperatorType::kTensorFlowLessEqual: case OperatorType::kTensorFlowGreater: case OperatorType::kTensorFlowGreaterEqual: + case OperatorType::kTensorFlowEqual: + case OperatorType::kTensorFlowNotEqual: // These operators unconditionally produce bool outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); break; @@ -163,6 +165,16 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, data_type_x); break; } + case OperatorType::kSparseToDense: { + // Select produces outputs with the same type as their 3rd input + CHECK_EQ(op->inputs.size(), 4); + const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type; + const ArrayDataType data_type_default = + model->GetArray(op->inputs[3]).data_type; + CHECK(data_type == data_type_default); + SetDataTypeForAllOutputs(model, op, data_type); + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 9d1d27f3ef01a572c2ae232b1f172a8e05374381..170a499d4eeea6f38704dadd5274e52da7ae2817 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -278,7 +278,7 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { << "TransposeConv input shape must have 4 dimensions. Input \"" << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape " << toco::ShapeToString(weights_shape) << "."; - CHECK_EQ(input_shape.dims(3), weights_shape.dims(0)) + CHECK_EQ(input_shape.dims(3), weights_shape.dims(3)) << "Input shape depth and weight depth do not agree"; // Set the output shape according to the specified output shape. @@ -1477,6 +1477,34 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { *output_array.mutable_shape()->mutable_dims() = output_dims; } +void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) { + CHECK_EQ(op->inputs.size(), 4); + + const Array& output_shape_array = model->GetArray(op->inputs[1]); + if (!output_shape_array.has_shape()) return; + CHECK_EQ(output_shape_array.shape().dimensions_count(), 1); + + // Output should not go over four dimensions. + CHECK_LE(output_shape_array.shape().dims(0), 4); + + const string& output_name = op->outputs[0]; + Array& output_array = model->GetArray(output_name); + if (output_array.has_shape()) return; + + CHECK(output_shape_array.data_type == ArrayDataType::kInt32 || + output_shape_array.data_type == ArrayDataType::kInt64); + if (output_shape_array.data_type == ArrayDataType::kInt32) { + *output_array.mutable_shape()->mutable_dims() = + output_shape_array.GetBuffer().data; + } else { + const std::vector& output_shape_data = + output_shape_array.GetBuffer().data; + std::copy( + output_shape_data.begin(), output_shape_data.end(), + std::back_inserter(*output_array.mutable_shape()->mutable_dims())); + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1535,6 +1563,8 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTensorFlowMaximum: case OperatorType::kTensorFlowMinimum: case OperatorType::kTensorFlowGreaterEqual: + case OperatorType::kTensorFlowEqual: + case OperatorType::kTensorFlowNotEqual: ProcessSimpleBinaryOperator(model, op); break; case OperatorType::kAddN: @@ -1700,6 +1730,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->inputs.size(), 1); ProcessOpWithShapeInput(model, op); break; + case OperatorType::kSparseToDense: + ProcessSparseToDenseOperator(model, + static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index 142841fcc460e8a5e9e4f2333496f4ece2557275..eca2c701f8bbf889088794c939af7082db0734dd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -45,12 +45,14 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kTensorFlowMinimum || type == OperatorType::kTensorFlowMaximum || type == OperatorType::kLogistic || type == OperatorType::kSoftmax || - type == OperatorType::kLogSoftmax || + type == OperatorType::kLogSoftmax || type == OperatorType::kSlice || + type == OperatorType::kResizeBilinear || type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub || type == OperatorType::kSqueeze || type == OperatorType::kPad || type == OperatorType::kPadV2 || type == OperatorType::kTensorFlowReshape || type == OperatorType::kTanh || type == OperatorType::kMul || + type == OperatorType::kSpaceToBatchND || type == OperatorType::kSpaceToDepth || type == OperatorType::kStridedSlice || type == OperatorType::kDepthToSpace || @@ -60,7 +62,7 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kTensorFlowGreaterEqual || type == OperatorType::kTensorFlowLess || type == OperatorType::kTensorFlowLessEqual || - type == OperatorType::kSelect; + type == OperatorType::kSelect || type == OperatorType::kArgMax; } const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc new file mode 100644 index 0000000000000000000000000000000000000000..88ea0945e7dd15ba325d34ea3fdbf34ff7d91381 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +// The minimum number of elements a weights array must have to be quantized +// by this transformation. +// TODO(suharshs): Make this minimum size configurable. +const int kWeightsMinSize = 1024; + +// Gets the quantization params from the float array. +void GetQuantizationParamsFromArray(const Array& array, + QuantizationParams* params) { + const std::vector& float_vals = + array.GetBuffer().data; + auto minmax = std::minmax_element(float_vals.begin(), float_vals.end()); + MinMax toco_minmax; + toco_minmax.min = *minmax.first; + toco_minmax.max = *minmax.second; + GetQuantizationParams(ArrayDataType::kUint8, toco_minmax, params); +} + +} // namespace + +bool QuantizeWeights::Run(Model* model, std::size_t op_index) { + const auto op_it = model->operators.begin() + op_index; + Operator* op = op_it->get(); + + // Get the weights tensor, if the current operator has one. + int weights_index; + if (op->type == OperatorType::kConv || + op->type == OperatorType::kDepthwiseConv || + op->type == OperatorType::kFullyConnected) { + weights_index = 1; + } else if (op->type == OperatorType::kLstmCell) { + weights_index = LstmCellOperator::WEIGHTS_INPUT; + } else { + return false; + } + + // Return early if the array isn't a constant param, this can happen in early + // transformation passes until transpose operations following the weight array + // are resolved. + const string weights = op->inputs[weights_index]; + if (!IsConstantParameterArray(*model, weights)) { + return false; + } + + // Return early if the weight tensor is not type float. + Array& weights_array = model->GetArray(weights); + if (weights_array.data_type != ArrayDataType::kFloat) { + return false; + } + + // Return early if the tensor is too small. Small tensors don't take up too + // much space and can result in bad quantization results. + if (weights_array.GetBuffer().data.size() < + kWeightsMinSize) { + return false; + } + + // Quantize the weight tensor to type kUint8. + QuantizationParams params; + GetQuantizationParamsFromArray(weights_array, ¶ms); + QuantizeArray(this, model, weights, ArrayDataType::kUint8, params); + + // Insert a Dequantize operation after the quantized weights tensor. + auto* dequantize_op = new DequantizeOperator; + model->operators.emplace(op_it, dequantize_op); + + // Create a new intermediate tensor to connect the Dequantize op to the + // original op. + const string dequantized_output = + AvailableArrayName(*model, weights + "_dequantized"); + Array& dequantized_output_array = model->GetOrCreateArray(dequantized_output); + dequantized_output_array.data_type = ArrayDataType::kFloat; + + // Connect up the new Dequantize op with the weights and original op. + op->inputs[weights_index] = dequantized_output; + dequantize_op->inputs = {weights}; + dequantize_op->outputs = {dequantized_output}; + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD index 8dcd4adc90b188c745cadb9815c3c46383705833..95e8433be2a332cfce5175f4f65ea0b83d5638c5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD @@ -8,8 +8,8 @@ load( ) tf_cc_test( - name = "resolve_constant_concatenation_test", - srcs = ["resolve_constant_concatenation_test.cc"], + name = "lstm_utils_test", + srcs = ["lstm_utils_test.cc"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", @@ -19,8 +19,20 @@ tf_cc_test( ) tf_cc_test( - name = "lstm_utils_test", - srcs = ["lstm_utils_test.cc"], + name = "quantize_weights_test", + srcs = ["quantize_weights_test.cc"], + deps = [ + "//tensorflow/contrib/lite/toco:graph_transformations", + "//tensorflow/contrib/lite/toco:model", + "//tensorflow/contrib/lite/toco:tooling_util", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + ], +) + +tf_cc_test( + name = "resolve_constant_concatenation_test", + srcs = ["resolve_constant_concatenation_test.cc"], deps = [ "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c05eb0929fd775d315fa735b4c9842a7fc024fa8 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include +#include +#include "absl/memory/memory.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +class QuantizeWeightsTest : public ::testing::Test { + protected: + QuantizeWeightsTest() {} + + // The name of the weights input array. + const string kWeightsName = "weights"; + // The zero_point of the values in the input array. + const int kZeroPoint = 128; + + // Prepare a hypothetical TOCO model of a quantizable fully connected float + // layer. + void PrepareModel(Model* model, int elements_per_dim) { + std::vector fc_input_names = {"inputs", kWeightsName}; + + const int kDim = 4; + const int buf_size = std::pow(elements_per_dim, static_cast(kDim)); + auto in_buf = absl::make_unique(buf_size); + // Initialize the array with values from -128.0 to 127.0, since these values + // should be exactly representable by quantization. + for (int i = 0; i < buf_size; i++) { + in_buf[i] = static_cast(i % 256 - kZeroPoint); + } + + for (const string& fc_input_name : fc_input_names) { + Array& in_array = model->GetOrCreateArray(fc_input_name); + in_array.data_type = ArrayDataType::kFloat; + + // Initialize shape for the input array. + Shape* in_array_shape = in_array.mutable_shape(); + std::vector* in_array_shape_dim = in_array_shape->mutable_dims(); + in_array_shape_dim->resize(kDim, elements_per_dim); + auto& in_array_buffer = + in_array.GetMutableBuffer(); + in_array_buffer.data.resize(buf_size); + float* buf_ptr = + in_array.GetMutableBuffer().data.data(); + std::copy(in_buf.get(), in_buf.get() + buf_size, buf_ptr); + } + + auto* fc_op = new FullyConnectedOperator; + fc_op->inputs = fc_input_names; + fc_op->outputs = {"fc_op_outputs"}; + Array& out_array = model->GetOrCreateArray(fc_op->outputs[0]); + out_array.data_type = ArrayDataType::kFloat; + Shape* out_array_shape = out_array.mutable_shape(); + std::vector* out_array_shape_dim = out_array_shape->mutable_dims(); + out_array_shape_dim->resize(kDim, elements_per_dim); + model->operators.push_back(std::unique_ptr(fc_op)); + } +}; + +TEST_F(QuantizeWeightsTest, QuantizedFullyConnected) { + // Test that weight arrays that are large enough are quantized. + Model model; + // 6 elements per dim gives us 1296 elements, which is sufficient to be + // quantized. + PrepareModel(&model, 6); + + // Check the state of the graph before the transformation. + const auto& float_array_map = model.GetArrayMap(); + EXPECT_EQ(float_array_map.size(), 3); + // Before the transformation, all arrays should be type float. + for (const auto& element : float_array_map) { + EXPECT_EQ(element.second->data_type, ArrayDataType::kFloat); + } + const std::vector float_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + + // Invoke the transformation. + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::QuantizeWeights); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + + // Check the state of the graph after the transformation. + const auto& quantized_array_map = model.GetArrayMap(); + EXPECT_EQ(quantized_array_map.size(), 4); + // After the transformation, three arrays should be type float and one array + // should be uint8. + int num_float = 0; + int num_uint8 = 0; + for (const auto& element : quantized_array_map) { + if (element.second->data_type == ArrayDataType::kFloat) { + num_float++; + } else if (element.second->data_type == ArrayDataType::kUint8) { + num_uint8++; + } else { + FAIL() << "Unexpected array type."; + } + } + EXPECT_EQ(num_float, 3); + EXPECT_EQ(num_uint8, 1); + // Ensure that the values were quantized correctly. + const std::vector& quantized_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + for (int i = 0; i < quantized_weight_vals.size(); i++) { + EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i] + kZeroPoint); + } + + // Ensure that a Dequantize operator has been inserted before the + // FullyConnectedLayer. + EXPECT_EQ(model.operators[0]->type, OperatorType::kDequantize); +} + +TEST_F(QuantizeWeightsTest, NotQuantizedFullyConnected) { + // Test that weight arrays that are too small are left untouched. + Model model; + // 5 elements per dim gives us 625 elements, which is NOT sufficient to be + // quantized. + PrepareModel(&model, 5); + + // Check the state of the graph before the transformation. + const auto& float_array_map = model.GetArrayMap(); + EXPECT_EQ(float_array_map.size(), 3); + // Before the transformation, all arrays should be type float. + for (auto it = float_array_map.begin(); it != float_array_map.end(); it++) { + EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat); + } + std::vector float_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + + // Invoke the transformation. + GraphTransformationsSet graph_transformation_set; + graph_transformation_set.Add(new toco::QuantizeWeights); + (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0); + + // Check the state of the graph after the transformation. + const auto& post_array_map = model.GetArrayMap(); + EXPECT_EQ(post_array_map.size(), 3); + for (auto it = post_array_map.begin(); it != post_array_map.end(); it++) { + EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat); + } + // Ensure that the values remain unchanged. + std::vector const& quantized_weight_vals = + model.GetArray(kWeightsName).GetBuffer().data; + for (int i = 0; i < quantized_weight_vals.size(); i++) { + EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i]); + } +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc index 3a1d175b9823f085c9b8730caba8bedd7eb87d52..66cfed4ac26969729d1881f11ba6ae74d9817fb5 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -12,9 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include #include #include @@ -126,7 +124,7 @@ class ResolveConstantConcatenationTest : public ::testing::Test { Array& in_array = model->GetOrCreateArray(concat_input_name); in_array.data_type = ArrayDataType::kFloat; - // Initialize shape for the input array. + // Initialize shape for the input array. Shape* in_array_shape = in_array.mutable_shape(); std::vector* in_array_shape_dim = in_array_shape->mutable_dims(); for (int i = 0; i < kDim; i++) { diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index ea051bb84ac1b70612397b7a929cf9c5d82c59de..cd4f034dfea57b6d379b67a90ba4fa3fe3d615d5 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/tensorflow_graph_matching/resolve_cluster.h" #include "tensorflow/contrib/lite/toco/tensorflow_util.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function.h" @@ -44,6 +43,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/public/version.h" @@ -63,8 +63,6 @@ using tensorflow::TensorShapeProto; namespace toco { -using port::Status; - namespace { bool HasAttr(const NodeDef& node, const string& attr_name) { return node.attr().count(attr_name) > 0; @@ -130,6 +128,42 @@ const AttrValue::ListValue& GetListAttr(const NodeDef& node, return attr.list(); } +tensorflow::Status CheckOptionalAttr(const NodeDef& node, + const string& attr_name, + const string& expected_value) { + if (HasAttr(node, attr_name)) { + const string& value = GetStringAttr(node, attr_name); + if (value != expected_value) { + return tensorflow::errors::InvalidArgument( + "Unexpected value for attribute '" + attr_name + "'. Expected '" + + expected_value + "'"); + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Status CheckOptionalAttr( + const NodeDef& node, const string& attr_name, + const tensorflow::DataType& expected_value) { + if (HasAttr(node, attr_name)) { + const tensorflow::DataType& value = GetDataTypeAttr(node, attr_name); + if (value != expected_value) { + return tensorflow::errors::InvalidArgument( + "Unexpected value for attribute '" + attr_name + "'. Expected '" + + tensorflow::DataType_Name(expected_value) + "'"); + } + } + return tensorflow::Status::OK(); +} + +template +tensorflow::Status ExpectValue(const T1& v1, const T2& v2, + const string& description) { + if (v1 == v2) return tensorflow::Status::OK(); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Unexpected ", description, ": got ", v1, ", expected ", v2)); +} + ArrayDataType ConvertDataType(tensorflow::DataType dtype) { if (dtype == DT_UINT8) return ArrayDataType::kUint8; @@ -148,9 +182,10 @@ ArrayDataType ConvertDataType(tensorflow::DataType dtype) { return ArrayDataType::kNone; } -Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< - tensorflow::TensorShapeProto_Dim>& input_dims, - int* input_flat_size, Shape* shape) { +tensorflow::Status ImportShape( + const TFLITE_PROTO_NS::RepeatedPtrField& + input_dims, + int* input_flat_size, Shape* shape) { std::vector input_dims_only_sizes; for (auto& d : input_dims) { if (d.size() == 0) { @@ -160,23 +195,24 @@ Status ImportShape(const TFLITE_PROTO_NS::RepeatedPtrField< // For now, tweaking this to record a 0-D shape instead. shape->mutable_dims()->clear(); if (input_flat_size != nullptr) *input_flat_size = 0; - return Status::OK(); + return tensorflow::Status::OK(); } // TensorFlow's shapes use int64s, while TOCO uses ints. if (d.size() > std::numeric_limits::max()) { - return Status(false, "Shape element overflows"); + return tensorflow::errors::InvalidArgument("Shape element overflows"); } input_dims_only_sizes.push_back(d.size()); } *shape->mutable_dims() = input_dims_only_sizes; - if (input_flat_size == nullptr) return Status::OK(); + if (input_flat_size == nullptr) return tensorflow::Status::OK(); return NumElements(input_dims_only_sizes, input_flat_size); } -Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportFloatArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_FLOAT); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -203,18 +239,18 @@ Status ImportFloatArray(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_float_data.data())); } else { - return Status( - false, + return tensorflow::errors::InvalidArgument( absl::StrCat("Neither input_content (", input_tensor.tensor_content().size() / sizeof(float), ") nor float_val (", input_tensor.float_val_size(), ") have the right dimensions (", input_flat_size, ") for this float tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportQuint8Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_QUINT8); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -236,18 +272,18 @@ Status ImportQuint8Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status( - false, + return tensorflow::errors::InvalidArgument( absl::StrCat("Neither input_content (", input_tensor.tensor_content().size() / sizeof(uint8_t), ") nor int_val (", input_tensor.int_val_size(), ") have the right dimensions (", input_flat_size, ") for this uint8 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportInt32Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT32); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -269,18 +305,17 @@ Status ImportInt32Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status( - false, - absl::StrCat("Neither input_content (", - input_tensor.tensor_content().size() / sizeof(int32), - ") nor int_val (", input_tensor.int_val_size(), - ") have the right dimensions (", input_flat_size, - ") for this int32 tensor")); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Neither input_content (", + input_tensor.tensor_content().size() / sizeof(int32), ") nor int_val (", + input_tensor.int_val_size(), ") have the right dimensions (", + input_flat_size, ") for this int32 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportInt64Array(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_INT64); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -302,18 +337,18 @@ Status ImportInt64Array(const TensorProto& input_tensor, Array* output_array) { toco::port::CopyToBuffer(input_tensor.tensor_content(), reinterpret_cast(output_int_data.data())); } else { - return Status( - false, + return tensorflow::errors::InvalidArgument( absl::StrCat("Neither input_content (", input_tensor.tensor_content().size() / sizeof(int64), ") nor int64_val (", input_tensor.int64_val_size(), ") have the right dimensions (", input_flat_size, ") for this int64 tensor")); } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportBoolArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_BOOL); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -343,19 +378,19 @@ Status ImportBoolArray(const TensorProto& input_tensor, Array* output_array) { // So far only encountered that in an array with 1 entry, let's // require that until we encounter a graph where that's not the case. if (output_bool_data.size() != 1) { - return Status( - false, absl::StrCat("Neither input_content (", - input_tensor.tensor_content().size(), - ") nor bool_val (", input_tensor.bool_val_size(), - ") have the right dimensions (", input_flat_size, - ") for this bool tensor")); + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Neither input_content (", input_tensor.tensor_content().size(), + ") nor bool_val (", input_tensor.bool_val_size(), + ") have the right dimensions (", input_flat_size, + ") for this bool tensor")); } output_bool_data[0] = false; } - return Status::OK(); + return tensorflow::Status::OK(); } -Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { +tensorflow::Status ImportStringArray(const TensorProto& input_tensor, + Array* output_array) { CHECK_EQ(input_tensor.dtype(), DT_STRING); const auto& input_shape = input_tensor.tensor_shape(); CHECK_LE(input_shape.dim_size(), 4); @@ -365,9 +400,9 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { if (!status.ok()) return status; if (input_flat_size != input_tensor.string_val_size()) { - return Status(false, - "Input_content string_val doesn't have the right dimensions " - "for this string tensor"); + return tensorflow::errors::InvalidArgument( + "Input_content string_val doesn't have the right dimensions " + "for this string tensor"); } auto& output_string_data = @@ -377,7 +412,7 @@ Status ImportStringArray(const TensorProto& input_tensor, Array* output_array) { for (int i = 0; i < input_flat_size; ++i) { output_string_data[i] = input_tensor.string_val(i); } - return Status::OK(); + return tensorflow::Status::OK(); } // Count the number of inputs of a given node. If @@ -417,14 +452,14 @@ string CreateConstArray(Model* model, string const& name, return array_name; } -Status ConvertConstOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConstOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Const"); const auto& tensor = GetTensorAttr(node, "value"); const auto dtype = GetDataTypeAttr(node, "dtype"); - Status status = Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); auto& array = model->GetOrCreateArray(node.name()); switch (dtype) { @@ -460,24 +495,21 @@ Status ConvertConstOperator(const NodeDef& node, array.GetMutableBuffer(); break; } - if (!status.ok()) { - status.AppendMessage(" (while processing node '" + node.name() + "')"); - } - return status; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + status, " (while processing node '" + node.name() + "')"); + return tensorflow::Status::OK(); } -void ConvertConvOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ConvertConvOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { CHECK_EQ(node.op(), "Conv2D"); CheckInputsCount(node, tf_import_flags, 2); // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. - if (HasAttr(node, "data_format")) { - CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC"); - } - CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "data_format", "NHWC")); + TF_RETURN_IF_ERROR(CheckOptionalAttr(node, "T", DT_FLOAT)); const auto& input_name = node.input(0); const auto& weights_name = node.input(1); @@ -502,27 +534,26 @@ void ConvertConvOperator(const NodeDef& node, auto* conv = new ConvOperator; conv->inputs = {input_name, reordered_weights_name}; conv->outputs = {node.name()}; + if (!HasAttr(node, "strides")) { + return tensorflow::errors::InvalidArgument("Missing attribute 'strides'"); + } const auto& strides = GetListAttr(node, "strides"); - CHECK_EQ(strides.i_size(), 4); - CHECK_EQ(strides.i(0), 1); - CHECK_EQ(strides.i(3), 1); + TF_RETURN_IF_ERROR(ExpectValue(strides.i_size(), 4, "number of strides")); + TF_RETURN_IF_ERROR(ExpectValue(strides.i(0), 1, "strides(0)")); + TF_RETURN_IF_ERROR(ExpectValue(strides.i(3), 1, "strides(3)")); conv->stride_height = strides.i(1); conv->stride_width = strides.i(2); if (HasAttr(node, "dilations")) { const auto& dilations = GetListAttr(node, "dilations"); - CHECK_EQ(dilations.i_size(), 4); - CHECK_EQ(dilations.i(0), 1) - << "Can only import Conv ops with dilation along the height (1st) or " - "width (2nd) axis. TensorFlow op \"" - << node.name() << "\" had dilations:[ " << dilations.i(0) << ", " - << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3) - << "]."; - CHECK_EQ(dilations.i(3), 1) - << "Can only import Conv ops with dilation along the height (1st) or " - "width (2nd) axis. TensorFlow op \"" - << node.name() << "\" had dilations:[ " << dilations.i(0) << ", " - << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3) - << "]."; + TF_RETURN_IF_ERROR( + ExpectValue(dilations.i_size(), 4, "number of dilations")); + if (dilations.i(0) != 1 || dilations.i(3) != 1) { + return tensorflow::errors::InvalidArgument(absl::StrCat( + "Can only import Conv ops with dilation along the height " + "(1st) or width (2nd) axis. TensorFlow op \"", + node.name(), "\" had dilations:[ ", dilations.i(0), ", ", + dilations.i(1), ", ", dilations.i(2), ", ", dilations.i(3), "].")); + } conv->dilation_height_factor = dilations.i(1); conv->dilation_width_factor = dilations.i(2); } else { @@ -535,9 +566,12 @@ void ConvertConvOperator(const NodeDef& node, } else if (padding == "VALID") { conv->padding.type = PaddingType::kValid; } else { - LOG(FATAL) << "Bad padding (only SAME and VALID are supported)"; + return tensorflow::errors::InvalidArgument( + "Bad padding (only SAME and VALID are supported)"); } model->operators.emplace_back(conv); + + return tensorflow::Status::OK(); } void ConvertDepthwiseConvOperator(const NodeDef& node, @@ -614,7 +648,14 @@ void ConvertSpaceToDepthOperator(const NodeDef& node, CHECK_EQ(node.op(), "SpaceToDepth"); CheckInputsCount(node, tf_import_flags, 1); - CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); + tensorflow::DataType dtype = GetDataTypeAttr(node, "T"); + if (dtype != DT_FLOAT && dtype != DT_UINT8 && dtype != DT_INT32 && + dtype != DT_INT64) { + const auto* enum_descriptor = tensorflow::DataType_descriptor(); + LOG(FATAL) << "TFLite does not support SpaceToDepth with type T:" + << enum_descriptor->FindValueByNumber(dtype)->name() << ". " + << "T must be one of {DT_FLOAT, DT_INT8, DT_INT32, DT_INT64}."; + } auto* op = new SpaceToDepthOperator; op->inputs.push_back(node.input(0)); op->outputs.push_back(node.name()); @@ -656,81 +697,6 @@ void ConvertRandomUniform(const NodeDef& node, model->operators.emplace_back(std::move(op)); } -void ConvertReluOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Relu"); - CheckInputsCount(node, tf_import_flags, 1); - const auto& input_name = node.input(0); - auto* relu = new ReluOperator; - relu->inputs.push_back(input_name); - relu->outputs.push_back(node.name()); - model->operators.emplace_back(relu); -} - -void ConvertRelu6Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Relu6"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new Relu6Operator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertLogOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Log"); - CheckInputsCount(node, tf_import_flags, 1); - - auto op = absl::make_unique(); - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(std::move(op)); -} - -void ConvertLogisticOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sigmoid"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new LogisticOperator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertTanhOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Tanh"); - CheckInputsCount(node, tf_import_flags, 1); - - const auto& input_name = node.input(0); - auto* op = new TanhOperator; - op->inputs.push_back(input_name); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertDivOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK(node.op() == "Div" || node.op() == "RealDiv"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new DivOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertIdentityOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -787,38 +753,6 @@ void ConvertFakeQuantWithMinMaxVars( model->operators.emplace_back(op); } -void ConvertNegOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Neg"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new NegOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertRsqrtOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Rsqrt"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowRsqrtOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSqrtOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sqrt"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowSqrtOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertSqueezeOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -840,66 +774,6 @@ void ConvertSqueezeOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertSquareOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Square"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowSquareOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAddOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Add"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new AddOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAddNOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "AddN"); - const int num_inputs = GetInputsCount(node, tf_import_flags); - auto* op = new AddNOperator; - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertMulOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Mul"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new MulOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSubOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sub"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new SubOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertSumOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -915,67 +789,6 @@ void ConvertSumOperator(const NodeDef& node, } } -void ConvertTileOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Tile"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowTileOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSliceOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Slice"); - CheckInputsCount(node, tf_import_flags, 3); - auto* op = new SliceOperator; - for (int i = 0; i < 3; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertPadOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Pad"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new PadOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertPadV2Operator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "PadV2"); - CheckInputsCount(node, tf_import_flags, 3); - auto* op = new PadV2Operator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->inputs.push_back(node.input(2)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertShapeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Shape"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new TensorFlowShapeOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertSplitOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -993,18 +806,6 @@ void ConvertSplitOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertMergeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Merge"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMergeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertSwitchOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1034,18 +835,6 @@ void ConvertSoftmaxOperator(const NodeDef& node, model->operators.emplace_back(softmax); } -void ConvertLogSoftmaxOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "LogSoftmax"); - CheckInputsCount(node, tf_import_flags, 1); - const auto& input_name = node.input(0); - auto* log_softmax = new LogSoftmaxOperator; - log_softmax->inputs.push_back(input_name); - log_softmax->outputs.push_back(node.name()); - model->operators.emplace_back(log_softmax); -} - void ConvertLRNOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1142,17 +931,6 @@ void ConvertAvgPoolOperator(const NodeDef& node, model->operators.emplace_back(avgpool); } -void ConvertReshapeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Reshape"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowReshapeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertBatchMatMulOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1215,24 +993,12 @@ void ConvertConcatOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertAllOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "All"); - auto* op = new TensorFlowAllOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertAssertOperator(const NodeDef& node, +// This method supports simple operators without additional attributes. +template +void ConvertSimpleOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK_EQ(node.op(), "Assert"); - auto* op = new TensorFlowAssertOperator; + auto* op = new Op; const int num_inputs = GetInputsCount(node, tf_import_flags); for (int i = 0; i < num_inputs; ++i) { op->inputs.push_back(node.input(i)); @@ -1241,69 +1007,13 @@ void ConvertAssertOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertLessOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Less"); - auto* op = new TensorFlowLessOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertLessEqualOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "LessEqual"); - auto* op = new TensorFlowLessEqualOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertSinOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Sin"); - auto* op = new SinOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertGreaterOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Greater"); - auto* op = new TensorFlowGreaterOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertGreaterEqualOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "GreaterEqual"); - auto* op = new TensorFlowGreaterEqualOperator; - const int num_inputs = GetInputsCount(node, tf_import_flags); - for (int i = 0; i < num_inputs; ++i) { - op->inputs.push_back(node.input(i)); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); +// This method supports simple operators without additional attributes. +template +void ConvertSimpleOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CheckInputsCount(node, tf_import_flags, NumInputs); + ConvertSimpleOperator(node, tf_import_flags, model); } void ConvertMaxOperator(const NodeDef& node, @@ -1336,29 +1046,6 @@ void ConvertMinOperator(const NodeDef& node, } } -void ConvertMaximumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Maximum"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMaximumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertMinimumOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Minimum"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TensorFlowMinimumOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertUnsupportedOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1387,19 +1074,6 @@ void ConvertUnsupportedOperator(const NodeDef& node, } } -void ConvertSelectOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CheckInputsCount(node, tf_import_flags, 3); - - auto* op = new SelectOperator; - for (const auto& input : node.input()) { - op->inputs.push_back(input); - } - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertStridedSliceOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1678,17 +1352,6 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertExpOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Exp"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new ExpOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertMeanOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1779,11 +1442,13 @@ void ConvertTransposeConvOperator(const NodeDef& node, if (existing_transpose) { CHECK(existing_transpose->type == OperatorType::kTranspose); } else { - // Transpose weights from HWIO order to OHWI order, which is more efficient - // for computation + // Transpose weights from HWOI order to OHWI order, which is more efficient + // for computation. (Note that TensorFlow considers the order as HWIO + // because they consider this a backward conv, inverting the sense of + // input/output.) TransposeOperator* transpose = new TransposeOperator; string perm_array = CreateConstArray( - model, node.name() + "_transpose_perm", {3, 0, 1, 2}); + model, node.name() + "_transpose_perm", {2, 0, 1, 3}); transpose->inputs = {weights_name, perm_array}; transpose->outputs = {transposed_weights_name}; model->operators.emplace_back(transpose); @@ -1802,53 +1467,6 @@ void ConvertTransposeConvOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertExpandDimsOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "ExpandDims"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new ExpandDimsOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFillOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Fill"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FillOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFloorDivOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "FloorDiv"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FloorDivOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - -void ConvertFloorModOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "FloorMod"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new FloorModOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} void ConvertRangeOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, @@ -1869,17 +1487,6 @@ void ConvertRangeOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertRankOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Rank"); - CheckInputsCount(node, tf_import_flags, 1); - auto* op = new RankOperator; - op->inputs.push_back(node.input(0)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} - void ConvertStackOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1900,17 +1507,6 @@ void ConvertStackOperator(const NodeDef& node, model->operators.emplace_back(op); } -void ConvertTransposeOperator(const NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Transpose"); - CheckInputsCount(node, tf_import_flags, 2); - auto* op = new TransposeOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); -} // Some TensorFlow ops only occur in graph cycles, representing // control flow. We do not currently support control flow, so we wouldn't @@ -2133,18 +1729,36 @@ void ConvertDynamicStitchOperator(const NodeDef& node, model->operators.emplace_back(op.release()); } +void ConvertSparseToDenseOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "SparseToDense"); + CheckInputsCount(node, tf_import_flags, 4); + + auto* op = new SparseToDenseOperator; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + op->outputs.push_back(node.name()); + + op->validate_indices = HasAttr(node, "validate_indices") + ? GetBoolAttr(node, "validate_indices") + : true; + model->operators.emplace_back(op); +} + } // namespace namespace internal { -Status ImportTensorFlowNode(const tensorflow::NodeDef& node, - const TensorFlowImportFlags& tf_import_flags, - Model* model) { +tensorflow::Status ImportTensorFlowNode( + const tensorflow::NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, Model* model) { // TODO(ahentz): Historically these functions all CHECK-fail on error. We've // been slowly converting them to return Status. if (node.op() == "Const") { return ConvertConstOperator(node, tf_import_flags, model); } else if (node.op() == "Conv2D") { - ConvertConvOperator(node, tf_import_flags, model); + return ConvertConvOperator(node, tf_import_flags, model); } else if (node.op() == "Conv2DBackpropInput") { ConvertTransposeConvOperator(node, tf_import_flags, model); } else if (node.op() == "DepthwiseConv2dNative") { @@ -2156,25 +1770,26 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "BiasAdd") { ConvertBiasAddOperator(node, tf_import_flags, model); } else if (node.op() == "Relu") { - ConvertReluOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Relu6") { - ConvertRelu6Operator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Sigmoid") { - ConvertLogisticOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Tanh") { - ConvertTanhOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "MaxPool") { ConvertMaxPoolOperator(node, tf_import_flags, model); } else if (node.op() == "AvgPool") { ConvertAvgPoolOperator(node, tf_import_flags, model); } else if (node.op() == "Reshape") { - ConvertReshapeOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "BatchMatMul") { ConvertBatchMatMulOperator(node, tf_import_flags, model); } else if (node.op() == "MatMul") { ConvertMatMulOperator(node, tf_import_flags, model); } else if (node.op() == "Div" || node.op() == "RealDiv") { - ConvertDivOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Identity" || node.op() == "CheckNumerics" || node.op() == "StopGradient") { ConvertIdentityOperator(node, tf_import_flags, model); @@ -2183,27 +1798,31 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "FakeQuantWithMinMaxArgs") { ConvertFakeQuantWithMinMaxArgs(node, tf_import_flags, model); } else if (node.op() == "Neg") { - ConvertNegOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Rsqrt") { - ConvertRsqrtOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Squeeze") { ConvertSqueezeOperator(node, tf_import_flags, model); } else if (node.op() == "Sqrt") { - ConvertSqrtOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Square") { - ConvertSquareOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Add") { - ConvertAddOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "AddN") { - ConvertAddNOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Mul") { - ConvertMulOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Sub") { - ConvertSubOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Sum") { ConvertSumOperator(node, tf_import_flags, model); } else if (node.op() == "Tile") { - ConvertTileOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Concat" || node.op() == "ConcatV2") { ConvertConcatOperator(node, tf_import_flags, model); } else if (node.op() == "LRN") { @@ -2211,41 +1830,50 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "Softmax") { ConvertSoftmaxOperator(node, tf_import_flags, model); } else if (node.op() == "Log") { - ConvertLogOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "LogSoftmax") { - ConvertLogSoftmaxOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "All") { - ConvertAllOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Assert") { - ConvertAssertOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Less") { - ConvertLessOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "LessEqual") { - ConvertLessEqualOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Greater") { - ConvertGreaterOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "GreaterEqual") { - ConvertGreaterEqualOperator(node, tf_import_flags, model); + ConvertSimpleOperator( + node, tf_import_flags, model); } else if (node.op() == "Max") { ConvertMaxOperator(node, tf_import_flags, model); } else if (node.op() == "Min") { ConvertMinOperator(node, tf_import_flags, model); } else if (node.op() == "Maximum") { - ConvertMaximumOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Minimum") { - ConvertMinimumOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Merge") { - ConvertMergeOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Pad") { - ConvertPadOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "PadV2") { - ConvertPadV2Operator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "StridedSlice") { ConvertStridedSliceOperator(node, tf_import_flags, model); } else if (node.op() == "Shape") { - ConvertShapeOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, + model); } else if (node.op() == "Slice") { - ConvertSliceOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Split") { ConvertSplitOperator(node, tf_import_flags, model); } else if (node.op() == "Switch") { @@ -2282,25 +1910,25 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "NextIteration") { ConvertOperatorSpecialCasedAsRNNBackEdge(node, tf_import_flags, model); } else if (node.op() == "ExpandDims") { - ConvertExpandDimsOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Fill") { - ConvertFillOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "FloorDiv") { - ConvertFloorDivOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "FloorMod") { - ConvertFloorModOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Range") { ConvertRangeOperator(node, tf_import_flags, model); } else if (node.op() == "Rank") { - ConvertRankOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Stack" || node.op() == "Pack") { ConvertStackOperator(node, tf_import_flags, model); } else if (node.op() == "Transpose") { - ConvertTransposeOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "ArgMax") { ConvertArgMaxOperator(node, tf_import_flags, model); } else if (node.op() == "Exp") { - ConvertExpOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "TopK" || node.op() == "TopKV2") { ConvertTopKV2Operator(node, tf_import_flags, model); } else if (node.op() == "DynamicPartition") { @@ -2311,13 +1939,23 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, } else if (node.op() == "RandomUniform") { ConvertRandomUniform(node, tf_import_flags, model); } else if (node.op() == "Sin") { - ConvertSinOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); + } else if (node.op() == "Log") { + ConvertSimpleOperator(node, tf_import_flags, model); } else if (node.op() == "Select") { - ConvertSelectOperator(node, tf_import_flags, model); + ConvertSimpleOperator(node, tf_import_flags, model); + } else if (node.op() == "SparseToDense") { + ConvertSparseToDenseOperator(node, tf_import_flags, model); + } else if (node.op() == "Equal") { + ConvertSimpleOperator(node, tf_import_flags, + model); + } else if (node.op() == "NotEqual") { + ConvertSimpleOperator(node, tf_import_flags, + model); } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } - return Status::OK(); + return tensorflow::Status::OK(); } } // namespace internal diff --git a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc index 835676662b9cb7ed20e578e2a35747a64ba443dc..d18c329a43411236f8fd5446998c168803b9373a 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow_test.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/lib/core/status.h" namespace toco { -using port::Status; using tensorflow::AttrValue; using tensorflow::DT_BOOL; using tensorflow::DT_FLOAT; @@ -33,6 +33,7 @@ using tensorflow::DT_INT64; using tensorflow::DT_QUINT8; using tensorflow::DT_STRING; using tensorflow::NodeDef; +using tensorflow::Status; namespace internal { Status ImportTensorFlowNode(const NodeDef&, const TensorFlowImportFlags&, @@ -117,9 +118,10 @@ TEST_P(ShapeImportTest, ShapeElementIsNegative) { NodeDef node; BuildConstNode({1, -2, 10}, GetParam(), 0, &node); auto status = ImportNode(node); - EXPECT_EQ(status.error_message(), - "Tensor shape should not include negative values (while processing " - "node 'Node1')"); + EXPECT_EQ( + status.error_message(), + "Tensor shape should not include negative values\n\t (while processing " + "node 'Node1')"); } INSTANTIATE_TEST_CASE_P(ShapeElementIsNegative, ShapeImportTest, ::testing::ValuesIn(TestTypes())); @@ -129,7 +131,7 @@ TEST_P(ShapeImportTest, ShapeElementTooLarge) { BuildConstNode({3000000000}, GetParam(), 0, &node); auto status = ImportNode(node); EXPECT_EQ(status.error_message(), - "Shape element overflows (while processing node 'Node1')"); + "Shape element overflows\n\t (while processing node 'Node1')"); } INSTANTIATE_TEST_CASE_P(ShapeElementTooLarge, ShapeImportTest, ::testing::ValuesIn(TestTypes())); @@ -139,7 +141,7 @@ TEST_P(ShapeImportTest, ShapeTooLarge) { BuildConstNode({1000000, 2000000, 2000000, 2000000}, GetParam(), 0, &node); auto status = ImportNode(node); EXPECT_EQ(status.error_message(), - "Tensor shape is too large (while processing node 'Node1')"); + "Tensor shape is too large\n\t (while processing node 'Node1')"); } INSTANTIATE_TEST_CASE_P(ShapeTooLarge, ShapeImportTest, ::testing::ValuesIn(TestTypes())); @@ -148,11 +150,11 @@ TEST_P(ShapeImportTest, ValidShapeButZeroElements) { NodeDef node; BuildConstNode({1, 2, 2, 2}, GetParam(), 0, &node); auto status = ImportNode(node); - EXPECT_THAT( - status.error_message(), - ::testing::MatchesRegex( - "Neither input_content .0. nor .*_val .0. have the right " - "dimensions .8. for this .* tensor .while processing node 'Node1'.")); + EXPECT_THAT(status.error_message(), + ::testing::MatchesRegex( + "Neither input_content .0. nor .*_val .0. have the right " + "dimensions .8. for this .* tensor\n\t .while processing " + "node 'Node1'.")); } INSTANTIATE_TEST_CASE_P(ValidShapeButZeroElements, ShapeImportTest, ::testing::ValuesIn(TestTypes())); diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index d878ac54e4d819efc1b0951acbbab23b3387eac5..2f43adb07b1c9dc9645942ce6ec868595704baa5 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -135,6 +135,9 @@ enum class OperatorType { // special nodes in the graph to shuffle axes. kReorderAxes, kSelect, + kSparseToDense, + kTensorFlowEqual, + kTensorFlowNotEqual, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -152,6 +155,7 @@ enum class AxesOrder { k1HWO, // Our standard for DepthwiseConv weights kHWIM, // TensorFlow DepthwiseConv weights kNHWC, // TensorFlow activations + kHWOI, // TensorFlow back-prop conv weights }; // The type of the scalars in an array. @@ -526,7 +530,15 @@ struct LstmCellOperator : Operator { ACTIV_TEMP = 3, NUM_OUTPUTS = 4 }; - LstmCellOperator() : Operator(OperatorType::kLstmCell) {} + enum KernelType { + KERNEL_BASIC = 0, + KERNEL_FULL = 1, + }; + + LstmCellOperator() + : Operator(OperatorType::kLstmCell), kernel_type(KERNEL_BASIC) {} + + KernelType kernel_type; }; // Element-wise multiplication operator. @@ -1349,6 +1361,22 @@ struct TensorFlowGreaterEqualOperator : Operator { : Operator(OperatorType::kTensorFlowGreaterEqual) {} }; +// TensorFlow Equal equivalent. Refer to TensorFlow documentation for +// details. +// Not fully supported, just a placeholder to handle TensorFlow graphs and +// support graph transformations to other operator types by matching sub-graphs. +// Typically, this is only used as an input to an Assert node, so can be +// removed as an unused node as we drop Assert nodes. +struct TensorFlowEqualOperator : Operator { + TensorFlowEqualOperator() : Operator(OperatorType::kTensorFlowEqual) {} +}; + +// TensorFlow Not Equal equivalent. Refer to TensorFlow documentation for +// details. +struct TensorFlowNotEqualOperator : Operator { + TensorFlowNotEqualOperator() : Operator(OperatorType::kTensorFlowNotEqual) {} +}; + // Global max reduction: computes the max of all of entries in the input array. // Thus the output is "0-dimensional": it consists of a single scalar value. // @@ -1598,13 +1626,26 @@ struct DynamicStitchOperator : Operator { int num_partitions; }; +// SparseToDense operator: +// +// Inputs: +// Inputs[0]: required: sparse_indices. +// Inputs[1]: required: output_shape. +// Inputs[2]: required: sparse_values. +// +// TensorFlow equivalent: SparseToDense. +struct SparseToDenseOperator : Operator { + SparseToDenseOperator() : Operator(OperatorType::kSparseToDense) {} + bool validate_indices; +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are // offsets from the start of the workspace buffer, expressed in bytes. struct Alloc { - int start = 0; - int end = 0; + int64 start = 0; + int64 end = 0; }; inline bool operator<(const Alloc& a, const Alloc& b) { diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index f875c85d1a7447432ea0f3a7d68b028e52cb78d4..4c9f1aa4b0274b5123bb3baa9b9fca1463bda4c3 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -48,7 +48,7 @@ bool ParseModelFlagsFromCommandLineFlags( "that information from the input file."), Flag("input_arrays", parsed_flags.input_arrays.bind(), parsed_flags.input_arrays.default_value(), - "Names of the output arrays, comma-separated. If not specified, " + "Names of the input arrays, comma-separated. If not specified, " "will try to read that information from the input file."), Flag("output_array", parsed_flags.output_array.bind(), parsed_flags.output_array.default_value(), @@ -83,7 +83,7 @@ bool ParseModelFlagsFromCommandLineFlags( "Deprecated: use --input_data_types instead. Input array type, if " "not already provided in the graph. " "Typically needs to be specified when passing arbitrary arrays " - "to --input_array."), + "to --input_arrays."), Flag("input_data_types", parsed_flags.input_data_types.bind(), parsed_flags.input_data_types.default_value(), "Input arrays types, comma-separated, if not already provided in " diff --git a/tensorflow/contrib/lite/toco/python/BUILD b/tensorflow/contrib/lite/toco/python/BUILD index 6c4f8e12cdd5b3222997c4a2b0ac243cc74324e0..93fe756a55d378fa205ff88be5e18aff586e5dca 100644 --- a/tensorflow/contrib/lite/toco/python/BUILD +++ b/tensorflow/contrib/lite/toco/python/BUILD @@ -12,10 +12,11 @@ cc_library( deps = [ "//tensorflow/contrib/lite/toco:model_flags_proto_cc", "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", + "//tensorflow/contrib/lite/toco:toco_graphviz_dump_options", "//tensorflow/contrib/lite/toco:toco_port", "//tensorflow/contrib/lite/toco:toco_tooling", "//tensorflow/core:lib", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -26,7 +27,7 @@ tf_py_wrap_cc( ":toco_python_api", "//tensorflow/contrib/lite/toco:model_flags_proto_cc", "//tensorflow/contrib/lite/toco:toco_flags_proto_cc", - "//util/python:python_headers", + "//third_party/python_runtime:headers", "@com_google_absl//absl/strings", ], ) @@ -41,12 +42,6 @@ py_binary( ], ) -py_binary( - name = "toco_wrapper", - srcs = ["toco_wrapper.py"], - srcs_version = "PY2AND3", -) - tf_py_test( name = "toco_from_protos_test", srcs = ["toco_from_protos_test.py"], diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.cc b/tensorflow/contrib/lite/toco/python/toco_python_api.cc index 5b1db852b4f8e89c1a591cfe18a0ab0aa2db04c9..d93e104038741e6e59608f04115854d611f1f9ae 100644 --- a/tensorflow/contrib/lite/toco/python/toco_python_api.cc +++ b/tensorflow/contrib/lite/toco/python/toco_python_api.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/python/toco_python_api.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" +#include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" #include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_tooling.h" #include "tensorflow/contrib/lite/toco/toco_types.h" @@ -62,7 +63,7 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, std::string input_contents_txt = ConvertArg(input_contents_txt_raw, &error); if (error) return nullptr; - // Use toco to produce new outputs + // Use TOCO to produce new outputs. toco::ModelFlags model_flags; if (!model_flags.ParseFromString(model_flags_proto_txt)) { LOG(FATAL) << "Model proto failed to parse." << std::endl; @@ -71,6 +72,16 @@ PyObject* TocoConvert(PyObject* model_flags_proto_txt_raw, if (!toco_flags.ParseFromString(toco_flags_proto_txt)) { LOG(FATAL) << "Toco proto failed to parse." << std::endl; } + + auto& dump_options = *GraphVizDumpOptions::singleton(); + if (toco_flags.has_dump_graphviz_dir()) { + dump_options.dump_graphviz = toco_flags.dump_graphviz_dir(); + } + if (toco_flags.has_dump_graphviz_include_video()) { + dump_options.dump_graphviz_video = toco_flags.dump_graphviz_include_video(); + } + + // Convert model. std::unique_ptr model = toco::Import(toco_flags, model_flags, input_contents_txt); toco::Transform(toco_flags, model.get()); diff --git a/tensorflow/contrib/lite/toco/python/toco_python_api.h b/tensorflow/contrib/lite/toco/python/toco_python_api.h index 9af38e937c29804f950ea53ee86b70e3ccb02360..7e8ad9c1dafa68dd91e4a0eb3bfb742207878c59 100644 --- a/tensorflow/contrib/lite/toco/python/toco_python_api.h +++ b/tensorflow/contrib/lite/toco/python/toco_python_api.h @@ -15,8 +15,8 @@ limitations under the License. #ifndef _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ #define _THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_PYTHON_TOCO_PYTHON_API_H_ -#include #include +#include namespace toco { diff --git a/tensorflow/contrib/lite/toco/python/toco_wrapper.py b/tensorflow/contrib/lite/toco/python/toco_wrapper.py deleted file mode 100644 index 6d6b500d7eccd353f566a4bad76df35e0e849d95..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/toco/python/toco_wrapper.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Wrapper for runninmg toco binary embedded in pip site-package. - -NOTE: this mainly exists since PIP setup.py cannot install binaries to bin/. -It can only install Python "console-scripts." This will work as a console -script. See tools/pip_package/setup.py (search for CONSOLE_SCRIPTS). -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import sys - - -def main(): - # Pip installs the binary in aux-bin off of main site-package install. - # Just find it and exec, passing all arguments in the process. - # TODO(aselle): it is unfortunate to use all of tensorflow to lookup binary. - print("""TOCO from pip install is currently not working on command line. -Please use the python TOCO API or use -bazel run tensorflow/contrib/lite:toco -- from a TensorFlow source dir. -""") - sys.exit(1) - # TODO(aselle): Replace this when we find a way to run toco without - # blowing up executable size. - # binary = os.path.join(tf.__path__[0], 'aux-bin/toco') - # os.execvp(binary, sys.argv) diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc index 5daa703c80b3b5d9152c5d21976260f21679a3f2..a2d753657b0bf6c88f5c94a20a1240fb7c13a37c 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.cc +++ b/tensorflow/contrib/lite/toco/tflite/export.cc @@ -316,6 +316,7 @@ void Export( auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map, &builder, &error_summary); const string fake_quant_operation_name = "FAKE_QUANT"; + if (error_summary.count(fake_quant_operation_name) != 0) { LOG(ERROR) << fake_quant_operation_name @@ -327,6 +328,21 @@ void Export( error_summary.erase(fake_quant_operation_name); } if (!allow_custom_ops && !error_summary.empty()) { + // Remove ExpandDims and ReorderAxes from unimplemented list unless they + // compose the list. Both ops are removed during graph transformations. + // However, if an op is unimplemented earlier in the model, the graph + // transformation is unable to run because the output shape is not defined. + // This causes unnecessary confusion during model conversion time. + std::set error_summary_final; + for (const auto& op_type : error_summary) { + if (op_type != "ReorderAxes" && op_type != "ExpandDims") { + error_summary_final.insert(op_type); + } + } + if (error_summary_final.empty()) { + error_summary_final = error_summary; + } + LOG(QFATAL) << "Some of the operators in the model are not supported by " "the standard TensorFlow Lite runtime. If you have a custom " @@ -334,7 +350,7 @@ void Export( "--allow_custom_ops, or by setting allow_custom_ops=True " "when calling tf.contrib.lite.toco_convert(). Here is a list " "of operators for which you will need custom implementations: " - << absl::StrJoin(error_summary, ", ") << "."; + << absl::StrJoin(error_summary_final, ", ") << "."; } auto ops = diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h index 90abfb94d8d091525cc6ce7b12e2e29c7e648160..098d2163e6c2fe26f3cb9cdf9959df62a1a4baf0 100644 --- a/tensorflow/contrib/lite/toco/tflite/export.h +++ b/tensorflow/contrib/lite/toco/tflite/export.h @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tflite/operator.h" +#include "tensorflow/contrib/lite/util.h" namespace toco { @@ -72,22 +73,10 @@ struct OperatorKey { struct Hash { size_t operator()(const OperatorKey& key) const { - return CombineHashes({std::hash()(static_cast(key.type)), - std::hash()(key.custom_code), - std::hash()(key.version)}); - } - - private: - // TODO(ycling): Refactoring and extract this function into a common - // utility module. - static size_t CombineHashes(std::initializer_list hashes) { - size_t result = 0; - // Hash combiner used by TensorFlow core. - for (size_t hash : hashes) { - result = result ^ (hash + 0x9e3779b97f4a7800ULL + (result << 10) + - (result >> 4)); - } - return result; + return ::tflite::CombineHashes( + {std::hash()(static_cast(key.type)), + std::hash()(key.custom_code), + std::hash()(key.version)}); } }; }; diff --git a/tensorflow/contrib/lite/toco/tflite/import.cc b/tensorflow/contrib/lite/toco/tflite/import.cc index c0e7ab2ef57ed8edf1b7cda08c64f6ae66172af3..cb44a5e6d7356a1cf5597bbe48565c5b1e1949a6 100644 --- a/tensorflow/contrib/lite/toco/tflite/import.cc +++ b/tensorflow/contrib/lite/toco/tflite/import.cc @@ -113,15 +113,35 @@ void ImportOperators( << operators_table.size(); } string opname = operators_table.at(index); + + // Find and use the appropriate operator deserialization factory. + std::unique_ptr new_op = nullptr; if (ops_by_name.count(opname) == 0) { - LOG(FATAL) << "Op '" << opname << "' not supported"; + string effective_opname = "TENSORFLOW_UNSUPPORTED"; + if (ops_by_name.count(effective_opname) == 0) { + LOG(FATAL) << "Internal logic error: TENSORFLOW_UNSUPPORTED not found."; + } + new_op = ops_by_name.at(effective_opname) + ->Deserialize(input_op->builtin_options(), + input_op->custom_options()); + if (new_op->type == OperatorType::kTensorFlowUnsupported) { + auto* unsupported_op = + static_cast(new_op.get()); + unsupported_op->tensorflow_op = opname; + // TODO(b/109932940): Remove this when quantized is removed. + // For now, we assume all ops are quantized. + unsupported_op->quantized = true; + } else { + LOG(FATAL) << "Expected a TensorFlowUnsupportedOperator"; + } + } else { + new_op = ops_by_name.at(opname)->Deserialize(input_op->builtin_options(), + input_op->custom_options()); } - - auto new_op = ops_by_name.at(opname)->Deserialize( - input_op->builtin_options(), input_op->custom_options()); model->operators.emplace_back(new_op.release()); auto* op = model->operators.back().get(); + // Make sure all the inputs and outputs are hooked up. auto inputs = input_op->inputs(); for (int i = 0; i < inputs->Length(); i++) { auto input_index = inputs->Get(i); diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 6922e5055a602b8d2eb43f88cde15b0d505eac40..7490ab960b9b0c62bef4c343927664ac6ae4eb9d 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -507,6 +507,22 @@ class Pad : public BuiltinOperator { + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateTileOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + int GetVersion(const Operator& op) const override { return 1; } +}; + class PadV2 : public BuiltinOperator { public: @@ -610,11 +626,21 @@ class Lstm : public BuiltinOperator WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { + ::tflite::LSTMKernelType kernel_type; + switch (op.kernel_type) { + case LstmCellOperator::KERNEL_BASIC: + kernel_type = ::tflite::LSTMKernelType_BASIC; + break; + case LstmCellOperator::KERNEL_FULL: + kernel_type = ::tflite::LSTMKernelType_FULL; + break; + } + // Current toco converter only supports tanh, no clip. return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/ ::tflite::ActivationFunctionType_TANH, /*cell_clip=*/0.0, - /*proj_clip=*/0.0); + /*proj_clip=*/0.0, kernel_type); } void ReadOptions(const TfLiteOptions& options, @@ -622,9 +648,26 @@ class Lstm : public BuiltinOperatorkernel_type = LstmCellOperator::KERNEL_BASIC; + break; + case ::tflite::LSTMKernelType_FULL: + op->kernel_type = LstmCellOperator::KERNEL_FULL; + break; + } } - int GetVersion(const Operator& op) const override { return 1; } + int GetVersion(const Operator& op) const override { + const auto& lstm_op = static_cast(op); + switch (lstm_op.kernel_type) { + case LstmCellOperator::KERNEL_FULL: + return 1; + case LstmCellOperator::KERNEL_BASIC: + return 2; + } + } }; class Mean : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->validate_indices = options.validate_indices(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class ExpandDims + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateExpandDimsOptions(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -976,8 +1058,14 @@ std::vector> BuildOperatorList() { new Cast(::tflite::BuiltinOperator_CAST, OperatorType::kCast)); ops.emplace_back( new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax)); + ops.emplace_back( + new Tile(::tflite::BuiltinOperator_TILE, OperatorType::kTensorFlowTile)); + ops.emplace_back(new ExpandDims(::tflite::BuiltinOperator_EXPAND_DIMS, + OperatorType::kExpandDims)); ops.emplace_back(new TransposeConv(::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv)); + ops.emplace_back(new SparseToDense(::tflite::BuiltinOperator_SPARSE_TO_DENSE, + OperatorType::kSparseToDense)); // Custom Operators. ops.emplace_back( @@ -1024,12 +1112,18 @@ std::vector> BuildOperatorList() { "LESS", OperatorType::kTensorFlowLess)); ops.emplace_back(new SimpleOperator( "LESS_EQUAL", OperatorType::kTensorFlowLessEqual)); + ops.emplace_back(new SimpleOperator( + "EQUAL", OperatorType::kTensorFlowEqual)); + ops.emplace_back(new SimpleOperator( + "NOT_EQUAL", OperatorType::kTensorFlowNotEqual)); ops.emplace_back(new SimpleOperator("NEG", OperatorType::kNeg)); ops.emplace_back( new SimpleOperator("SELECT", OperatorType::kSelect)); ops.emplace_back( new SimpleOperator("SLICE", OperatorType::kSlice)); + // Element-wise operator ops.emplace_back(new SimpleOperator("SIN", OperatorType::kSin)); + ops.emplace_back(new SimpleOperator("LOG", OperatorType::kLog)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index fe594c6da9826ab904d162c9e28e1455b1bf69f6..03bb20b3208196e964d950c0f0954d1fc0ba9e86 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -74,8 +74,10 @@ class OperatorTest : public ::testing::Test { auto new_toco_op = op.Deserialize(output_options->builtin_options(), output_options->custom_options()); - CHECK(dynamic_cast(new_toco_op.get())) - << "Cannot cast " << HelpfulOperatorTypeName(*new_toco_op) << " to " + CHECK(new_toco_op->type == toco_op.type) + << "The type of the serialized and deserialized" + << HelpfulOperatorTypeName(*new_toco_op) + << " does not match the type of the original " << HelpfulOperatorTypeName(toco_op); return std::unique_ptr(dynamic_cast(new_toco_op.release())); @@ -119,6 +121,11 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("SELECT", OperatorType::kSelect); CheckSimpleOperator("SLICE", OperatorType::kSlice); CheckSimpleOperator("SIN", OperatorType::kSin); + CheckSimpleOperator("EQUAL", + OperatorType::kTensorFlowEqual); + CheckSimpleOperator( + "NOT_EQUAL", OperatorType::kTensorFlowNotEqual); + CheckSimpleOperator("LOG", OperatorType::kLog); } TEST_F(OperatorTest, BuiltinAdd) { @@ -420,6 +427,15 @@ TEST_F(OperatorTest, BuiltinTransposeConv) { EXPECT_EQ(op.padding.type, output_toco_op->padding.type); } +TEST_F(OperatorTest, BuiltinSparseToDense) { + SparseToDenseOperator op; + op.validate_indices = false; + std::unique_ptr output_toco_op = + SerializeAndDeserialize( + GetOperator("SPARSE_TO_DENSE", OperatorType::kSparseToDense), op); + EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index 7786a4ada335abc9a01a0a6e423125f2d67957c2..87a1e429b928bf59cb14597980602953732a7659 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -153,6 +153,16 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.dedupe_array_min_size_bytes.default_value(), "Minimum size of constant arrays to deduplicate; arrays smaller " "will not be deduplicated."), + Flag("split_tflite_lstm_inputs", + parsed_flags.split_tflite_lstm_inputs.bind(), + parsed_flags.split_tflite_lstm_inputs.default_value(), + "Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. " + "Ignored if the output format is not TFLite."), + Flag("quantize_weights", parsed_flags.quantize_weights.bind(), + parsed_flags.quantize_weights.default_value(), + "Store weights as quantized weights followed by dequantize " + "operations. Computation is still done in float, but reduces model " + "size (at the cost of accuracy and latency)."), }; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); @@ -245,6 +255,8 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel, FlagRequirement::kNone); READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone); + READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone); + READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone); // Deprecated flag handling. if (parsed_toco_flags.input_type.specified()) { @@ -278,6 +290,11 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, QCHECK(toco::IODataType_Parse(input_types[0], &input_type)); toco_flags->set_inference_input_type(input_type); } + if (parsed_toco_flags.quantize_weights.value()) { + QCHECK_NE(toco_flags->inference_type(), IODataType::QUANTIZED_UINT8) + << "quantize_weights is not supported with inference_type " + "QUANTIZED_UINT8."; + } #undef READ_TOCO_FLAG #undef PARSE_TOCO_FLAG diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index 8589ca361dae2561207f9fa0c57b3240240c08d6..ad4e94ded9f9730842a257e065d9aec2b1cbfac8 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 19. +// Next ID to use: 21. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -165,4 +165,22 @@ message TocoFlags { // Minimum size of constant arrays to deduplicate; arrays smaller will not be // deduplicated. optional int64 dedupe_array_min_size_bytes = 18 [default = 64]; + + // Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. + // Ignored if the output format is not TFLite. + optional bool split_tflite_lstm_inputs = 19 [default = true]; + + // Store weights as quantized weights followed by dequantize operations. + // Computation is still done in float, but reduces model size (at the cost of + // accuracy and latency). + optional bool quantize_weights = 20 [default = false]; + + // Full filepath of folder to dump the graphs at various stages of processing + // GraphViz .dot files. Preferred over --output_format=GRAPHVIZ_DOT in order + // to keep the requirements of the output file. + optional string dump_graphviz_dir = 24; + + // Boolean indicating whether to dump the graph after every graph + // transformation. + optional bool dump_graphviz_include_video = 25; } diff --git a/tensorflow/contrib/lite/toco/toco_port.cc b/tensorflow/contrib/lite/toco/toco_port.cc index a1c8696cd06a30bfe8661bb70aa4f2d6d175aac3..de76fd4032d24eff8a6c2fd0c16a911b9c00186b 100644 --- a/tensorflow/contrib/lite/toco/toco_port.cc +++ b/tensorflow/contrib/lite/toco/toco_port.cc @@ -16,8 +16,16 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/toco_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" +#if defined(__ANDROID__) && defined(__ARM_ARCH_7A__) +namespace std { +double round(double x) { return ::round(x); } +} // namespace std +#endif + namespace toco { namespace port { void CopyToBuffer(const string& src, char* dest) { @@ -55,8 +63,12 @@ void CheckInitGoogleIsDone(const char* message) { namespace file { // Conversion to our wrapper Status. -Status ToStatus(const ::util::Status& uts) { - return Status(uts.ok(), uts.error_message()); +tensorflow::Status ToStatus(const ::util::Status& uts) { + if (!uts.ok()) { + return tensorflow::Status(tensorflow::errors::Code(uts.error_code()), + uts.error_message()); + } + return tensorflow::Status::OK(); } // Conversion to our wrapper Options. @@ -65,7 +77,7 @@ toco::port::file::Options ToOptions(const ::file::Options& options) { return Options(); } -Status Writable(const string& filename) { +tensorflow::Status Writable(const string& filename) { File* f = nullptr; const auto status = ::file::Open(filename, "w", &f, ::file::Defaults()); if (f) { @@ -74,22 +86,24 @@ Status Writable(const string& filename) { return ToStatus(status); } -Status Readable(const string& filename, const file::Options& options) { +tensorflow::Status Readable(const string& filename, + const file::Options& options) { return ToStatus(::file::Readable(filename, ::file::Defaults())); } -Status Exists(const string& filename, const file::Options& options) { +tensorflow::Status Exists(const string& filename, + const file::Options& options) { auto status = ::file::Exists(filename, ::file::Defaults()); return ToStatus(status); } -Status GetContents(const string& filename, string* contents, - const file::Options& options) { +tensorflow::Status GetContents(const string& filename, string* contents, + const file::Options& options) { return ToStatus(::file::GetContents(filename, contents, ::file::Defaults())); } -Status SetContents(const string& filename, const string& contents, - const file::Options& options) { +tensorflow::Status SetContents(const string& filename, const string& contents, + const file::Options& options) { return ToStatus(::file::SetContents(filename, contents, ::file::Defaults())); } @@ -133,37 +147,42 @@ void CheckInitGoogleIsDone(const char* message) { namespace file { -Status Writable(const string& filename) { +tensorflow::Status Writable(const string& filename) { FILE* f = fopen(filename.c_str(), "w"); if (f) { fclose(f); - return Status(true, ""); + return tensorflow::Status::OK(); } - return Status(false, "not writable"); + return tensorflow::errors::NotFound("not writable"); } -Status Readable(const string& filename, const file::Options& options) { +tensorflow::Status Readable(const string& filename, + const file::Options& options) { FILE* f = fopen(filename.c_str(), "r"); if (f) { fclose(f); - return Status(true, ""); + return tensorflow::Status::OK(); } - return Status(false, "not readable"); + return tensorflow::errors::NotFound("not readable"); } -Status Exists(const string& filename, const file::Options& options) { +tensorflow::Status Exists(const string& filename, + const file::Options& options) { struct stat statbuf; int ret = stat(filename.c_str(), &statbuf); - return Status(ret != -1, ""); + if (ret == -1) { + return tensorflow::errors::NotFound("file doesn't exist"); + } + return tensorflow::Status::OK(); } -Status GetContents(const string& path, string* output, - const file::Options& options) { +tensorflow::Status GetContents(const string& path, string* output, + const file::Options& options) { output->clear(); int fd = open(path.c_str(), O_RDONLY); if (fd == -1) { - return Status(false, "can't open() for read"); + return tensorflow::errors::NotFound("can't open() for read"); } // Direct read, for speed. @@ -174,25 +193,25 @@ Status GetContents(const string& path, string* output, if (size == 0) { // Done. close(fd); - return Status(true, ""); + return tensorflow::Status::OK(); } else if (size == -1) { // Error. close(fd); - return Status(false, "error during read()"); + return tensorflow::errors::Internal("error during read()"); } else { output->append(buffer, size); } } CHECK(0); - return Status(false, "internal error"); + return tensorflow::errors::Internal("internal error"); } -Status SetContents(const string& filename, const string& contents, - const file::Options& options) { +tensorflow::Status SetContents(const string& filename, const string& contents, + const file::Options& options) { int fd = open(filename.c_str(), O_WRONLY | O_CREAT, 0664); if (fd == -1) { - return Status(false, "can't open() for write"); + return tensorflow::errors::Internal("can't open() for write"); } size_t i = 0; @@ -201,13 +220,13 @@ Status SetContents(const string& filename, const string& contents, ssize_t written = write(fd, &contents[i], to_write); if (written == -1) { close(fd); - return Status(false, "write() error"); + return tensorflow::errors::Internal("write() error"); } i += written; } close(fd); - return Status(true, ""); + return tensorflow::Status::OK(); } string JoinPath(const string& base, const string& filename) { diff --git a/tensorflow/contrib/lite/toco/toco_port.h b/tensorflow/contrib/lite/toco/toco_port.h index 906792ef569e5b8dd2a40f6cf683fa8a35946012..17f82b9dd7dcc633aa204038b6d965f4eb6967bb 100644 --- a/tensorflow/contrib/lite/toco/toco_port.h +++ b/tensorflow/contrib/lite/toco/toco_port.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "google/protobuf/text_format.h" #include "tensorflow/contrib/lite/toco/format_port.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/platform.h" #if defined(PLATFORM_GOOGLE) @@ -33,28 +34,26 @@ limitations under the License. #define TFLITE_PROTO_NS google::protobuf #endif -namespace toco { -namespace port { - -class Status { - public: - static Status OK() { return Status(true, ""); } - - // Create a failed status with no message. - Status() {} - - Status(bool ok, const string& message) : ok_(ok), message_(message) {} - - void AppendMessage(const string& message) { message_ += message; } +#ifdef __ANDROID__ +#include +namespace std { - bool ok() const { return ok_; } +template +std::string to_string(T value) +{ + std::ostringstream os ; + os << value ; + return os.str() ; +} - const string error_message() const { return message_; } +#ifdef __ARM_ARCH_7A__ +double round(double x); +#endif +} +#endif - private: - bool ok_ = false; - string message_; -}; +namespace toco { +namespace port { void InitGoogle(const char* usage, int* argc, char*** argv, bool remove_flags); void CheckInitGoogleIsDone(const char* message); @@ -65,14 +64,14 @@ inline Options Defaults() { Options o; return o; } -Status GetContents(const string& filename, string* contents, - const Options& options); -Status SetContents(const string& filename, const string& contents, - const Options& options); +tensorflow::Status GetContents(const string& filename, string* contents, + const Options& options); +tensorflow::Status SetContents(const string& filename, const string& contents, + const Options& options); string JoinPath(const string& base, const string& filename); -Status Writable(const string& filename); -Status Readable(const string& filename, const Options& options); -Status Exists(const string& filename, const Options& options); +tensorflow::Status Writable(const string& filename); +tensorflow::Status Readable(const string& filename, const Options& options); +tensorflow::Status Exists(const string& filename, const Options& options); } // namespace file // Copy `src` string to `dest`. User must ensure `dest` has enough space. diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index b5531ca2f4785e0c95703f95977be93a0ba2a8e2..1fe76f8163cdf23b27f8baaf2d9c6d99b1aa3747 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -263,12 +263,15 @@ void Transform(const TocoFlags& toco_flags, Model* model) { if (!toco_flags.debug_disable_recurrent_cell_fusion()) { transformations.Add(new IdentifyLstmCell); } - if (output_format == TFLITE) { + if (output_format == TFLITE && toco_flags.split_tflite_lstm_inputs()) { transformations.Add(new toco::SplitLstmCellInputs); } else { transformations.Add(new toco::MergeLstmCellInputs); } } + if (toco_flags.quantize_weights()) { + transformations.Add(new QuantizeWeights); + } transformations.Add(new ResolveConstantConcatenation); RunGraphTransformations(model, "general graph transformations", transformations); diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 1e6314f2dc78297c8bdacb19cf89292603695e3f..92bab5246cb85052b5e0216f1cb8a04736ae7a79 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/contrib/lite/toco/dump_graphviz.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" namespace toco { @@ -393,6 +393,9 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(DynamicPartition) HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) HANDLE_OPERATORTYPENAME_CASE(Select) + HANDLE_OPERATORTYPENAME_CASE(SparseToDense) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowEqual) + HANDLE_OPERATORTYPENAME_CASE(TensorFlowNotEqual) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -582,6 +585,13 @@ void UnextendShape(Shape* shape, int new_shape_size) { shape_dims.erase(shape_dims.begin(), shape_dims.begin() + size_reduction); } +bool IsValid(const Shape& shape) { + for (int i = 0; i < shape.dimensions_count(); ++i) { + if (shape.dims(i) < 1) return false; + } + return true; +} + void CheckShapeDimensions(const Shape& shape) { for (int i = 0; i < shape.dimensions_count(); ++i) { CHECK_GE(shape.dims()[i], 1) << "shape has dimension 0 at index << " << i @@ -1862,18 +1872,15 @@ void GetShuffleShape(AxesOrder input_axes_order, AxesOrder output_axes_order, output_axes_order == AxesOrder::kHWIO) { // 3210 <- 3210 // HWIO <- OHWI - (*shuffle)[0] = 1; - (*shuffle)[1] = 2; - (*shuffle)[2] = 3; - (*shuffle)[3] = 0; + *shuffle = {1, 2, 3, 0}; } else if (input_axes_order == AxesOrder::kHWIO && output_axes_order == AxesOrder::kOHWI) { // 3210 <- 3210 // OHWI <- HWIO - (*shuffle)[0] = 3; - (*shuffle)[1] = 0; - (*shuffle)[2] = 1; - (*shuffle)[3] = 2; + *shuffle = {3, 0, 1, 2}; + } else if (input_axes_order == AxesOrder::kOHWI && + output_axes_order == AxesOrder::kHWOI) { + *shuffle = {1, 2, 0, 3}; } else { LOG(FATAL) << "Bad shuffle"; } @@ -2019,6 +2026,8 @@ int AxesCount(AxesOrder axes_order) { return 4; case AxesOrder::kNHWC: return 4; + case AxesOrder::kHWOI: + return 4; default: LOG(FATAL) << "Bad AxesOrder"; return 0; diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 1f596ca8e5a28f17e816c33eea03725d16f7ce12..7681ce9d39ec56f9447896682b52bd4efb1d0e54 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -26,14 +26,15 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/core/platform/logging.h" #if TOCO_SUPPORT_PORTABLE_PROTOS -#include "third_party/protobuf/src/google/protobuf/text_format.h" +#include "third_party/protobuf/include/google/protobuf/text_format.h" #endif // TOCO_SUPPORT_PORTABLE_PROTOS #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" #include "tensorflow/contrib/lite/toco/toco_flags.pb.h" -#include "tensorflow/contrib/lite/toco/toco_port.h" #include "tensorflow/contrib/lite/toco/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" // TODO(aselle): Replace with using a container specific hash override instead. namespace std { @@ -112,7 +113,9 @@ void ExtendShape(Shape* shape, int new_shape_size); // TODO(b/36075966): Clean up when dims superseded by array shape. void UnextendShape(Shape* shape, int new_shape_size); -// Checks (using CHECK) that all dimensions of 'shape' are at least 1. +// Checks that all dimensions of 'shape' are at least 1. +bool IsValid(const Shape& shape); +// Same as above, but reports error using CHECK. void CheckShapeDimensions(const Shape& shape); // Given two shapes with potentially different dimensionality and dimension @@ -315,7 +318,7 @@ void UseArraysExtraInfo(Model* model, bool quantize_output); // doesn't have enough range to represent the sum of elements, an error is // returned. template -port::Status NumElements(const std::vector& shape, U* num_elements) { +tensorflow::Status NumElements(const std::vector& shape, U* num_elements) { static_assert( std::numeric_limits::max() <= std::numeric_limits::max(), "vector type exceed capabilities of NumElements"); @@ -326,17 +329,17 @@ port::Status NumElements(const std::vector& shape, U* num_elements) { // TensorFlow's shapes sometimes include -1 to represent an "unknown" // size but TOCO isn't able to create arrays of unknown sizes and will // crash in RequiredBufferSizeForShape(). - return port::Status(false, - "Tensor shape should not include negative values"); + return tensorflow::errors::InvalidArgument( + "Tensor shape should not include negative values"); } if (static_cast(dim) > std::numeric_limits::max() / *num_elements) { *num_elements = 0; - return port::Status(false, "Tensor shape is too large"); + return tensorflow::errors::InvalidArgument("Tensor shape is too large"); } *num_elements *= dim; } - return port::Status::OK(); + return tensorflow::Status::OK(); } } // namespace toco diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc index 87fd30db2cf54824a3c34ed875291d898f1a9e38..a683867374c8b8dcb274478adf6b5fa0691d1c5a 100644 --- a/tensorflow/contrib/lite/toco/tooling_util_test.cc +++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/lib/core/status.h" namespace toco { @@ -99,7 +100,7 @@ static const char kLargeTensorMessage[] = "Tensor shape is too large"; TEST(NumElementsTest, Int) { int count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{1024, 1024, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -114,7 +115,7 @@ TEST(NumElementsTest, Int) { TEST(NumElementsTest, Int32) { int32_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{1024, 1024, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -129,7 +130,7 @@ TEST(NumElementsTest, Int32) { TEST(NumElementsTest, Int64) { int64_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{16777216, 16777216, 32767}, &count); EXPECT_TRUE(status.ok()); @@ -144,7 +145,7 @@ TEST(NumElementsTest, Int64) { TEST(NumElementsTest, UnsignedInt32) { uint32_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{1024, 2048, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -159,7 +160,7 @@ TEST(NumElementsTest, UnsignedInt32) { TEST(NumElementsTest, UnsignedInt64) { uint64_t count; - port::Status status = port::Status::OK(); + tensorflow::Status status = tensorflow::Status::OK(); status = NumElements(std::vector{16777216, 16777216, 65535}, &count); diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 824a164651073bac846a514505726a8ee85cc41d..5913847329eeae7373d0d21834dd37327e4068c4 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -7,6 +7,8 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +common_copts = ["-Wall"] + py_binary( name = "visualize", srcs = ["visualize.py"], @@ -28,34 +30,6 @@ tf_cc_binary( ], ) -tf_cc_binary( - name = "benchmark_model", - srcs = ["benchmark_model.cc"], - linkopts = select({ - "//tensorflow:android": [ - "-pie", - "-landroid", - "-lm", - "-z defs", - "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export - ], - "//conditions:default": [], - }), - deps = [ - "//tensorflow/contrib/lite:framework", - "//tensorflow/contrib/lite:string_util", - "//tensorflow/contrib/lite/kernels:builtin_ops", - ] + select({ - "//tensorflow:android": [ - "//tensorflow/core:android_tensorflow_lib", - ], - "//conditions:default": [ - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], - }), -) - cc_library( name = "gen_op_registration", srcs = ["gen_op_registration.cc"], diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..8857062c000201e1077469fc36e3bf2760924a30 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/BUILD @@ -0,0 +1,89 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +common_copts = ["-Wall"] + tflite_copts() + +cc_binary( + name = "benchmark_model", + srcs = [ + "benchmark_main.cc", + "logging.h", + ], + copts = common_copts, + linkopts = tflite_linkopts() + select({ + "//tensorflow:android": [ + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm + ], + "//conditions:default": [], + }), + deps = [ + ":benchmark_tflite_model_lib", + ], +) + +cc_library( + name = "command_line_flags", + srcs = ["command_line_flags.cc"], + hdrs = ["command_line_flags.h"], + copts = common_copts, +) + +cc_test( + name = "command_line_flags_test", + srcs = ["command_line_flags_test.cc"], + copts = common_copts, + visibility = ["//visibility:private"], + deps = [ + ":command_line_flags", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "benchmark_tflite_model_lib", + srcs = [ + "benchmark_tflite_model.cc", + "logging.h", + ], + hdrs = ["benchmark_tflite_model.h"], + copts = common_copts, + deps = [ + ":benchmark_model_lib", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/profiling:profile_summarizer", + "//tensorflow/contrib/lite/profiling:profiler", + ], +) + +cc_library( + name = "benchmark_model_lib", + srcs = [ + "benchmark_model.cc", + "logging.h", + ], + hdrs = ["benchmark_model.h"], + copts = common_copts, + deps = [ + ":command_line_flags", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/profiling:profile_summarizer", + "//tensorflow/contrib/lite/profiling:profiler", + "//tensorflow/contrib/lite/profiling:time", + "//tensorflow/core:stats_calculator_portable", + ], +) + +tflite_portable_test_suite() diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c10826afff6d5569545d4b7df73c88d24d9dcd1a --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/README.md @@ -0,0 +1,154 @@ +# TFLite Model Benchmark Tool + +## Description + +A simple C++ binary to benchmark a TFLite model and its individual operators, +both on desktop machines and on Android. + +## To build/install/run + +### On Android: + +(0) Refer to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android to edit the `WORKSPACE` to configure the android NDK/SDK. + +(1) Build for your specific platform, e.g.: + +``` +bazel build -c opt \ + --config=android_arm \ + --cxxopt='--std=c++11' \ + tensorflow/contrib/lite/tools/benchmark:benchmark_model +``` + +(2) Connect your phone. Push the binary to your phone with adb push + (make the directory if required): + +``` +adb push bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model /data/local/tmp +``` + +(3) Make the binary executable. + +``` +adb shell chmod +x /data/local/tmp/benchmark_model +``` + +(4) Push the compute graph that you need to test. For example: + +``` +adb push mobilenet_quant_v1_224.tflite /data/local/tmp +``` + +(5) Run the benchmark. For example: + +``` +adb shell /data/local/tmp/benchmark_model \ + --graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \ + --input_layer="Placeholder" \ + --input_layer_shape="1,224,224,3" \ + --num_threads=4 +``` + +### On desktop: +(1) build the binary + +``` +bazel build -c opt tensorflow/contrib/lite/tools/benchmark:benchmark_model +``` + +(2) Run on your compute graph, similar to the Android case but without the need of adb shell. +For example: + +``` +bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \ + --graph=mobilenet_quant_v1_224.tflite \ + --input_layer="Placeholder" \ + --input_layer_shape="1,224,224,3" \ + --num_threads=4 +``` + +The MobileNet graph used as an example here may be downloaded from +https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip + +## Profiling model operators +The benchmark model binary also allows you to profile operators and give execution times of each operator. To do this, +compile the binary with a compiler flag that enables profiling to be compiled in. Pass **--copt=-DTFLITE_PROFILING_ENABLED** +to compile benchmark with profiling support. +For example, to compile with profiling support on Android, add this flag to the previous command: + +``` +bazel build -c opt \ + --config=android_arm \ + --cxxopt='--std=c++11' \ + --copt=-DTFLITE_PROFILING_ENABLED \ + tensorflow/contrib/lite/tools/benchmark:benchmark_model +``` +This compiles TFLite with profiling enabled, now you can run the benchmark binary like before. The binary will produce detailed statistics for each operation similar to those shown below: + +``` + +============================== Run Order ============================== + [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] + CONV_2D 0.000 4.269 4.269 0.107% 0.107% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] + DEPTHWISE_CONV_2D 4.270 2.150 2.150 0.054% 0.161% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_depthwise/Relu6] + CONV_2D 6.421 6.107 6.107 0.153% 0.314% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] + DEPTHWISE_CONV_2D 12.528 1.366 1.366 0.034% 0.348% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_depthwise/Relu6] + CONV_2D 13.895 4.195 4.195 0.105% 0.454% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_2_pointwise/Relu6] + DEPTHWISE_CONV_2D 18.091 1.260 1.260 0.032% 0.485% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_depthwise/Relu6] + CONV_2D 19.352 6.652 6.652 0.167% 0.652% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] + DEPTHWISE_CONV_2D 26.005 0.698 0.698 0.018% 0.670% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_depthwise/Relu6] + CONV_2D 26.703 3.344 3.344 0.084% 0.754% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_4_pointwise/Relu6] + DEPTHWISE_CONV_2D 30.047 0.646 0.646 0.016% 0.770% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_depthwise/Relu6] + CONV_2D 30.694 5.800 5.800 0.145% 0.915% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] + DEPTHWISE_CONV_2D 36.495 0.331 0.331 0.008% 0.924% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_depthwise/Relu6] + CONV_2D 36.826 2.838 2.838 0.071% 0.995% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_6_pointwise/Relu6] + DEPTHWISE_CONV_2D 39.665 0.439 0.439 0.011% 1.006% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_depthwise/Relu6] + CONV_2D 40.105 5.293 5.293 0.133% 1.139% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + DEPTHWISE_CONV_2D 45.399 0.352 0.352 0.009% 1.147% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_depthwise/Relu6] + CONV_2D 45.752 5.322 5.322 0.133% 1.281% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + DEPTHWISE_CONV_2D 51.075 0.357 0.357 0.009% 1.290% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_depthwise/Relu6] + CONV_2D 51.432 5.693 5.693 0.143% 1.433% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + DEPTHWISE_CONV_2D 57.126 0.366 0.366 0.009% 1.442% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_depthwise/Relu6] + CONV_2D 57.493 5.472 5.472 0.137% 1.579% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] + DEPTHWISE_CONV_2D 62.966 0.364 0.364 0.009% 1.588% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_depthwise/Relu6] + CONV_2D 63.330 5.404 5.404 0.136% 1.724% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] + DEPTHWISE_CONV_2D 68.735 0.155 0.155 0.004% 1.728% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_depthwise/Relu6] + CONV_2D 68.891 2.970 2.970 0.074% 1.802% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_12_pointwise/Relu6] + DEPTHWISE_CONV_2D 71.862 0.206 0.206 0.005% 1.807% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_depthwise/Relu6] + CONV_2D 72.069 5.888 5.888 0.148% 1.955% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] + AVERAGE_POOL_2D 77.958 0.036 0.036 0.001% 1.956% 0.000 0 [MobilenetV1/Logits/AvgPool_1a/AvgPool] + CONV_2D 77.994 1.445 1.445 0.036% 1.992% 0.000 0 [MobilenetV1/Logits/Conv2d_1c_1x1/BiasAdd] + RESHAPE 79.440 0.002 0.002 0.000% 1.992% 0.000 0 [MobilenetV1/Predictions/Reshape] + SOFTMAX 79.443 0.029 0.029 0.001% 1.993% 0.000 0 [MobilenetV1/Predictions/Softmax] + +============================== Top by Computation Time ============================== + [node type] [start] [first] [avg ms] [%] [cdf%] [mem KB] [times called] [Name] + CONV_2D 19.352 6.652 6.652 0.167% 0.167% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_3_pointwise/Relu6] + CONV_2D 6.421 6.107 6.107 0.153% 0.320% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_1_pointwise/Relu6] + CONV_2D 72.069 5.888 5.888 0.148% 0.468% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_13_pointwise/Relu6] + CONV_2D 30.694 5.800 5.800 0.145% 0.613% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_5_pointwise/Relu6] + CONV_2D 51.432 5.693 5.693 0.143% 0.756% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_9_pointwise/Relu6] + CONV_2D 57.493 5.472 5.472 0.137% 0.893% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_10_pointwise/Relu6] + CONV_2D 63.330 5.404 5.404 0.136% 1.029% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_11_pointwise/Relu6] + CONV_2D 45.752 5.322 5.322 0.133% 1.162% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_8_pointwise/Relu6] + CONV_2D 40.105 5.293 5.293 0.133% 1.295% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_7_pointwise/Relu6] + CONV_2D 0.000 4.269 4.269 0.107% 1.402% 0.000 0 [MobilenetV1/MobilenetV1/Conv2d_0/Relu6] + +Number of nodes executed: 31 +============================== Summary by node type ============================== + [Node type] [count] [avg ms] [avg %] [cdf %] [mem KB] [times called] + CONV_2D 15 1.406 89.270% 89.270% 0.000 0 + DEPTHWISE_CONV_2D 13 0.169 10.730% 100.000% 0.000 0 + SOFTMAX 1 0.000 0.000% 100.000% 0.000 0 + RESHAPE 1 0.000 0.000% 100.000% 0.000 0 + AVERAGE_POOL_2D 1 0.000 0.000% 100.000% 0.000 0 + +Timings (microseconds): count=50 first=79449 curr=81350 min=77385 max=88213 avg=79732 std=1929 +Memory (bytes): count=0 +31 nodes observed + + +Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9 +``` + + diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..372d31e838e5666df492ee3156022249a2d97691 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_main.cc @@ -0,0 +1,37 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace tflite { +namespace benchmark { + +int Main(int argc, char** argv) { +#ifdef TFLITE_CUSTOM_OPS_HEADER + TFLITE_LOG(INFO) << "STARTING with custom ops!"; +#else + TFLITE_LOG(INFO) << "STARTING!"; +#endif + BenchmarkTfLiteModel benchmark; + BenchmarkLoggingListener listener; + benchmark.AddListener(&listener); + benchmark.Run(argc, argv); + return 0; +} +} // namespace benchmark +} // namespace tflite + +int main(int argc, char** argv) { return tflite::benchmark::Main(argc, argv); } diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8a9a6112c1ec050be8d0bcfe9dc5f00df40d3ff --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h" + +#include + +#include +#include + +#include "tensorflow/contrib/lite/profiling/time.h" +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +namespace { +void SleepForSeconds(double sleep_seconds) { + if (sleep_seconds <= 0.0) { + return; + } + // Convert the run_delay string into a timespec. + timespec req; + req.tv_sec = static_cast(sleep_seconds); + req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; + // If requested, sleep between runs for an arbitrary amount of time. + // This can be helpful to determine the effect of mobile processor + // scaling and thermal throttling. +#ifdef PLATFORM_WINDOWS + Sleep(sleep_seconds * 1000); +#else + nanosleep(&req, nullptr); +#endif +} + +} // namespace + +namespace tflite { +namespace benchmark { +using tensorflow::Stat; + +void BenchmarkLoggingListener::OnBenchmarkEnd(const BenchmarkResults &results) { + auto inference_us = results.inference_time_us(); + auto init_us = results.startup_latency_us(); + auto warmup_us = results.warmup_time_us(); + TFLITE_LOG(INFO) << "Average inference timings in us: " + << "Warmup: " << warmup_us.avg() << ", " + << "Init: " << init_us << ", " + << "no stats: " << inference_us.avg(); +} + +std::vector BenchmarkModel::GetFlags() { + return { + Flag("num_runs", ¶ms_.num_runs, "number of runs"), + Flag("run_delay", ¶ms_.run_delay, "delay between runs in seconds"), + Flag("num_threads", ¶ms_.num_threads, "number of threads"), + Flag("benchmark_name", ¶ms_.benchmark_name, "benchmark name"), + Flag("output_prefix", ¶ms_.output_prefix, "benchmark output prefix"), + Flag("warmup_runs", ¶ms_.warmup_runs, + "how many runs to initialize model"), + }; +} + +void BenchmarkModel::LogFlags() { + TFLITE_LOG(INFO) << "Num runs: [" << params_.num_runs << "]"; + TFLITE_LOG(INFO) << "Inter-run delay (seconds): [" << params_.run_delay + << "]"; + TFLITE_LOG(INFO) << "Num threads: [" << params_.num_threads << "]"; + TFLITE_LOG(INFO) << "Benchmark name: [" << params_.benchmark_name << "]"; + TFLITE_LOG(INFO) << "Output prefix: [" << params_.output_prefix << "]"; + TFLITE_LOG(INFO) << "Warmup runs: [" << params_.warmup_runs << "]"; +} + +Stat BenchmarkModel::Run(int num_times, RunType run_type) { + Stat run_stats; + TFLITE_LOG(INFO) << "Running benchmark for " << num_times << " iterations "; + for (int run = 0; run < num_times; run++) { + listeners_.OnSingleRunStart(run_type); + int64_t start_us = profiling::time::NowMicros(); + RunImpl(); + int64_t end_us = profiling::time::NowMicros(); + listeners_.OnSingleRunEnd(); + + run_stats.UpdateStat(end_us - start_us); + SleepForSeconds(params_.run_delay); + } + + std::stringstream stream; + run_stats.OutputToStream(&stream); + TFLITE_LOG(INFO) << stream.str() << std::endl; + + return run_stats; +} + +void BenchmarkModel::Run(int argc, char **argv) { + if (!ParseFlags(argc, argv)) { + return; + } + + LogFlags(); + + listeners_.OnBenchmarkStart(params_); + int64_t initialization_start_us = profiling::time::NowMicros(); + Init(); + int64_t initialization_end_us = profiling::time::NowMicros(); + int64_t startup_latency_us = initialization_end_us - initialization_start_us; + TFLITE_LOG(INFO) << "Initialized session in " << startup_latency_us / 1e3 + << "ms"; + + uint64_t input_bytes = ComputeInputBytes(); + Stat warmup_time_us = Run(params_.warmup_runs, WARMUP); + Stat inference_time_us = Run(params_.num_runs, REGULAR); + listeners_.OnBenchmarkEnd( + {startup_latency_us, input_bytes, warmup_time_us, inference_time_us}); +} + +bool BenchmarkModel::ParseFlags(int argc, char **argv) { + auto flag_list = GetFlags(); + const bool parse_result = + Flags::Parse(&argc, const_cast(argv), flag_list); + if (!parse_result) { + std::string usage = Flags::Usage(argv[0], flag_list); + TFLITE_LOG(ERROR) << usage; + return false; + } + return ValidateFlags(); +} + +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h new file mode 100644 index 0000000000000000000000000000000000000000..d48f693693c2cee0cd2e2a6f2b4c590998feffb3 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_model.h @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_MODEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" +#include "tensorflow/core/util/stats_calculator.h" + +namespace tflite { +namespace benchmark { + +enum RunType { + WARMUP, + REGULAR, +}; + +class BenchmarkResults { + public: + BenchmarkResults(int64_t startup_latency_us, uint64_t input_bytes, + tensorflow::Stat warmup_time_us, + tensorflow::Stat inference_time_us) + : startup_latency_us_(startup_latency_us), + input_bytes_(input_bytes), + warmup_time_us_(warmup_time_us), + inference_time_us_(inference_time_us) {} + + tensorflow::Stat inference_time_us() const { + return inference_time_us_; + } + tensorflow::Stat warmup_time_us() const { return warmup_time_us_; } + int64_t startup_latency_us() const { return startup_latency_us_; } + uint64_t input_bytes() const { return input_bytes_; } + double throughput_MB_per_second() const { + double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) / + inference_time_us_.sum(); + return bytes_per_sec / (1024.0 * 1024.0); + } + + private: + int64_t startup_latency_us_; + uint64_t input_bytes_; + tensorflow::Stat warmup_time_us_; + tensorflow::Stat inference_time_us_; +}; + +struct BenchmarkParams { + BenchmarkParams() + : num_runs(50), warmup_runs(1), run_delay(-1.0), num_threads(1) {} + int num_runs; + int warmup_runs; + float run_delay; + int num_threads; + std::string benchmark_name; + std::string output_prefix; +}; + +class BenchmarkListener { + public: + virtual void OnBenchmarkStart(const BenchmarkParams& params) {} + virtual void OnSingleRunStart(RunType runType) {} + virtual void OnSingleRunEnd() {} + virtual void OnBenchmarkEnd(const BenchmarkResults& results) {} + virtual ~BenchmarkListener() {} +}; + +// A listener that forwards its method calls to a collection of listeners. +class BenchmarkListeners : public BenchmarkListener { + public: + // Added a listener to the listener collection. + // |listener| is not owned by the instance of |BenchmarkListeners|. + // |listener| should not be null and should outlast the instance of + // |BenchmarkListeners|. + void AddListener(BenchmarkListener* listener) { + listeners_.push_back(listener); + } + + void OnBenchmarkStart(const BenchmarkParams& params) override { + for (auto listener : listeners_) { + listener->OnBenchmarkStart(params); + } + } + + void OnSingleRunStart(RunType runType) override { + for (auto listener : listeners_) { + listener->OnSingleRunStart(runType); + } + } + + void OnSingleRunEnd() override { + for (auto listener : listeners_) { + listener->OnSingleRunEnd(); + } + } + + void OnBenchmarkEnd(const BenchmarkResults& results) override { + for (auto listener : listeners_) { + listener->OnBenchmarkEnd(results); + } + } + + ~BenchmarkListeners() {} + + private: + // Use vector so listeners are invoked in the order they are added. + std::vector listeners_; +}; + +// Benchmark listener that just logs the results of benchmark run. +class BenchmarkLoggingListener : public BenchmarkListener { + void OnBenchmarkEnd(const BenchmarkResults& results) override; +}; + +// Benchmarks a model. +// +// Subclasses need to implement initialization and running of the model. +// The results can be collected by adding BenchmarkListener(s). +class BenchmarkModel { + public: + virtual ~BenchmarkModel() {} + bool ParseFlags(int argc, char** argv); + virtual void Init() = 0; + void Run(int argc, char** argv); + void AddListener(BenchmarkListener* listener) { + listeners_.AddListener(listener); + } + + protected: + virtual void LogFlags(); + virtual bool ValidateFlags() { return true; } + virtual std::vector GetFlags(); + virtual uint64_t ComputeInputBytes() = 0; + virtual tensorflow::Stat Run(int num_times, RunType run_type); + virtual void RunImpl() = 0; + BenchmarkParams params_; + BenchmarkListeners listeners_; +}; + +} // namespace benchmark +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f803cec197858953180d379c763ed7ebd34ee1d --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc @@ -0,0 +1,304 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/op_resolver.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/benchmark/logging.h" + +#ifdef TFLITE_CUSTOM_OPS_HEADER +void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); +#endif + +namespace tflite { +namespace benchmark { + +void ProfilingListener::SetInterpreter(tflite::Interpreter* interpreter) { + TFLITE_BENCHMARK_CHECK(interpreter); + interpreter_ = interpreter; + interpreter_->SetProfiler(&profiler_); +} + +void ProfilingListener::OnSingleRunStart(RunType run_type) { + if (run_type == REGULAR) { + profiler_.Reset(); + profiler_.StartProfiling(); + } +} + +void ProfilingListener::OnBenchmarkEnd(const BenchmarkResults& results) { + if (has_profiles_) { + TFLITE_LOG(INFO) << summarizer_.GetOutputString(); + } +} + +void ProfilingListener::OnSingleRunEnd() { + profiler_.StopProfiling(); + auto profile_events = profiler_.GetProfileEvents(); + has_profiles_ = !profile_events.empty(); + summarizer_.ProcessProfiles(profile_events, *interpreter_); +} + +namespace { + +std::vector Split(const std::string& str, const char delim) { + std::istringstream input(str); + std::vector results; + std::string item; + while (std::getline(input, item, delim)) { + results.push_back(item); + } + return results; +} + +template +bool SplitAndParse(const std::string& str, char delim, std::vector* values) { + std::istringstream input(str); + bool first = true; + while (!input.eof()) { + if (!first) { + char c; + input >> c; + if (c != delim) { + return false; + } + } else { + first = false; + } + T val; + input >> val; + if (!input.eof() && !input.good()) { + return false; + } + values->push_back(val); + } + return true; +} + +template +void FillRandomValue(T* ptr, const std::vector& sizes, + const std::function& random_func) { + int num_elements = 1; + for (int dim : sizes) { + num_elements *= dim; + } + for (int i = 0; i < num_elements; ++i) { + *ptr++ = random_func(); + } +} + +void FillRandomString(tflite::DynamicBuffer* buffer, + const std::vector& sizes, + const std::function& random_func) { + int num_elements = 1; + for (int dim : sizes) { + num_elements *= dim; + } + for (int i = 0; i < num_elements; ++i) { + auto str = random_func(); + buffer->AddString(str.data(), str.length()); + } +} + +bool PopulateInputLayerInfo( + const string& names_string, const string& shapes_string, + std::vector* info) { + std::vector names = Split(names_string, ','); + std::vector shapes = Split(shapes_string, ':'); + + if (names.size() != shapes.size()) { + TFLITE_LOG(ERROR) << "The number of items in" + << " --input_layer_shape (" << shapes_string << ", with " + << shapes.size() << " items)" + << " must match the number of items in" + << " --input_layer (" << names_string << ", with " + << names.size() << " items)." + << " For example --input_layer=input1,input2" + << " --input_layer_shape=1,224,224,4:1,20"; + return false; + } + + for (int i = 0; i < names.size(); ++i) { + info->push_back(BenchmarkTfLiteModel::InputLayerInfo()); + BenchmarkTfLiteModel::InputLayerInfo& input = info->back(); + + input.name = names[i]; + + TFLITE_BENCHMARK_CHECK(SplitAndParse(shapes[i], ',', &input.shape)) + << "Incorrect size string specified: " << shapes[i]; + for (int dim : input.shape) { + if (dim == -1) { + TFLITE_LOG(ERROR) + << "Any unknown sizes in the shapes (-1's) must be replaced" + << " with the size you want to benchmark with."; + return false; + } + } + } + + return true; +} + +} // namespace + +std::vector BenchmarkTfLiteModel::GetFlags() { + std::vector flags = BenchmarkTfLiteModel::BenchmarkModel::GetFlags(); + std::vector specific_flags = { + Flag("graph", &graph, "graph file name"), + Flag("input_layer", &input_layer_string, "input layer names"), + Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), + Flag("use_nnapi", &use_nnapi, "use nnapi api")}; + + flags.insert(flags.end(), specific_flags.begin(), specific_flags.end()); + return flags; +} + +void BenchmarkTfLiteModel::LogFlags() { + BenchmarkModel::LogFlags(); + TFLITE_LOG(INFO) << "Graph: [" << graph << "]"; + TFLITE_LOG(INFO) << "Input layers: [" << input_layer_string << "]"; + TFLITE_LOG(INFO) << "Input shapes: [" << input_layer_shape_string << "]"; + TFLITE_LOG(INFO) << "Use nnapi : [" << use_nnapi << "]"; +} + +bool BenchmarkTfLiteModel::ValidateFlags() { + if (graph.empty()) { + TFLITE_LOG(ERROR) + << "Please specify the name of your TF Lite input file with --graph"; + return false; + } + return PopulateInputLayerInfo(input_layer_string, input_layer_shape_string, + &inputs); +} + +uint64_t BenchmarkTfLiteModel::ComputeInputBytes() { + TFLITE_BENCHMARK_CHECK(interpreter); + uint64_t total_input_bytes = 0; + for (int input : interpreter->inputs()) { + auto* t = interpreter->tensor(input); + total_input_bytes += t->bytes; + } + return total_input_bytes; +} + +void BenchmarkTfLiteModel::Init() { + model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); + if (!model) { + TFLITE_LOG(FATAL) << "Failed to mmap model " << graph; + } + TFLITE_LOG(INFO) << "Loaded model " << graph; + model->error_reporter(); + TFLITE_LOG(INFO) << "resolved reporter"; + +#ifdef TFLITE_CUSTOM_OPS_HEADER + tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); +#else + tflite::ops::builtin::BuiltinOpResolver resolver; +#endif + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + TFLITE_LOG(FATAL) << "Failed to construct interpreter"; + } + profiling_listener_.SetInterpreter(interpreter.get()); + + if (params_.num_threads != -1) { + interpreter->SetNumThreads(params_.num_threads); + } + + interpreter->UseNNAPI(use_nnapi); + auto interpreter_inputs = interpreter->inputs(); + + if (!inputs.empty()) { + TFLITE_BENCHMARK_CHECK_EQ(inputs.size(), interpreter_inputs.size()) + << "Inputs mismatch: Model inputs #:" << interpreter_inputs.size() + << " expected: " << inputs.size(); + } + + // TFLITE_BENCHMARK_CHECK that all names and types match + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + TFLITE_BENCHMARK_CHECK_EQ(t->name, input.name) + << "Tensor # " << i << " is named " << t->name << " but flags call it " + << input.name; + } + + // Resize all non-string tensors. + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + if (t->type != kTfLiteString) { + interpreter->ResizeInputTensor(i, input.shape); + } + } + + if (interpreter->AllocateTensors() != kTfLiteOk) { + TFLITE_LOG(FATAL) << "Failed to allocate tensors!"; + } + + // Set the values of the input tensors. + for (int j = 0; j < inputs.size(); ++j) { + const InputLayerInfo& input = inputs[j]; + int i = interpreter_inputs[j]; + TfLiteTensor* t = interpreter->tensor(i); + std::vector sizes = input.shape; + + // TODO(ahentz): below we ignore the O-th dimension (number of batches). + if (t->type == kTfLiteFloat32) { + FillRandomValue( + interpreter->typed_tensor(i), + std::vector(sizes.begin() + 1, sizes.end()), + []() { return static_cast(rand()) / RAND_MAX - 0.5f; }); + } else if (t->type == kTfLiteUInt8) { + FillRandomValue( + interpreter->typed_tensor(i), + std::vector(sizes.begin() + 1, sizes.end()), + []() { return static_cast(rand()) % 255; }); + } else if (t->type == kTfLiteString) { + tflite::DynamicBuffer buffer; + FillRandomString(&buffer, sizes, []() { + return "we're have some friends over saturday to hang out in the yard"; + }); + buffer.WriteToTensor(interpreter->tensor(i)); + } else { + TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name + << " of type " << t->type; + } + } +} + +void BenchmarkTfLiteModel::RunImpl() { + if (interpreter->Invoke() != kTfLiteOk) { + TFLITE_LOG(FATAL) << "Failed to invoke!"; + } +} + +} // namespace benchmark +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h new file mode 100644 index 0000000000000000000000000000000000000000..ffb93da964b2da0328616e749abd9c5a84189468 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h @@ -0,0 +1,86 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_TFLITE_MODEL_H_ + +#include +#include +#include + +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/profiling/profile_summarizer.h" +#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h" + +namespace tflite { +namespace benchmark { + +// Dumps profiling events if profiling is enabled +class ProfilingListener : public BenchmarkListener { + public: + explicit ProfilingListener() : interpreter_(nullptr), has_profiles_(false) {} + + void SetInterpreter(Interpreter* interpreter); + + void OnSingleRunStart(RunType run_type) override; + + void OnSingleRunEnd() override; + + void OnBenchmarkEnd(const BenchmarkResults& results) override; + + private: + Interpreter* interpreter_; + profiling::Profiler profiler_; + profiling::ProfileSummarizer summarizer_; + bool has_profiles_; +}; + +// Benchmarks a TFLite model by running tflite interpreter. +class BenchmarkTfLiteModel : public BenchmarkModel { + public: + BenchmarkTfLiteModel() : use_nnapi(false) { + AddListener(&profiling_listener_); + } + + std::vector GetFlags() override; + void LogFlags() override; + bool ValidateFlags() override; + uint64_t ComputeInputBytes() override; + void Init() override; + void RunImpl() override; + virtual ~BenchmarkTfLiteModel() {} + + struct InputLayerInfo { + std::string name; + std::vector shape; + }; + + private: + std::unique_ptr model; + std::unique_ptr interpreter; + std::string graph; + std::string input_layer_string; + std::string input_layer_type_string; + std::string input_layer_shape_string; + std::string input_layer_values_string; + std::vector inputs; + bool use_nnapi; + ProfilingListener profiling_listener_; +}; + +} // namespace benchmark +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_BENCHMARK_TFLITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..8195fc44beb288eec3c020791b47eefa01536fb7 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.cc @@ -0,0 +1,194 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" + +#include +#include +#include +#include + +namespace tflite { +namespace { + +template +std::string ToString(T val) { + std::ostringstream stream; + stream << val; + return stream.str(); +} + +bool ParseFlag(const std::string& arg, const std::string& flag, + const std::function& parse_func, + bool* value_parsing_ok) { + *value_parsing_ok = true; + std::string flag_prefix = "--" + flag + "="; + if (arg.find(flag_prefix) != 0) { + return false; + } + bool has_value = arg.size() >= flag_prefix.size(); + *value_parsing_ok = has_value; + if (has_value) { + *value_parsing_ok = parse_func(arg.substr(flag_prefix.size())); + } + return true; +} + +template +bool ParseFlag(const std::string& flag_value, T* value) { + std::istringstream stream(flag_value); + T read_value; + stream >> read_value; + if (!stream.eof() && !stream.good()) { + return false; + } + *value = read_value; + return true; +} + +bool ParseBoolFlag(const std::string& flag_value, bool* value) { + if (flag_value != "true" && flag_value != "false") { + return false; + } + + *value = (flag_value == "true"); + return true; +} + +bool ParseStringFlag(const std::string& flag_value, std::string* value) { + *value = flag_value; + return true; +} + +} // namespace + +Flag::Flag(const char* name, int32_t* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_INT32), + value_hook_([dst](const std::string& flag_value) { + return ParseFlag(flag_value, dst); + }), + default_for_display_(ToString(*dst)), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, int64_t* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_INT64), + value_hook_([dst](const std::string& flag_value) { + return ParseFlag(flag_value, dst); + }), + default_for_display_(ToString(*dst)), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, float* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_FLOAT), + value_hook_([dst](const std::string& flag_value) { + return ParseFlag(flag_value, dst); + }), + default_for_display_(ToString(*dst)), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, bool* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_BOOL), + value_hook_([dst](const std::string& flag_value) { + return ParseBoolFlag(flag_value, dst); + }), + default_for_display_((*dst) ? "true" : "false"), + usage_text_(usage_text) {} + +Flag::Flag(const char* name, std::string* dst, const std::string& usage_text) + : name_(name), + type_(TYPE_STRING), + value_hook_([dst](const std::string& flag_value) { + return ParseStringFlag(flag_value, dst); + }), + default_for_display_(*dst), + usage_text_(usage_text) {} + +bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const { + return ParseFlag(arg, name_, value_hook_, value_parsing_ok); +} + +std::string Flag::GetTypeName() const { + switch (type_) { + case TYPE_INT32: + return "int32"; + case TYPE_INT64: + return "int64"; + case TYPE_FLOAT: + return "float"; + case TYPE_BOOL: + return "bool"; + case TYPE_STRING: + return "string"; + } + + return "unknown"; +} + +/*static*/ bool Flags::Parse(int* argc, const char** argv, + const std::vector& flag_list) { + bool result = true; + std::vector unknown_flags; + for (int i = 1; i < *argc; ++i) { + if (std::string(argv[i]) == "--") { + while (i < *argc) { + unknown_flags.push_back(argv[i]); + ++i; + } + break; + } + + bool was_found = false; + for (const Flag& flag : flag_list) { + bool value_parsing_ok; + was_found = flag.Parse(argv[i], &value_parsing_ok); + if (!value_parsing_ok) { + result = false; + } + if (was_found) { + break; + } + } + if (!was_found) { + unknown_flags.push_back(argv[i]); + } + } + int dst = 1; // Skip argv[0] + for (auto f : unknown_flags) { + argv[dst++] = f; + } + argv[dst++] = nullptr; + *argc = unknown_flags.size() + 1; + return result && (*argc < 2 || std::strcmp(argv[1], "--help") != 0); +} + +/*static*/ std::string Flags::Usage(const std::string& cmdline, + const std::vector& flag_list) { + std::ostringstream usage_text; + usage_text << "usage: " << cmdline << "\n"; + if (!flag_list.empty()) { + usage_text << "Flags:\n"; + } + + for (const Flag& flag : flag_list) { + auto type_name = flag.GetTypeName(); + usage_text << "\t"; + usage_text << "--" << flag.name_ << "=" << flag.default_for_display_; + usage_text << "\t" << type_name << "\t" << flag.usage_text_ << "\n"; + } + return usage_text.str(); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h new file mode 100644 index 0000000000000000000000000000000000000000..36f9e64767315a317338bc4d2db2ec2d43bee875 --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags.h @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_COMMAND_LINE_FLAGS_H_ + +#include +#include +#include + +namespace tflite { +// A simple command-line argument parsing module. +// Dependency free simplified port of core/util/command_line_flags. +// This class is written for benchmarks and uses inefficient string +// concatenation. This was written to avoid dependency on tensorflow/core/util +// which transitively brings in a lot of other dependencies that are not +// necessary for tflite benchmarking code. +// The recommended way of using it is with local variables and an initializer +// list of Flag objects, for example: +// +// int some_int = 10; +// bool some_switch = false; +// std::string some_name = "something"; +// std::vector flag_list = { +// Flag("some_int", &some_int, "an integer that affects X"), +// Flag("some_switch", &some_switch, "a bool that affects Y"), +// Flag("some_name", &some_name, "a std::string that affects Z") +// }; +// // Get usage message before ParseFlags() to capture default values. +// std::string usage = Flag::Usage(argv[0], flag_list); +// bool parsed_values_ok = Flags::Parse(&argc, argv, flag_list); +// +// tensorflow::port::InitMain(usage.c_str(), &argc, &argv); +// if (argc != 1 || !parsed_values_ok) { +// ...output usage and error message... +// } +// +// The argc and argv values are adjusted by the Parse function so all that +// remains is the program name (at argv[0]) and any unknown arguments fill the +// rest of the array. This means you can check for flags that weren't understood +// by seeing if argv is greater than 1. +// The result indicates if there were any errors parsing the values that were +// passed to the command-line switches. For example, --some_int=foo would return +// false because the argument is expected to be an integer. +// +// NOTE: Unlike gflags-style libraries, this library is intended to be +// used in the `main()` function of your binary. It does not handle +// flag definitions that are scattered around the source code. + +// A description of a single command line flag, holding its name, type, usage +// text, and a pointer to the corresponding variable. +class Flag { + public: + Flag(const char* name, int32_t* dst, const std::string& usage_text); + Flag(const char* name, int64_t* dst, const std::string& usage_text); + Flag(const char* name, bool* dst, const std::string& usage_text); + Flag(const char* name, std::string* dst, const std::string& usage_text); + Flag(const char* name, float* dst, const std::string& usage_text); + + private: + friend class Flags; + + bool Parse(const std::string& arg, bool* value_parsing_ok) const; + + std::string name_; + enum { + TYPE_INT32, + TYPE_INT64, + TYPE_BOOL, + TYPE_STRING, + TYPE_FLOAT, + } type_; + + std::string GetTypeName() const; + + std::function value_hook_; + std::string default_for_display_; + + std::string usage_text_; +}; + +class Flags { + public: + // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag + // instances matching flags in flaglist[]. Update the variables associated + // with matching flags, and remove the matching arguments from (*argc, argv). + // Return true iff all recognized flag values were parsed correctly, and the + // first remaining argument is not "--help". + static bool Parse(int* argc, const char** argv, + const std::vector& flag_list); + + // Return a usage message with command line cmdline, and the + // usage_text strings in flag_list[]. + static std::string Usage(const std::string& cmdline, + const std::vector& flag_list); +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_COMMAND_LINE_FLAGS_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..620d61b027d30044ba9d449a8e308375f72ad76f --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/command_line_flags_test.cc @@ -0,0 +1,166 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/tools/benchmark/command_line_flags.h" +#include +#include +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace { + +TEST(CommandLineFlagsTest, BasicUsage) { + int some_int32 = 10; + int64_t some_int64 = 21474836470; // max int32 is 2147483647 + bool some_switch = false; + std::string some_name = "something_a"; + float some_float = -23.23f; + const char* argv_strings[] = {"program_name", + "--some_int32=20", + "--some_int64=214748364700", + "--some_switch=true", + "--some_name=somethingelse", + "--some_float=42.0"}; + int argc = 6; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + { + Flag("some_int32", &some_int32, "some int32"), + Flag("some_int64", &some_int64, "some int64"), + Flag("some_switch", &some_switch, "some switch"), + Flag("some_name", &some_name, "some name"), + Flag("some_float", &some_float, "some float"), + }); + + EXPECT_EQ(true, parsed_ok); + EXPECT_EQ(20, some_int32); + EXPECT_EQ(214748364700, some_int64); + EXPECT_EQ(true, some_switch); + EXPECT_EQ("somethingelse", some_name); + EXPECT_NEAR(42.0f, some_float, 1e-5f); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, EmptyStringFlag) { + int argc = 2; + std::string some_string = "invalid"; + const char* argv_strings[] = {"program_name", "--some_string="}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag("some_string", &some_string, "some string")}); + + EXPECT_EQ(true, parsed_ok); + EXPECT_EQ(some_string, ""); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadIntValue) { + int some_int = 10; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_int=notanumber"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag("some_int", &some_int, "some int")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(10, some_int); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadBoolValue) { + bool some_switch = false; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_switch=notabool"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag("some_switch", &some_switch, "some switch")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_EQ(false, some_switch); + EXPECT_EQ(argc, 1); +} + +TEST(CommandLineFlagsTest, BadFloatValue) { + float some_float = -23.23f; + int argc = 2; + const char* argv_strings[] = {"program_name", "--some_float=notanumber"}; + bool parsed_ok = + Flags::Parse(&argc, reinterpret_cast(argv_strings), + {Flag("some_float", &some_float, "some float")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_NEAR(-23.23f, some_float, 1e-5f); + EXPECT_EQ(argc, 1); +} + +// Return whether str==pat, but allowing any whitespace in pat +// to match zero or more whitespace characters in str. +static bool MatchWithAnyWhitespace(const std::string& str, + const std::string& pat) { + bool matching = true; + int pat_i = 0; + for (int str_i = 0; str_i != str.size() && matching; str_i++) { + if (isspace(str[str_i])) { + matching = (pat_i != pat.size() && isspace(pat[pat_i])); + } else { + while (pat_i != pat.size() && isspace(pat[pat_i])) { + pat_i++; + } + matching = (pat_i != pat.size() && str[str_i] == pat[pat_i++]); + } + } + while (pat_i != pat.size() && isspace(pat[pat_i])) { + pat_i++; + } + return (matching && pat_i == pat.size()); +} + +TEST(CommandLineFlagsTest, UsageString) { + int some_int = 10; + int64_t some_int64 = 21474836470; // max int32 is 2147483647 + bool some_switch = false; + std::string some_name = "something"; + // Don't test float in this case, because precision is hard to predict and + // match against, and we don't want a flakey test. + const std::string tool_name = "some_tool_name"; + std::string usage = Flags::Usage( + tool_name + " ", {Flag("some_int", &some_int, "some int"), + Flag("some_int64", &some_int64, "some int64"), + Flag("some_switch", &some_switch, "some switch"), + Flag("some_name", &some_name, "some name")}); + // Match the usage message, being sloppy about whitespace. + const char* expected_usage = + " usage: some_tool_name \n" + "Flags:\n" + "--some_int=10\tint32\tsome int\n" + "--some_int64=21474836470\tint64\tsome int64\n" + "--some_switch=false\tbool\tsome switch\n" + "--some_name=something\tstring\tsome name\n"; + ASSERT_EQ(MatchWithAnyWhitespace(usage, expected_usage), true) << usage; + + // Again but with no flags. + usage = Flags::Usage(tool_name, {}); + ASSERT_EQ(MatchWithAnyWhitespace(usage, " usage: some_tool_name\n"), true) + << usage; +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/tools/benchmark/logging.h b/tensorflow/contrib/lite/tools/benchmark/logging.h new file mode 100644 index 0000000000000000000000000000000000000000..9e9292e2feacf0eff0751534f02cdacd21c9b0dd --- /dev/null +++ b/tensorflow/contrib/lite/tools/benchmark/logging.h @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_ +#define TENSORFLOW_CONTRIB_LITE_TOOLS_LOGGING_H_ + +// LOG and CHECK macros for benchmarks. + +#include +#include +#include + +namespace tflite { +namespace logging { +// A wrapper that logs to stderr. +// +// Used for TFLITE_LOG and TFLITE_BENCHMARK_CHECK macros. +class LoggingWrapper { + public: + enum class LogSeverity : int { + INFO = 0, + WARN = 1, + ERROR = 2, + FATAL = 3, + }; + LoggingWrapper(LogSeverity severity) + : severity_(severity), should_log_(true) {} + LoggingWrapper(LogSeverity severity, bool log) + : severity_(severity), should_log_(log) {} + std::stringstream& Stream() { return stream_; } + ~LoggingWrapper() { + if (should_log_) { + std::cerr << stream_.str() << std::endl; + if (severity_ == LogSeverity::FATAL) { + std::flush(std::cerr); + std::abort(); + } + } + } + + private: + std::stringstream stream_; + LogSeverity severity_; + bool should_log_; +}; + +} // namespace logging + +} // namespace tflite + +#define TFLITE_LOG(severity) \ + tflite::logging::LoggingWrapper( \ + tflite::logging::LoggingWrapper::LogSeverity::severity) \ + .Stream() + +#define TFLITE_BENCHMARK_CHECK(condition) \ + tflite::logging::LoggingWrapper( \ + tflite::logging::LoggingWrapper::LogSeverity::FATAL, \ + (condition) ? false : true) \ + .Stream() + +#define TFLITE_BENCHMARK_CHECK_EQ(a, b) TFLITE_BENCHMARK_CHECK(a == b) + +#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_BENCHMARK_LOGGING_H_ diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark_model.cc deleted file mode 100644 index 869c531b3e3db37f634761e7b25d4ffa1e8304a7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/lite/tools/benchmark_model.cc +++ /dev/null @@ -1,475 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include -#include -#include -#include - -#include "tensorflow/contrib/lite/kernels/register.h" -#include "tensorflow/contrib/lite/model.h" -#include "tensorflow/contrib/lite/op_resolver.h" -#include "tensorflow/contrib/lite/string_util.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/command_line_flags.h" - -#ifdef TFLITE_CUSTOM_OPS_HEADER -void RegisterSelectedOps(::tflite::MutableOpResolver* resolver); -#endif - -namespace tflite { - -using ::tensorflow::Env; -using ::tensorflow::str_util::Split; -using ::tensorflow::str_util::SplitAndParseAsFloats; -using ::tensorflow::str_util::SplitAndParseAsInts; - -struct InputLayerInfo { - string name; - TfLiteType data_type; - std::vector shape; - // Note that initialization_values is currently unused. - std::vector initialization_values; -}; - -template -void FillRandomValue(T* ptr, const std::vector& sizes, - const std::function& random_func) { - int num_elements = 1; - for (int dim : sizes) { - num_elements *= dim; - } - for (int i = 0; i < num_elements; ++i) { - *ptr++ = random_func(); - } -} - -void FillRandomString(tflite::DynamicBuffer* buffer, - const std::vector& sizes, - const std::function& random_func) { - int num_elements = 1; - for (int dim : sizes) { - num_elements *= dim; - } - for (int i = 0; i < num_elements; ++i) { - auto str = random_func(); - buffer->AddString(str.data(), str.length()); - } -} - -TfLiteType TfLiteTypeFromString(const string& input_layer_type) { - if (input_layer_type == "string") - return kTfLiteString; - else if (input_layer_type == "float") - return kTfLiteFloat32; - else if (input_layer_type == "uint8") - return kTfLiteUInt8; - else if (input_layer_type == "int32") - return kTfLiteInt32; - else if (input_layer_type == "int64") - return kTfLiteInt64; - else - return kTfLiteNoType; -} - -std::vector ShapeFromTfLiteTensor(TfLiteTensor* t) { - std::vector result; - result.reserve(t->dims->size); - for (int i = 0; i < t->dims->size; ++i) { - result.push_back(t->dims->data[i]); - } - CHECK(!result.empty()) << "Found no shapes in model"; - return result; -} - -bool CreateInterpreter(const string& graph, - std::unique_ptr* model, - std::unique_ptr* interpreter) { - *model = tflite::FlatBufferModel::BuildFromFile(graph.c_str()); - if (!model) { - std::cerr << "Failed to load model " << graph << std::endl; - return false; - } - -#ifdef TFLITE_CUSTOM_OPS_HEADER - tflite::MutableOpResolver resolver; - RegisterSelectedOps(&resolver); -#else - tflite::ops::builtin::BuiltinOpResolver resolver; -#endif - - tflite::InterpreterBuilder(*(model->get()), resolver)(interpreter); - if (!(*interpreter)) { - std::cerr << "Failed to construct interpreter" << std::endl; - return false; - } - - return true; -} - -bool PrepareInterpreter(const std::vector inputs, - int num_threads, bool use_nnapi, - Interpreter* interpreter) { - if (num_threads != -1) { - interpreter->SetNumThreads(num_threads); - } - - interpreter->UseNNAPI(use_nnapi); - - // Check that all names and types match - for (const InputLayerInfo& input : inputs) { - for (int i : interpreter->inputs()) { - TfLiteTensor* t = interpreter->tensor(i); - CHECK_EQ(t->name, input.name) - << "Tensor # " << i << " is named " << t->name - << " but flags call it " << input.name; - CHECK_EQ(t->type, input.data_type) - << "Could not match the type of input tensor " << t->name; - } - } - - // Resize all non-string tensors. - for (const InputLayerInfo& input : inputs) { - for (int i : interpreter->inputs()) { - TfLiteTensor* t = interpreter->tensor(i); - if (t->type != kTfLiteString) { - interpreter->ResizeInputTensor(i, input.shape); - } - } - } - - if (interpreter->AllocateTensors() != kTfLiteOk) { - std::cerr << "Failed to allocate tensors!" << std::endl; - return false; - } - - // Set the values of the input tensors. - for (int i : interpreter->inputs()) { - TfLiteTensor* t = interpreter->tensor(i); - std::vector sizes = ShapeFromTfLiteTensor(t); - - // TODO(ahentz): below we ignore the O-th dimension (number of batches). - if (t->type == kTfLiteFloat32) { - FillRandomValue( - interpreter->typed_tensor(i), - std::vector(sizes.begin() + 1, sizes.end()), - []() { return static_cast(rand()) / RAND_MAX - 0.5f; }); - } else if (t->type == kTfLiteUInt8) { - FillRandomValue( - interpreter->typed_tensor(i), - std::vector(sizes.begin() + 1, sizes.end()), - []() { return static_cast(rand()) % 255; }); - } else if (t->type == kTfLiteString) { - tflite::DynamicBuffer buffer; - FillRandomString(&buffer, sizes, []() { - return "we're have some friends over saturday to hang out in the yard"; - }); - buffer.WriteToTensor(interpreter->tensor(i)); - } else { - std::cerr << "Don't know how to populate tensor " << t->name - << " of type " << t->type << std::endl; - return false; - } - } - return true; -} - -bool PopulateInputLayerInfo(const string& names_string, - const string& shapes_string, - const string& types_string, - const string& values_string, - std::vector* info) { - std::vector names = Split(names_string, ','); - std::vector shapes = Split(shapes_string, ':'); - std::vector types = Split(types_string, ','); - std::vector values = Split(values_string, ':'); - - if (names.size() != shapes.size()) { - LOG(ERROR) << "The number of items in" - << " --input_layer_shape (" << shapes_string << ", with " - << shapes.size() << " items)" - << " must match the number of items in" - << " --input_layer (" << names_string << ", with " - << names.size() << " items)." - << " For example --input_layer=input1,input2" - << " --input_layer_shape=1,224,224,4:1,20"; - return false; - } - if (names.size() != types.size()) { - LOG(ERROR) << "The number of items in" - << " --input_layer_type (" << types_string << ", with " - << types.size() << " items)" - << " must match the number of items in" - << " --input_layer (" << names_string << ", with " - << names.size() << " items)." - << " For example --input_layer=input1,input2" - << " --input_layer_type=float,int"; - return false; - } - - for (int i = 0; i < names.size(); ++i) { - info->push_back(InputLayerInfo()); - InputLayerInfo& input = info->back(); - - input.name = names[i]; - - input.data_type = TfLiteTypeFromString(types[i]); - CHECK(input.data_type != kTfLiteNoType) - << types[i] << " was an invalid type"; - - CHECK(SplitAndParseAsInts(shapes[i], ',', &input.shape)) - << "Incorrect size string specified: " << shapes[i]; - for (int dim : input.shape) { - if (dim == -1) { - LOG(ERROR) << "Any unknown sizes in the shapes (-1's) must be replaced" - << " with the size you want to benchmark with."; - return false; - } - } - - if (i < values.size()) { - CHECK(SplitAndParseAsFloats(values[i], ',', &input.initialization_values)) - << "Incorrect initialization values string specified: " << values[i]; - } - } - - return true; -} - -bool RunBenchmark(Interpreter* interpreter, int64_t* inference_time_us) { - const int64_t start_time = Env::Default()->NowMicros(); - - if (interpreter->Invoke() != kTfLiteOk) { - std::cerr << "Failed to invoke!"; - return false; - } - - const int64_t end_time = Env::Default()->NowMicros(); - *inference_time_us = end_time - start_time; - return true; -} - -class Latencies { - public: - void AddMeasurement(int64_t time_us) { - max_ = std::max(time_us, max_); - min_ = std::min(time_us, min_); - ++count_; - sum_ += time_us; - squared_sum_ += static_cast(time_us) * time_us; - } - - double avg() const { - if (count_ == 0) return std::numeric_limits::quiet_NaN(); - return static_cast(sum_) / count_; - } - - int64_t std_deviation() const { - if (count_ == 0 || min_ == max_) return 0; - return sqrt(squared_sum_ / count_ - avg() * avg()); - } - - void OutputToStream(std::ostream* stream) const { - *stream << "count=" << count_; - if (count_ == 0) return; - *stream << " min=" << min_ << " max=" << max_; - *stream << " avg=" << avg() << " std=" << std_deviation(); - } - - private: - int64_t count_ = 0; - int64_t min_ = std::numeric_limits::max(); - int64_t max_ = std::numeric_limits::min(); - int64_t sum_ = 0; - double squared_sum_ = 0; -}; - -bool TimeMultipleRuns(Interpreter* interpreter, double sleep_seconds, - int num_runs, int64* total_time_us) { - // Convert the run_delay string into a timespec. - timespec req; - req.tv_sec = static_cast(sleep_seconds); - req.tv_nsec = (sleep_seconds - req.tv_sec) * 1000000000; - - *total_time_us = 0; - - std::cout << "Running benchmark for " << num_runs - << " iterations: " << std::endl; - - Latencies latencies; - for (int i = 0; i < num_runs; ++i) { - int64_t time_us; - bool run_status = RunBenchmark(interpreter, &time_us); - latencies.AddMeasurement(time_us); - *total_time_us += time_us; - if (!run_status) { - std::cout << "Failed on run " << i << std::endl; - return false; - } - - // If requested, sleep between runs for an arbitrary amount of time. - // This can be helpful to determine the effect of mobile processor - // scaling and thermal throttling. - if (sleep_seconds > 0.0) { -#ifdef PLATFORM_WINDOWS - Sleep(sleep_seconds * 1000); -#else - nanosleep(&req, nullptr); -#endif - } - } - latencies.OutputToStream(&std::cout); - std::cout << std::endl; - - return true; -} - -int Main(int argc, char** argv) { - using tensorflow::Flag; - using tensorflow::Flags; - - string graph; // e.g.: /data/local/tmp/tfl_inception-v1_model.fb - string input_layer_string; // e.g.: input - string input_layer_shape_string; // e.g.: 1,224,224,3 - string input_layer_type_string; // e.g.: float - string input_layer_values_string; - string output_layer_string; // e.g.: output - int num_runs = 50; - string run_delay = "-1.0"; - int num_threads = 1; - string benchmark_name = ""; - string output_prefix = ""; - int warmup_runs = 1; - bool use_nnapi = false; - - std::vector flag_list = { - Flag("graph", &graph, "graph file name"), - // All the following flags are optional, but can be used in order - // to benchmark different input shapes. - Flag("input_layer", &input_layer_string, "input layer names"), - Flag("input_layer_shape", &input_layer_shape_string, "input layer shape"), - Flag("input_layer_type", &input_layer_type_string, "input layer type"), - Flag("input_layer_values", &input_layer_values_string, - "values to initialize the inputs with"), - Flag("output_layer", &output_layer_string, "output layer name"), - Flag("num_runs", &num_runs, "number of runs"), - Flag("run_delay", &run_delay, "delay between runs in seconds"), - Flag("num_threads", &num_threads, "number of threads"), - Flag("benchmark_name", &benchmark_name, "benchmark name"), - Flag("output_prefix", &output_prefix, "benchmark output prefix"), - Flag("warmup_runs", &warmup_runs, "how many runs to initialize model"), - Flag("use_nnapi", &use_nnapi, "use nnapi api"), - }; - string usage = Flags::Usage(argv[0], flag_list); - const bool parse_result = Flags::Parse(&argc, argv, flag_list); - tensorflow::port::InitMain(argv[0], &argc, &argv); - - if (!parse_result) { - std::cerr << usage << std::endl; - return -1; - } - - std::cout << "Graph: [" << graph << "]" << std::endl; - if (!input_layer_string.empty()) { - std::cout << "Input layers: [" << input_layer_string << "]" << std::endl; - std::cout << "Input shapes: [" << input_layer_shape_string << "]" - << std::endl; - std::cout << "Input types: [" << input_layer_type_string << "]" - << std::endl; - } - if (!output_layer_string.empty()) { - std::cout << "Output layers: [" << output_layer_string << "]" << std::endl; - } - std::cout << "Num runs: [" << num_runs << "]" << std::endl; - std::cout << "Inter-run delay (seconds): [" << run_delay << "]" << std::endl; - std::cout << "Num threads: [" << num_threads << "]" << std::endl; - if (!benchmark_name.empty()) { - std::cout << "Benchmark name: [" << benchmark_name << "]" << std::endl; - std::cout << "Output prefix: [" << output_prefix << "]" << std::endl; - } - std::cout << "Warmup runs: [" << warmup_runs << "]" << std::endl; - std::cout << "Use nnapi : [" << use_nnapi << "]" << std::endl; - - if (graph.empty()) { - std::cout - << "Please specify the name of your TF Lite input file with --graph" - << std::endl; - return -1; - } - - std::vector inputs; - if (!PopulateInputLayerInfo(input_layer_string, input_layer_shape_string, - input_layer_type_string, - input_layer_values_string, &inputs)) { - return -1; - } - - int64 initialization_start_us = Env::Default()->NowMicros(); - - std::unique_ptr model; - std::unique_ptr interpreter; - if (!CreateInterpreter(graph, &model, &interpreter)) { - return -1; - } - if (!PrepareInterpreter(inputs, num_threads, use_nnapi, interpreter.get())) { - return -1; - } - - int64 initialization_end_us = Env::Default()->NowMicros(); - - const double initialization_time_s = - (initialization_end_us - initialization_start_us) / 1000000.0f; - std::cout << "Initialized session in " << initialization_time_s << "s" - << std::endl; - - const double sleep_seconds = std::strtod(run_delay.c_str(), nullptr); - - // If requested, run through the graph first to preinitialize everything - // before the benchmarking runs. - int64 warmup_time_us = 0; - if (warmup_runs > 0) { - if (!TimeMultipleRuns(interpreter.get(), sleep_seconds, warmup_runs, - &warmup_time_us)) { - std::cerr << "Warmup failed" << std::endl; - return -1; - } - } - - // Capture overall inference time without stat logging overhead. This is the - // timing data that can be compared to other libaries. - int64 no_stat_time_us = 0; - if (!TimeMultipleRuns(interpreter.get(), sleep_seconds, num_runs, - &no_stat_time_us)) { - std::cerr << "Timing failed." << std::endl; - return -1; - } - - std::cout << "Average inference timings in us: " << no_stat_time_us / num_runs - << " , Warmup: " - << (warmup_runs > 0 ? warmup_time_us / warmup_runs : 0) << ", " - << std::endl; - - return 0; -} - -} // namespace tflite - -int main(int argc, char** argv) { return ::tflite::Main(argc, argv); } diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc index ce8a7857d2dd66b12e9ea970911ef1dd01e4550e..ad7d59ecb41a0c81a6a4d8edae5fa6b4b5a7bede 100644 --- a/tensorflow/contrib/lite/tools/verifier_test.cc +++ b/tensorflow/contrib/lite/tools/verifier_test.cc @@ -41,7 +41,7 @@ class TfLiteFlatbufferModelBuilder { } TfLiteFlatbufferModelBuilder(const std::vector& builtin_ops, - const std::vector& custom_ops) { + const std::vector& custom_ops) { buffers_.push_back( CreateBuffer(builder_, builder_.CreateVector(std::vector{}))); @@ -194,8 +194,8 @@ TEST(VerifyModel, TensorBufferIsNotValid) { /*operators=*/0, builder.CreateString("Main"))}); auto buffers = builder.CreateVector(std::vector>{ - CreateBuffer(builder, - builder.CreateVector(std::vector{1, 2, 3, 4, 5, 6})), + CreateBuffer(builder, builder.CreateVector( + std::vector{1, 2, 3, 4, 5, 6})), }); auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, /*operator_codes=*/0, diff --git a/tensorflow/contrib/lite/util.cc b/tensorflow/contrib/lite/util.cc index fb4af07d060cac3a6a4e01c7d625b6db5241f10d..8ccb65c24fd64f05d7e2c888f7932e586c1e11ec 100644 --- a/tensorflow/contrib/lite/util.cc +++ b/tensorflow/contrib/lite/util.cc @@ -38,4 +38,14 @@ bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, return true; } +size_t CombineHashes(std::initializer_list hashes) { + size_t result = 0; + // Hash combiner used by TensorFlow core. + for (size_t hash : hashes) { + result = result ^ + (hash + 0x9e3779b97f4a7800ULL + (result << 10) + (result >> 4)); + } + return result; +} + } // namespace tflite diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h index a34db35823104414cce028b9119397da085d05b1..89d9b4f5cffa99e708f391fd8fe19208009b5e79 100644 --- a/tensorflow/contrib/lite/util.h +++ b/tensorflow/contrib/lite/util.h @@ -35,6 +35,8 @@ TfLiteIntArray* ConvertArrayToTfLiteIntArray(const int rank, const int* dims); bool EqualArrayAndTfLiteIntArray(const TfLiteIntArray* a, const int b_size, const int* b); +size_t CombineHashes(std::initializer_list hashes); + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_UTIL_H_ diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py index 5d4682ec9f4b8c5864383bd1d2f4c0b41a11baad..5a080cceabb55c307dcd1a457a9e30d24e0bd172 100644 --- a/tensorflow/contrib/lookup/lookup_ops_test.py +++ b/tensorflow/contrib/lookup/lookup_ops_test.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib import lookup from tensorflow.python.client import session +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -1396,15 +1397,22 @@ class KeyValueTensorInitializerTest(test.TestCase): class IndexTableFromTensor(test.TestCase): + @test_util.run_in_graph_and_eager_modes() def test_index_table_from_tensor_with_tensor_init(self): - with self.test_session(): + table = lookup.index_table_from_tensor( + mapping=("brain", "salad", "surgery"), num_oov_buckets=1) + + if not context.executing_eagerly(): + with self.assertRaises(errors_impl.OpError): + self.evaluate(table.lookup( + constant_op.constant(("salad", "surgery", "tarkus")))) + else: + # Reinitializing a table in eager should work. table = lookup.index_table_from_tensor( mapping=("brain", "salad", "surgery"), num_oov_buckets=1) - ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) - - self.assertRaises(errors_impl.OpError, ids.eval) - lookup_ops.tables_initializer().run() - self.assertAllEqual((1, 2, 3), ids.eval()) + self.evaluate(lookup_ops.tables_initializer()) + ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus"))) + self.assertAllEqual((1, 2, 3), self.evaluate(ids)) def test_int32_index_table_from_tensor_with_tensor_init(self): with self.test_session(): diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index eff9081e35c285027c764c5bdbaf14f78bc5f512..48953e2e3843ff92744514d28bd725cc0d72f3a8 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -27,9 +27,7 @@ if [ ! -f $BZL_FILE_PATH ]; then fi EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" -# TODO (yongtang): Replace the following with 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' once -# the archive has been propagated in mirror.bazel.build. -GEMMLOWP_URL="$(grep -o 'https://github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 00a933e5e0c537033573b225d43581f74557b240..a6be2084aae6bb05f958929b45977ed21b570603 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -1544,7 +1544,7 @@ def precision_recall_at_equal_thresholds(labels, result: A named tuple (See PrecisionRecallData within the implementation of this function) with properties that are variables of shape `[num_thresholds]`. The names of the properties are tp, fp, tn, fn, - precision, recall, thresholds. + precision, recall, thresholds. Types are same as that of predictions. update_op: An op that accumulates values. Raises: @@ -1570,7 +1570,6 @@ def precision_recall_at_equal_thresholds(labels, check_ops.assert_type(labels, dtypes.bool) - dtype = predictions.dtype with variable_scope.variable_scope(name, 'precision_recall_at_equal_thresholds', (labels, predictions, weights)): @@ -1592,11 +1591,16 @@ def precision_recall_at_equal_thresholds(labels, predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - # We cast to float to ensure we have 0.0 or 1.0. - f_labels = math_ops.cast(labels, dtype) + # It's important we aggregate using float64 since we're accumulating a lot + # of 1.0's for the true/false labels, and accumulating to float32 will + # be quite inaccurate even with just a modest amount of values (~20M). + # We use float64 instead of integer primarily since GPU scatter kernel + # only support floats. + agg_dtype = dtypes.float64 - # Get weighted true/false labels. - true_labels = f_labels * weights + f_labels = math_ops.cast(labels, agg_dtype) + weights = math_ops.cast(weights, agg_dtype) + true_labels = f_labels * weights false_labels = (1.0 - f_labels) * weights # Flatten predictions and labels. @@ -1638,9 +1642,9 @@ def precision_recall_at_equal_thresholds(labels, with ops.name_scope('variables'): tp_buckets_v = metrics_impl.metric_variable( - [num_thresholds], dtype, name='tp_buckets') + [num_thresholds], agg_dtype, name='tp_buckets') fp_buckets_v = metrics_impl.metric_variable( - [num_thresholds], dtype, name='fp_buckets') + [num_thresholds], agg_dtype, name='fp_buckets') with ops.name_scope('update_op'): update_tp = state_ops.scatter_add( @@ -1660,18 +1664,21 @@ def precision_recall_at_equal_thresholds(labels, fn = tp[0] - tp # We use a minimum to prevent division by 0. - epsilon = 1e-7 + epsilon = ops.convert_to_tensor(1e-7, dtype=agg_dtype) precision = tp / math_ops.maximum(epsilon, tp + fp) recall = tp / math_ops.maximum(epsilon, tp + fn) + # Convert all tensors back to predictions' dtype (as per function contract). + out_dtype = predictions.dtype + _convert = lambda tensor: math_ops.cast(tensor, out_dtype) result = PrecisionRecallData( - tp=tp, - fp=fp, - tn=tn, - fn=fn, - precision=precision, - recall=recall, - thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds)) + tp=_convert(tp), + fp=_convert(fp), + tn=_convert(tn), + fn=_convert(fn), + precision=_convert(precision), + recall=_convert(recall), + thresholds=_convert(math_ops.lin_space(0.0, 1.0, num_thresholds))) update_op = control_flow_ops.group(update_tp, update_fp) return result, update_op @@ -2496,7 +2503,7 @@ def _compute_recall_at_precision(tp, fp, fn, precision, name): name: An optional variable_scope name. Returns: - The recall at a the given `precision`. + The recall at a given `precision`. """ precisions = math_ops.div(tp, tp + fp + _EPSILON) tf_index = math_ops.argmin( diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 76420db8bda39435bcc2be2fd3d8c3467d6753e2..b13f08a37d9e856d56903324fc6e7cf1457bb191 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -2333,47 +2333,24 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): np.random.seed(1) ops.reset_default_graph() - def _testResultsEqual(self, expected_dict, gotten_result): + def _testResultsEqual(self, expected_dict, gotten_result, eps=None): """Tests that 2 results (dicts) represent the same data. Args: expected_dict: A dictionary with keys that are the names of properties of PrecisionRecallData and whose values are lists of floats. gotten_result: A PrecisionRecallData object. + eps: Epsilon value to use for testing output values. If unspecified, use + default from assertAllClose. """ gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()} self.assertItemsEqual(list(expected_dict.keys()), list(gotten_dict.keys())) for key, expected_values in expected_dict.items(): - self.assertAllClose(expected_values, gotten_dict[key]) - - def _testCase(self, predictions, labels, expected_result, weights=None): - """Performs a test given a certain scenario of labels, predictions, weights. - - Args: - predictions: The predictions tensor. Of type float32. - labels: The labels tensor. Of type bool. - expected_result: The expected result (dict) that maps to tensors. - weights: Optional weights tensor. - """ - with self.test_session() as sess: - predictions_tensor = constant_op.constant( - predictions, dtype=dtypes_lib.float32) - labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) - weights_tensor = None - if weights: - weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32) - gotten_result, update_op = ( - metric_ops.precision_recall_at_equal_thresholds( - labels=labels_tensor, - predictions=predictions_tensor, - weights=weights_tensor, - num_thresholds=3)) - - sess.run(variables.local_variables_initializer()) - sess.run(update_op) - - self._testResultsEqual(expected_result, gotten_result) + if eps is not None: + self.assertAllClose(expected_values, gotten_dict[key], atol=eps) + else: + self.assertAllClose(expected_values, gotten_dict[key]) def testVars(self): metric_ops.precision_recall_at_equal_thresholds( @@ -2414,6 +2391,78 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): for _ in range(3): self._testResultsEqual(initial_result, result) + def testLargeCase(self): + self.skipTest("Test consistently timing out") + shape = [32, 512, 256, 1] + predictions = random_ops.random_uniform( + shape, 0.0, 1.0, dtype=dtypes_lib.float32) + labels = math_ops.greater(random_ops.random_uniform(shape, 0.0, 1.0), 0.5) + + result, update_op = metric_ops.precision_recall_at_equal_thresholds( + labels=labels, predictions=predictions, num_thresholds=201) + # Run many updates, enough to cause highly inaccurate values if the + # code used float32 for accumulation. + num_updates = 71 + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_updates): + sess.run(update_op) + + prdata = sess.run(result) + + # Since we use random values, we won't know the tp/fp/tn/fn values, but + # tp and fp at threshold 0 should be the total number of positive and + # negative labels, hence their sum should be total number of pixels. + expected_value = 1.0 * np.product(shape) * num_updates + got_value = prdata.tp[0] + prdata.fp[0] + # They should be at least within 1. + self.assertNear(got_value, expected_value, 1.0) + + def _testCase(self, + predictions, + labels, + expected_result, + dtype=dtypes_lib.float32, + eps=None, + weights=None): + """Performs a test given a certain scenario of labels, predictions, weights. + + Args: + predictions: The predictions tensor. Of type dtype. + labels: The labels tensor. Of type bool. + expected_result: The expected result (dict) that maps to tensors. + dtype: Data type to use for predictions and weights tensor. Default + is float32. + eps: Epsilon value to use for testing output values. If unspecified, use + default from assertAllClose. + weights: Optional weights tensor. + """ + with self.test_session() as sess: + predictions_tensor = constant_op.constant(predictions, dtype=dtype) + labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) + weights_tensor = None + if weights: + weights_tensor = constant_op.constant(weights, dtype=dtype) + gotten_result, update_op = ( + metric_ops.precision_recall_at_equal_thresholds( + labels=labels_tensor, + predictions=predictions_tensor, + weights=weights_tensor, + num_thresholds=3)) + self.assertEqual(gotten_result.tp.dtype, dtype) + self.assertEqual(gotten_result.fp.dtype, dtype) + self.assertEqual(gotten_result.tn.dtype, dtype) + self.assertEqual(gotten_result.fn.dtype, dtype) + self.assertEqual(gotten_result.precision.dtype, dtype) + self.assertEqual(gotten_result.recall.dtype, dtype) + self.assertEqual(gotten_result.thresholds.dtype, dtype) + + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + + self._testResultsEqual(expected_result, gotten_result, eps=eps) + def testAllTruePositives(self): self._testCase( [[1]], [[True]], { @@ -2489,6 +2538,35 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): }, weights=[[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]]) + def testFloat64(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], { + 'tp': [4, 3, 0], + 'fp': [2, 0, 0], + 'tn': [0, 2, 2], + 'fn': [0, 1, 4], + 'precision': [2.0 / 3.0, 1.0, 0.0], + 'recall': [1.0, 0.75, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }, + dtype=dtypes_lib.float64) + + def testFloat16(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], { + 'tp': [4, 3, 0], + 'fp': [2, 0, 0], + 'tn': [0, 2, 2], + 'fn': [0, 1, 4], + 'precision': [2.0 / 3.0, 1.0, 0.0], + 'recall': [1.0, 0.75, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }, + dtype=dtypes_lib.float16, + eps=1e-3) + class StreamingSpecificityAtSensitivityTest(test.TestCase): @@ -7101,6 +7179,14 @@ class CohenKappaTest(test.TestCase): with self.assertRaises(ValueError): metrics.cohen_kappa(labels, invalid_predictions, 3) + def testConditionalPackingOptimization(self): + placeholder = array_ops.placeholder(dtypes_lib.float32, [None]) + values, update_op = metric_ops.streaming_concat(placeholder) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for feed in range(10): + sess.run(update_op, feed_dict={placeholder: [feed]}) + print(sess.run(values)) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py index e4e5ccc33472ad5a12bd8111fb1ff6ebbd6f45f9..ef34f7bf7bf3eba047b50ce8abf883b0ed741a63 100644 --- a/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py +++ b/tensorflow/contrib/mixed_precision/python/loss_scale_optimizer.py @@ -26,26 +26,32 @@ from tensorflow.python.training import optimizer class LossScaleOptimizer(optimizer.Optimizer): + # TODO(jamesqin): move mixed precision training explanation to __init__ + # docstring. """An optimizer that applies loss scaling in backprop. - This class is useful for mixed precision training on GPUs (or other potential - accelerators), which is an approach to improve compute throughput without loss - of model quality. - - The commmon configuration of mixed precision models is the following: - * variables are kept in high precision (e.g. float32). - * computations are done in lower precision (e.g. float16). variables are - casted to lower precision before they're used. - * (in training), final gradients are casted back to variable precision and get - applied. - - Because computations happen in lower precision, gradients in the backprop pass - might underflow in the smaller dynamic range, causing a model to converge at a - suboptimal level. This optimizer multiplies the loss by a factor before - backprop starts to prevent underflow. Before gradients are applied, they are - casted to higher precision and down-scaled by the same factor, so - mathematically the variable updates are no different from regular - same-precision training. + This class is useful for "mixed precision training" on GPUs (or other + potential accelerators), an approach to improve compute throughput without + compromising model quality. + + The canonical way to perform mixed precision training is the following: + * Model variables are kept in high precision (e.g. float32). + * Computations are done in lower precision (e.g. float16), which enjoys + performance speedup by virtue of hardware support. Variables are casted to + lower precision before they're used. + * Final gradients are casted back to high precision dtype, then used to update + variables. + + The side-effect of performing computation in lower precision, is that it comes + with smaller numerical range. During backproping, small gradients might + underflow in the reduced numerical range, causing a model to converge at + suboptimal level. + + To prevent underflow, this optimizer multiplies the loss by a factor before + backprop starts. Consequently, the gradients are linearly scaled up by the + same factor, thus not falling into the underflow zone. After that, to perserve + the correctness of backprop, the gradients are down-scaled by the same factor, + casted to the (higher) variable precision, then applied on the variables. See [Nvidia's manual on mixed precision training]( https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc index 4d8d922cb42d2974dab32cf4562bee3993bef098..5144f7c38c8650ebfced1dfcc9378263ebaad8c0 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc @@ -171,8 +171,7 @@ class NcclManagerTest : public ::testing::Test { private: static Allocator* GpuAllocator(BaseGPUDevice* device) { - return device->GetStepAllocator(AllocatorAttributes(), - nullptr /* step_resource_manager */); + return device->GetAllocator(AllocatorAttributes()); } static se::DeviceMemory AsDeviceMemory(const Scalar* cuda_memory) { diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index 13aa1d7e7a11877373a848c1ba865aa418790cd0..114b344d38413208755a47f36f45badc1a5ecaa9 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -28,6 +28,7 @@ py_library( "python/training/reg_adagrad_optimizer.py", "python/training/sign_decay.py", "python/training/variable_clipping_optimizer.py", + "python/training/weight_decay_optimizers.py", ], srcs_version = "PY2AND3", deps = [ @@ -194,6 +195,25 @@ py_test( ], ) +py_test( + name = "weight_decay_optimizers_test", + srcs = ["python/training/weight_decay_optimizers_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":opt_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:session", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "drop_stale_gradient_optimizer_test", srcs = ["python/training/drop_stale_gradient_optimizer_test.py"], diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py index 4c13c8e247185213b798eb733ddcf65a07a8f64d..5df5d35f8e4f8fcc2c5aa09bd8f3254e16e3a74f 100644 --- a/tensorflow/contrib/opt/__init__.py +++ b/tensorflow/contrib/opt/__init__.py @@ -27,6 +27,7 @@ from tensorflow.contrib.opt.python.training.lazy_adam_optimizer import * from tensorflow.contrib.opt.python.training.moving_average_optimizer import * from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import * from tensorflow.contrib.opt.python.training.nadam_optimizer import * +from tensorflow.contrib.opt.python.training.weight_decay_optimizers import * from tensorflow.contrib.opt.python.training.powersign import * from tensorflow.contrib.opt.python.training.variable_clipping_optimizer import * from tensorflow.contrib.opt.python.training.elastic_average_optimizer import * @@ -46,6 +47,10 @@ _allowed_symbols = [ 'LazyAdamOptimizer', 'NadamOptimizer', 'MovingAverageOptimizer', + 'MomentumWOptimizer', + 'AdamWOptimizer', + 'DecoupledWeightDecayExtension', + 'extend_with_decoupled_weight_decay', 'ScipyOptimizerInterface', 'VariableClippingOptimizer', 'MultitaskOptimizerWrapper', diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py index bc92a7006f1a0a56adafc486a75afa94e965cb2c..915e6504e1e59ff247a2715820bc31a4d4cc1944 100644 --- a/tensorflow/contrib/opt/python/training/adamax_test.py +++ b/tensorflow/contrib/opt/python/training/adamax_test.py @@ -198,11 +198,11 @@ class AdaMaxOptimizerTest(test.TestCase): self.assertTrue(beta1_power is not None) self.assertIn(beta1_power, opt_variables) - with ops.Graph().as_default(): - # Shouldn't return non-slot variables from other graphs. - self.assertEqual(0, len(opt.variables())) - if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) # Fetch params to validate initial values self.assertAllClose([1.0, 2.0], self.evaluate(var0)) @@ -224,8 +224,10 @@ class AdaMaxOptimizerTest(test.TestCase): var1_np, m1, v1 = adamax_update_numpy(var1_np, grads1_np, t, m1, v1) # Validate updated params - self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) - self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0), + rtol=1e-2) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1), + rtol=1e-2) if use_resource: self.assertEqual("var0_%d/AdaMax:0" % (i,), opt.get_slot(var=var0, name="m").name) diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa40aeb45d4ec15140bdfc5ebd824e8aa08d8d9 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py @@ -0,0 +1,326 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Base class to make optimizers weight decay ready.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.training import optimizer +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.training import adam +from tensorflow.python.training import momentum as momentum_opt +from tensorflow.python.util.tf_export import tf_export +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import resource_variable_ops + + +class DecoupledWeightDecayExtension(object): + """This class allows to extend optimizers with decoupled weight decay. + + It implements the decoupled weight decay described by Loshchilov & Hutter + (https://arxiv.org/pdf/1711.05101.pdf), in which the weight decay is + decoupled from the optimization steps w.r.t. to the loss function. + For SGD variants, this simplifies hyperparameter search since it decouples + the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + This class alone is not an optimizer but rather extends existing + optimizers with decoupled weight decay. We explicitly define the two examples + used in the above paper (SGDW and AdamW), but in general this can extend + any OptimizerX by using + `extend_with_weight_decay(OptimizerX, weight_decay=weight_decay)`. + In order for it to work, it must be the first class the Optimizer with + weight decay inherits from, e.g. + + ```python + class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + def __init__(self, weight_decay, *args, **kwargs): + super(AdamWOptimizer, self).__init__(weight_decay, *args, **kwargs). + ``` + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + """ + + def __init__(self, weight_decay, **kwargs): + """Construct the extension class that adds weight decay to an optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value, the factor by which + a variable is decayed in the update step. + decay_var_list: Optional list or tuple or set of `Variable` objects to + decay. + """ + self._decay_var_list = None # is set in minimize or apply_gradients + self._weight_decay = weight_decay + # The tensors are initialized in call to _prepare + self._weight_decay_tensor = None + super(DecoupledWeightDecayExtension, self).__init__(**kwargs) + + def minimize(self, loss, global_step=None, var_list=None, + gate_gradients=optimizer.Optimizer.GATE_OP, + aggregation_method=None, colocate_gradients_with_ops=False, + name=None, grad_loss=None, decay_var_list=None): + """Add operations to minimize `loss` by updating `var_list` with decay. + + This function is the same as Optimizer.minimize except that it allows to + specify the variables that should be decayed using decay_var_list. + If decay_var_list is None, all variables in var_list are decayed. + + For more information see the documentation of Optimizer.minimize. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).minimize( + loss, global_step=global_step, var_list=var_list, + gate_gradients=gate_gradients, aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, name=name, + grad_loss=grad_loss) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None, + decay_var_list=None): + """Apply gradients to variables and decay the variables. + + This function is the same as Optimizer.apply_gradients except that it + allows to specify the variables that should be decayed using + decay_var_list. If decay_var_list is None, all variables in var_list + are decayed. + + For more information see the documentation of Optimizer.apply_gradients. + """ + self._decay_var_list = set(decay_var_list) if decay_var_list else False + return super(DecoupledWeightDecayExtension, self).apply_gradients( + grads_and_vars, global_step=global_step, name=name) + + def _prepare(self): + weight_decay = self._weight_decay + if callable(weight_decay): + weight_decay = weight_decay() + self._weight_decay_tensor = ops.convert_to_tensor( + weight_decay, name="weight_decay") + # Call the optimizers _prepare function. + super(DecoupledWeightDecayExtension, self)._prepare() + + def _decay_weights_op(self, var): + if not self._decay_var_list or var in self._decay_var_list: + return var.assign_sub(self._weight_decay * var, self._use_locking) + return control_flow_ops.no_op() + + def _decay_weights_sparse_op(self, var, indices, scatter_add): + if not self._decay_var_list or var in self._decay_var_list: + return scatter_add(var, indices, -self._weight_decay * var, + self._use_locking) + return control_flow_ops.no_op() + + # Here, we overwrite the apply functions that the base optimizer calls. + # super().apply_x resolves to the apply_x function of the BaseOptimizer. + def _apply_dense(self, grad, var): + with ops.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, self)._apply_dense(grad, var) + + def _resource_apply_dense(self, grad, var): + with ops.control_dependencies([self._decay_weights_op(var)]): + return super(DecoupledWeightDecayExtension, self)._resource_apply_dense( + grad, var) + + def _apply_sparse(self, grad, var): + scatter_add = state_ops.scatter_add + decay_op = self._decay_weights_sparse_op(var, grad.indices, scatter_add) + with ops.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, self)._apply_sparse( + grad, var) + + def _resource_scatter_add(self, x, i, v, _=None): + # last argument allows for one overflow argument, to have the same function + # signature as state_ops.scatter_add + with ops.control_dependencies( + [resource_variable_ops.resource_scatter_add(x.handle, i, v)]): + return x.value() + + def _resource_apply_sparse(self, grad, var, indices): + scatter_add = self._resource_scatter_add + decay_op = self._decay_weights_sparse_op(var, indices, scatter_add) + with ops.control_dependencies([decay_op]): + return super(DecoupledWeightDecayExtension, self)._resource_apply_sparse( + grad, var, indices) + + +def extend_with_decoupled_weight_decay(base_optimizer): + """Factory function returning an optimizer class with decoupled weight decay. + + Returns an optimizer class. An instance of the returned class computes the + update step of `base_optimizer` and additionally decays the weights. + E.g., the class returned by + `extend_with_decoupled_weight_decay(tf.train.AdamOptimizer)` is equivalent to + `tf.contrib.opt.AdamWOptimizer`. + + The API of the new optimizer class slightly differs from the API of the + base optimizer: + - The first argument to the constructor is the weight decay rate. + - `minimize` and `apply_gradients` accept the optional keyword argument + `decay_var_list`, which specifies the variables that should be decayed. + If `None`, all variables that are optimized are decayed. + + Usage example: + ```python + # MyAdamW is a new class + MyAdamW = extend_with_decoupled_weight_decay(tf.train.AdamOptimizer) + # Create a MyAdamW object + optimizer = MyAdamW(weight_decay=0.001, learning_rate=0.001) + sess.run(optimizer.minimize(loss, decay_variables=[var1, var2])) + + Note that this extension decays weights BEFORE applying the update based + on the gradient, i.e. this extension only has the desired behaviour for + optimizers which do not depend on the value of'var' in the update step! + ``` + + Args: + base_optimizer: An optimizer class that inherits from tf.train.Optimizer. + + Returns: + A new optimizer class that inherits from DecoupledWeightDecayExtension + and base_optimizer. + """ + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecayExtension, + base_optimizer): + """Base_optimizer with decoupled weight decay. + + This class computes the update step of `base_optimizer` and + additionally decays the variable with the weight decay being decoupled from + the optimization steps w.r.t. to the loss function, as described by + Loshchilov & Hutter (https://arxiv.org/pdf/1711.05101.pdf). + For SGD variants, this simplifies hyperparameter search since + it decouples the settings of weight decay and learning rate. + For adaptive gradient algorithms, it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield + better training loss and generalization error in the paper above. + """ + + def __init__(self, weight_decay, *args, **kwargs): + # super delegation is necessary here + # pylint: disable=useless-super-delegation + super(OptimizerWithDecoupledWeightDecay, self).__init__( + weight_decay, *args, **kwargs) + # pylint: enable=useless-super-delegation + + return OptimizerWithDecoupledWeightDecay + + +@tf_export("contrib.opt.MomentumWOptimizer") +class MomentumWOptimizer(DecoupledWeightDecayExtension, + momentum_opt.MomentumOptimizer): + """Optimizer that implements the Momentum algorithm with weight_decay. + + This is an implementation of the SGDW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + It computes the update step of `train.MomentumOptimizer` and additionally + decays the variable. Note that this is different from adding + L2 regularization on the variables to the loss. Decoupling the weight decay + from other hyperparameters (in particular the learning rate) simplifies + hyperparameter search. + + For further information see the documentation of the Momentum Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.train.MomentumOptimizer, + weight_decay=weight_decay) + ``` + """ + + def __init__(self, weight_decay, learning_rate, momentum, + use_locking=False, name="MomentumW", use_nesterov=False): + """Construct a new MomentumW optimizer. + + For further information see the documentation of the Momentum Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + learning_rate: A `Tensor` or a floating point value. The learning rate. + momentum: A `Tensor` or a floating point value. The momentum. + use_locking: If `True` use locks for update operations. + name: Optional name prefix for the operations created when applying + gradients. Defaults to "Momentum". + use_nesterov: If `True` use Nesterov Momentum. + See [Sutskever et al., 2013]( + http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). + This implementation always computes gradients at the value of the + variable(s) passed to the optimizer. Using Nesterov Momentum makes the + variable(s) track the values called `theta_t + mu*v_t` in the paper. + + @compatibility(eager) + When eager execution is enabled, learning_rate, weight_decay and momentum + can each be a callable that takes no arguments and returns the actual value + to use. This can be useful for changing these values across different + invocations of optimizer functions. + @end_compatibility + """ + super(MomentumWOptimizer, self).__init__( + weight_decay, learning_rate=learning_rate, momentum=momentum, + use_locking=use_locking, name=name, use_nesterov=use_nesterov) + + +@tf_export("contrib.opt.AdamWOptimizer") +class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer): + """Optimizer that implements the Adam algorithm with weight decay. + + This is an implementation of the AdamW optimizer described in "Fixing + Weight Decay Regularization in Adam" by Loshchilov & Hutter + (https://arxiv.org/abs/1711.05101) + ([pdf])(https://arxiv.org/pdf/1711.05101.pdf). + + It computes the update step of `train.AdamOptimizer` and additionally decays + the variable. Note that this is different from adding L2 regularization on + the variables to the loss: it regularizes variables with large + gradients more than L2 regularization would, which was shown to yield better + training loss and generalization error in the paper above. + + For further information see the documentation of the Adam Optimizer. + + Note that this optimizer can also be instantiated as + ```python + extend_with_weight_decay(tf.train.AdamOptimizer, weight_decay=weight_decay) + ``` + """ + + def __init__(self, weight_decay, learning_rate=0.001, beta1=0.9, beta2=0.999, + epsilon=1e-8, use_locking=False, name="AdamW"): + """Construct a new AdamW optimizer. + + For further information see the documentation of the Adam Optimizer. + + Args: + weight_decay: A `Tensor` or a floating point value. The weight decay. + learning_rate: A Tensor or a floating point value. The learning rate. + beta1: A float value or a constant float tensor. + The exponential decay rate for the 1st moment estimates. + beta2: A float value or a constant float tensor. + The exponential decay rate for the 2nd moment estimates. + epsilon: A small constant for numerical stability. This epsilon is + "epsilon hat" in the Kingma and Ba paper (in the formula just before + Section 2.1), not the epsilon in Algorithm 1 of the paper. + use_locking: If True use locks for update operations. + name: Optional name for the operations created when applying gradients. + Defaults to "Adam". + """ + super(AdamWOptimizer, self).__init__( + weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2, + epsilon=epsilon, use_locking=use_locking, name=name) diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..74d1cdbbdac8724518937d141a976abf9fec6ce3 --- /dev/null +++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers_test.py @@ -0,0 +1,190 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for optimizers with weight decay.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import adam +from tensorflow.contrib.opt.python.training import weight_decay_optimizers + +WEIGHT_DECAY = 0.01 + + +def adamw_update_numpy(param, g_t, t, m, v, lr=0.001, beta1=0.9, + beta2=0.999, epsilon=1e-8): + lr_t = lr * np.sqrt(1 - beta2**t) / (1 - beta1**t) + + m_t = beta1 * m + (1 - beta1) * g_t + v_t = beta2 * v + (1 - beta2) * g_t * g_t + + param_t = (param - lr_t * m_t / (np.sqrt(v_t) + epsilon) - + (param * WEIGHT_DECAY)) + return param_t, m_t, v_t + + +def momentumw_update_numpy(param, g_t, m, lr=0.001, momentum=0.9, **_): + # v, t are not needed for momentum optimizer + m = momentum * m + g_t + param_t = param - lr * m - param * WEIGHT_DECAY + return param_t, m, None + + +class WeightDecayOptimizerTest(test.TestCase): + + def doTest(self, optimizer, update_fn, optimizer_name, slot_name, + use_resource=False, do_sparse=False): + for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]): + with self.test_session(graph=ops.Graph()): + # Initialize variables for numpy implementation. + m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0 + var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype) + grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype) + var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype) + grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype) + + if use_resource: + var0 = resource_variable_ops.ResourceVariable( + var0_np, name="var0_%d" % i) + var1 = resource_variable_ops.ResourceVariable( + var1_np, name="var1_%d" % i) + else: + var0 = variables.Variable(var0_np) + var1 = variables.Variable(var1_np) + + if do_sparse: + grads0_np_indices = np.array([0, 1], dtype=np.int32) + grads0 = ops.IndexedSlices(constant_op.constant(grads0_np), + constant_op.constant(grads0_np_indices), + constant_op.constant([2])) + grads1_np_indices = np.array([0, 1], dtype=np.int32) + grads1 = ops.IndexedSlices(constant_op.constant(grads1_np), + constant_op.constant(grads1_np_indices), + constant_op.constant([2])) + else: + grads0 = constant_op.constant(grads0_np) + grads1 = constant_op.constant(grads1_np) + + opt = optimizer() + update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + + if not context.executing_eagerly(): + with ops.Graph().as_default(): + # Shouldn't return non-slot variables from other graphs. + self.assertEqual(0, len(opt.variables())) + self.evaluate(variables.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + self.assertAllClose([3.0, 4.0], self.evaluate(var1)) + + # Run 3 steps of the optimizer + for t in range(1, 4): + if not context.executing_eagerly(): + self.evaluate(update) + elif t > 1: + opt.apply_gradients(zip([grads0, grads1], [var0, var1])) + + var0_np, m0, v0 = update_fn(var0_np, grads0_np, t=t, m=m0, v=v0) + var1_np, m1, v1 = update_fn(var1_np, grads1_np, t=t, m=m1, v=v1) + + # Validate updated params + self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0)) + self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1)) + if use_resource: + self.assertEqual("var0_%d/%s:0" % (i, optimizer_name), + opt.get_slot(var=var0, name=slot_name).name) + + +class AdamWOptimizerTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + return weight_decay_optimizers.AdamWOptimizer(WEIGHT_DECAY) + + def testSparse(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=False, do_sparse=True) + + def testResourceSparse(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=True, do_sparse=True) + + def testBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "AdamW", "m", + use_resource=True) + + +class MomentumWOptimizerTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + return weight_decay_optimizers.MomentumWOptimizer(WEIGHT_DECAY, 0.001, 0.9) + + def testSparse(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=False, do_sparse=True) + + def testResourceSparse(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=True, do_sparse=True) + + def testBasic(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, momentumw_update_numpy, "MomentumW", + "momentum", use_resource=True) + + +class ExtendWithWeightDecayTest(WeightDecayOptimizerTest): + + @staticmethod + def get_optimizer(): + AdamW = weight_decay_optimizers.extend_with_decoupled_weight_decay( + adam.AdamOptimizer) + return AdamW(WEIGHT_DECAY) + + def testBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", + use_resource=False) + + @test_util.run_in_graph_and_eager_modes(reset_test=True) + def testResourceBasic(self): + self.doTest(self.get_optimizer, adamw_update_numpy, "Adam", "m", + use_resource=True) + + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/optimizer_v2/momentum_test.py b/tensorflow/contrib/optimizer_v2/momentum_test.py index 26724f66c2a1db1d01577b31b739af18f51d3976..24cdab462665adc6297b0e0821455a545c3880af 100644 --- a/tensorflow/contrib/optimizer_v2/momentum_test.py +++ b/tensorflow/contrib/optimizer_v2/momentum_test.py @@ -134,7 +134,6 @@ class MomentumOptimizerTest(test.TestCase): with context.eager_mode(): self.doTestBasic(use_resource=True, use_callable_params=True) - @test_util.run_in_graph_and_eager_modes(reset_test=True) def testVariablesAcrossGraphs(self): optimizer = momentum_lib.MomentumOptimizer(0.01, 0.5) with ops.Graph().as_default(): @@ -142,10 +141,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var0") var1 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var1") - if context.executing_eagerly(): - loss = lambda: math_ops.reduce_sum(var0 + var1) - else: - loss = math_ops.reduce_sum(var0 + var1) + loss = math_ops.reduce_sum(var0 + var1) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var0") @@ -157,10 +153,7 @@ class MomentumOptimizerTest(test.TestCase): [1.0, 2.0], dtype=dtypes.float32, name="var2") var3 = resource_variable_ops.ResourceVariable( [3.0, 4.0], dtype=dtypes.float32, name="var3") - if context.executing_eagerly(): - loss = lambda: math_ops.reduce_sum(var2 + var3) - else: - loss = math_ops.reduce_sum(var2 + var3) + loss = math_ops.reduce_sum(var2 + var3) optimizer.minimize(loss) optimizer_variables = optimizer.variables() self.assertStartsWith(optimizer_variables[0].name, "var2") diff --git a/tensorflow/contrib/periodic_resample/BUILD b/tensorflow/contrib/periodic_resample/BUILD index 6ca7fe8b6e59b0dc24be76262d4f54f387e53e48..976b312e8345a801ad07f622b6117b88af2cf603 100644 --- a/tensorflow/contrib/periodic_resample/BUILD +++ b/tensorflow/contrib/periodic_resample/BUILD @@ -6,12 +6,13 @@ exports_files(["LICENSE"]) load( "//tensorflow:tensorflow.bzl", - "py_test", + "tf_cc_test", "tf_gen_op_libs", "tf_custom_op_library", "tf_custom_op_py_library", "tf_gen_op_wrapper_py", ) +load("//tensorflow:tensorflow.bzl", "py_test") cc_library( name = "all_ops", @@ -84,6 +85,20 @@ py_test( ":init_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradient_checker", + ], +) + +tf_cc_test( + name = "periodic_resample_op_cc_test", + size = "small", + srcs = [ + "ops/array_ops_test.cc", + ], + deps = [ + ":all_ops", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc index e18923c8aae74c66ce78f98eb5e615e99463af74..514689cf4543cd08632bd0321a78fa933c456467 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc @@ -22,4 +22,9 @@ namespace tensorflow { REGISTER_KERNEL_BUILDER(Name("PeriodicResample").Device(DEVICE_CPU), PeriodicResampleOp); + +REGISTER_KERNEL_BUILDER(Name("PeriodicResampleOpGrad") + .Device(DEVICE_CPU), + PeriodicResampleOpGrad); + } // namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h index 3ab588c45881c8f93b4c1bcdf7ccde39086a1ed7..42fba81a5cb9490c093062048f269704a110756a 100644 --- a/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h +++ b/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h @@ -25,92 +25,202 @@ #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/work_sharder.h" namespace { -template -IndexT compute_input_index( - IndexVecT* target_dimensions, const IndexT& output_index, - const IndexVecT& original_dimensions, const int& adjustable_dimension, - const std::vector& dimension_ceiling, - const std::vector& cumulative_dimensions, IndexT* result, - std::vector* output_indices, const int& rank) { - *result = 0; - output_indices->clear(); +// Computes input tensor index for given output index during forward +// propagation through periodic_resample operation. +class InputIndexer { + public: + InputIndexer(const std::vector& output_dimensions, + const tensorflow::TensorShape& input_shape, + int adjustable_dimension) + : output_dimensions_(output_dimensions), + adjustable_dimension_(adjustable_dimension), + rank_(input_shape.dims()), + linear_output_index_(0), + linear_input_index_(0), + adjustable_dimension_carriage_sum_(0) { + auto input_dimensions = TensorShapeToVector(input_shape); + // factors by which input_dimensions increases/decreases w.r.t. + // output_dimensions + dimension_ceiling_ = + ComputeDimensionCeiling(output_dimensions, input_dimensions); + cumulative_dimensions_ = ComputeCumulativeDimensions(); + + output_indices_.resize(output_dimensions_.size()); + input_indices_.resize(output_dimensions_.size()); + + // Compute index_factors + index_factors_.resize(rank_); + tensorflow::int64 last_index_factor = 1; + for (auto r = rank_ - 1; r >= 0; --r) { + index_factors_[r] = last_index_factor; + last_index_factor *= input_dimensions[r]; + } + } + + tensorflow::int64 linear_input_index() const { return linear_input_index_; } + + void MoveToOutputIndex(tensorflow::int64 output_index); + void IncrementOutputIndex(); + + private: + void RecomputeInputAdjustableDimensionIndex() { + tensorflow::int64 index = adjustable_dimension_carriage_sum_; + index *= output_dimensions_[adjustable_dimension_]; + index += output_indices_[adjustable_dimension_]; + input_indices_[adjustable_dimension_] = index; + } + + std::vector TensorShapeToVector( + const tensorflow::TensorShape& tensor_shape); + + std::vector ComputeDimensionCeiling( + const std::vector& output_dimensions, + const std::vector& input_dimensions); + + std::vector ComputeCumulativeDimensions(); + + const std::vector output_dimensions_; + std::vector dimension_ceiling_; + std::vector index_factors_; + std::vector cumulative_dimensions_; + std::vector output_indices_; + std::vector input_indices_; + + const int adjustable_dimension_; + const int rank_; + tensorflow::int64 linear_output_index_; + tensorflow::int64 linear_input_index_; + tensorflow::int64 adjustable_dimension_carriage_sum_; +}; + +void InputIndexer::MoveToOutputIndex(tensorflow::int64 output_index) { + linear_output_index_ = output_index; + linear_input_index_ = 0; // un-rasterize the output index auto last_reduced_i = output_index; - for (auto r = rank - 1; r >= 0; --r) { - (*output_indices)[r] = last_reduced_i % (*target_dimensions)[r]; + for (auto r = rank_ - 1; r >= 0; --r) { + output_indices_[r] = last_reduced_i % output_dimensions_[r]; last_reduced_i = - (last_reduced_i - (*output_indices)[r]) / (*target_dimensions)[r]; + (last_reduced_i - output_indices_[r]) / output_dimensions_[r]; } + tensorflow::int64 carriage_sum = 0; + for (int qi = 0; qi < rank_; ++qi) { + if (qi == adjustable_dimension_) continue; + carriage_sum += cumulative_dimensions_[qi] * + (output_indices_[qi] % dimension_ceiling_[qi]); + } + adjustable_dimension_carriage_sum_ = carriage_sum; + // rasterize the input index - IndexT last_index_factor = 1; - for (auto r = rank - 1; r >= 0; --r) { - IndexT index = 0; - if (r != adjustable_dimension) - index = (*output_indices)[r] / dimension_ceiling[r]; - else { - for (int qi = 0; qi < rank; ++qi) { - if (qi == adjustable_dimension) continue; - index += cumulative_dimensions[qi] * - ((*output_indices)[qi] % dimension_ceiling[qi]); - } - index *= (*target_dimensions)[adjustable_dimension]; - index += (*output_indices)[r]; + for (auto r = rank_ - 1; r >= 0; --r) { + if (r != adjustable_dimension_) { + input_indices_[r] = output_indices_[r] / dimension_ceiling_[r]; + } else { + RecomputeInputAdjustableDimensionIndex(); } - *result += last_index_factor * index; - last_index_factor *= original_dimensions[r]; } + for (auto r = rank_ - 1; r >= 0; --r) { + linear_input_index_ += index_factors_[r] * input_indices_[r]; + } +} + +void InputIndexer::IncrementOutputIndex() { + linear_output_index_++; + for (auto r = rank_ - 1; r >= 0; --r) { + auto old_carriage_sum_increment = + cumulative_dimensions_[r] * + (output_indices_[r] % dimension_ceiling_[r]); + output_indices_[r] = (output_indices_[r] + 1) % output_dimensions_[r]; + if (r != adjustable_dimension_) { + auto new_input_index = output_indices_[r] / dimension_ceiling_[r]; + linear_input_index_ += + (new_input_index - input_indices_[r]) * index_factors_[r]; + + input_indices_[r] = new_input_index; + + auto new_carriage_sum_increment = + cumulative_dimensions_[r] * + (output_indices_[r] % dimension_ceiling_[r]); - return *result; + adjustable_dimension_carriage_sum_ = adjustable_dimension_carriage_sum_ - + old_carriage_sum_increment + + new_carriage_sum_increment; + } + + if (output_indices_[r] != 0) { + // No more carries to higher indices. + break; + } + } + auto old_adjustable_dimension_input_index = + input_indices_[adjustable_dimension_]; + RecomputeInputAdjustableDimensionIndex(); + linear_input_index_ += (input_indices_[adjustable_dimension_] - + old_adjustable_dimension_input_index) * + index_factors_[adjustable_dimension_]; } -template // both types are needed here b/c IndexVecT and - // InputDataT are not related - void - fill_periodic_tensor( - tensorflow::OpKernelContext* context, - const IndexVecT& desired_shape, - const tensorflow::Tensor& input_tensor) { - // input is a strided array (last index is fastest, C-ordered) - auto input = input_tensor.flat(); - const int rank = input_tensor.dims(); - // original and target dimensions - std::vector original_dimensions(rank), - target_dimensions(rank); - tensorflow::int64 total_size(input_tensor.NumElements()), new_sliced_size(1); - // factors by which original_dimensions increases/decreases w.r.t. - // target_dimensions - std::vector dimension_ceiling(rank), - cumulative_dimensions(rank); - // index of adjustable dimension - int adjustable_dimension; - tensorflow::TensorShape output_shape; +std::vector InputIndexer::TensorShapeToVector( + const tensorflow::TensorShape& tensor_shape) { + std::vector result(tensor_shape.dims()); + int count = 0; + for (const auto dim_info : tensor_shape) { + result[count] = dim_info.size; + ++count; + } + return result; +} - // requires that the rank of the input tensor and length of the desired shape - // are equal - OP_REQUIRES(context, rank == desired_shape.size(), - tensorflow::errors::InvalidArgument( - "periodic_resample expects the rank of the input tensor, ", - rank, ", to be the same as the length of the desired shape, ", - desired_shape.size(), ".")); +std::vector InputIndexer::ComputeDimensionCeiling( + const std::vector& output_dimensions, + const std::vector& input_dimensions) { + std::vector dimension_ceiling(input_dimensions.size()); + for (size_t i = 0; i < input_dimensions.size(); ++i) { + dimension_ceiling[i] = (output_dimensions[i] + input_dimensions[i] - 1) / + input_dimensions[i]; + } + return dimension_ceiling; +} - bool found = false; - const auto& input_tensor_shape = input_tensor.shape(); +std::vector InputIndexer::ComputeCumulativeDimensions() { + std::vector cumulative_dimensions(rank_); + int count = 0; + for (int i = 0; i < rank_; ++i) { + if (count == 0) { + cumulative_dimensions[count] = 1; + } else { + cumulative_dimensions[count] = + cumulative_dimensions[count - 1] * dimension_ceiling_[count - 1]; + } + ++count; + } + return cumulative_dimensions; +} +template +void process_desired_shape(tensorflow::OpKernelContext* context, + const tensorflow::TensorShape& input_tensor_shape, + const IndexVecT& desired_shape, + int* adjustable_dimension, + std::vector* target_dimensions, + tensorflow::int64* output_size) { + tensorflow::int64 new_sliced_size = 1; + bool found = false; + const int rank = input_tensor_shape.dims(); for (int i = 0; i < rank; ++i) { - // if (desired_shape(i) < 1) { if (desired_shape[i] < 1) { // only one index can be adjustable OP_REQUIRES(context, !found, tensorflow::errors::InvalidArgument( "periodic_resample expects only " "one index to be marked as adjustable.")); - adjustable_dimension = i; + *adjustable_dimension = i; found = true; } else { OP_REQUIRES( @@ -122,9 +232,8 @@ template +void +do_periodic_resample_op(tensorflow::OpKernelContext* context, + const tensorflow::TensorShape& original_shape, + const tensorflow::PartialTensorShape& desired_shape, + const tensorflow::Tensor& source_tensor) { + const int rank = source_tensor.dims(); + + // requires that the rank of the input tensor and length of the desired shape + // are equal + OP_REQUIRES(context, rank == desired_shape.dims(), + tensorflow::errors::InvalidArgument( + "periodic_resample expects the rank of the input tensor, ", + rank, ", to be the same as the length of the desired shape, ", + desired_shape.dims(), ".")); + + std::vector target_dimensions(rank); + tensorflow::int64 new_size = 0; + // index of adjustable dimension + int adjustable_dimension = 0; + process_desired_shape(context, original_shape, desired_shape.dim_sizes(), + &adjustable_dimension, &target_dimensions, &new_size); // ensure that the new dimension is greater than zero OP_REQUIRES(context, target_dimensions[adjustable_dimension] > 0, @@ -160,11 +293,14 @@ template allocate_output(0, output_shape, &output_tensor)); auto output = output_tensor->flat(); - // memory is allocated for these variables outside the inner loop for - // efficiency (although, I could create a separate class scope for - // this purpose instead) - tensorflow::int64 result = 0; - std::vector output_indices(target_dimensions.size()); + // input is a strided array (last index is fastest, C-ordered) + auto input = source_tensor.flat(); // Fill output tensor with periodically resampled input tensor values - for (tensorflow::int64 output_index = 0; output_index < new_size; - ++output_index) { - output(output_index) = input(compute_input_index( - &target_dimensions, output_index, original_dimensions, - adjustable_dimension, dimension_ceiling, cumulative_dimensions, &result, - &output_indices, rank)); - } + InputIndexer input_indexer(target_dimensions, original_shape, + adjustable_dimension); + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + auto fill_output_tensor = [&input_indexer, &output, &input]( + tensorflow::int64 start, tensorflow::int64 limit) { + InputIndexer local_indexer(input_indexer); + local_indexer.MoveToOutputIndex(start); + for (tensorflow::int64 output_index = start; output_index < limit; + ++output_index) { + if (mode == Mode::kForward) { + output(output_index) = input(local_indexer.linear_input_index()); + } else { + output(local_indexer.linear_input_index()) = input(output_index); + } + local_indexer.IncrementOutputIndex(); + } + }; + ::tensorflow::Shard(worker_threads.num_threads, worker_threads.workers, + new_size, costPerFillIndex, fill_output_tensor); } +#define DATA_TYPE_SWITCH(data_type, context, CASE) \ + switch (data_type) { \ + CASE(float) \ + CASE(double) \ + CASE(tensorflow::int32) \ + CASE(tensorflow::int64) \ + default: \ + context->CtxFailure(__FILE__, __LINE__, \ + tensorflow::errors::InvalidArgument( \ + "Unsuppored tensor elements type")); \ + break; \ + } + void create_output_tensor( tensorflow::OpKernelContext* context, const tensorflow::Tensor& input_tensor, const tensorflow::DataType& input_tensor_type, - const tensorflow::PartialTensorShape& desired_shape_tensor) { - auto desired_shape = desired_shape_tensor.dim_sizes(); - - // obligatory type switch - switch (input_tensor_type) { - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, input_tensor); + const tensorflow::PartialTensorShape& desired_shape) { +#define CASE(type) \ + case tensorflow::DataTypeToEnum::value: \ + do_periodic_resample_op( \ + context, input_tensor.shape(), desired_shape, input_tensor); \ break; - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, input_tensor); - break; - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, - input_tensor); - break; - case tensorflow::DataTypeToEnum::value: - fill_periodic_tensor(context, desired_shape, - input_tensor); + + DATA_TYPE_SWITCH(input_tensor_type, context, CASE); +#undef CASE +} + +void create_grad_tensor(tensorflow::OpKernelContext* context, + const tensorflow::Tensor& grad_tensor, + const tensorflow::DataType& grad_tensor_type, + const tensorflow::TensorShape& original_shape, + const tensorflow::PartialTensorShape& desired_shape) { +#define CASE(type) \ + case tensorflow::DataTypeToEnum::value: \ + do_periodic_resample_op( \ + context, original_shape, desired_shape, grad_tensor); \ break; - default:; - } + + DATA_TYPE_SWITCH(grad_tensor_type, context, CASE); +#undef CASE } } // namespace @@ -238,4 +400,25 @@ class PeriodicResampleOp : public tensorflow::OpKernel { tensorflow::PartialTensorShape desired_shape; }; +class PeriodicResampleOpGrad : public tensorflow::OpKernel { + public: + explicit PeriodicResampleOpGrad(tensorflow::OpKernelConstruction* context) + : tensorflow::OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("original_shape", &original_shape)); + OP_REQUIRES_OK(context, context->GetAttr("desired_shape", &desired_shape)); + } + + void Compute(tensorflow::OpKernelContext* context) override { + const tensorflow::Tensor& grad_tensor = context->input(0); + const tensorflow::DataType grad_tensor_type = context->input_dtype(0); + create_grad_tensor(context, grad_tensor, grad_tensor_type, original_shape, + desired_shape); + } + + private: + tensorflow::TensorShape original_shape; + tensorflow::PartialTensorShape desired_shape; +}; + #endif // TENSORFLOW_KERNELS_PERIODICRESAMPLE_OP_H_ diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops.cc b/tensorflow/contrib/periodic_resample/ops/array_ops.cc index 82bd79695646e3673c2c78ad99dd2bd200fc2fbf..fd38cd09b4d0939d7955f7839763a8e955b71fa5 100644 --- a/tensorflow/contrib/periodic_resample/ops/array_ops.cc +++ b/tensorflow/contrib/periodic_resample/ops/array_ops.cc @@ -26,7 +26,42 @@ REGISTER_OP("PeriodicResample") .Input("values: T") .Attr("shape: shape") .Output("output: T") - .SetShapeFn(shape_inference::ExplicitShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + tensorflow::PartialTensorShape desired_shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape)); + shape_inference::ShapeHandle input_tensor_shape = c->input(0); + shape_inference::DimensionHandle num_input_elements = + c->NumElements(input_tensor_shape); + shape_inference::ShapeHandle result_shape_handle; + if (!shape_inference::InferenceContext::ValueKnown(num_input_elements)) { + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + desired_shape, &result_shape_handle)); + } else { + const int rank = c->Rank(input_tensor_shape); + std::vector target_dimensions(rank); + tensorflow::int64 new_sliced_size = 1; + int adjustable_dimension = 0; + for (int i = 0; i < rank; ++i) { + if (desired_shape.dim_size(i) < 1) { + adjustable_dimension = i; + } else { + target_dimensions[i] = desired_shape.dim_size(i); + new_sliced_size *= target_dimensions[i]; + } + } + target_dimensions[adjustable_dimension] = + shape_inference::InferenceContext::Value( + num_input_elements) / new_sliced_size; + tensorflow::TensorShape result_shape; + for (int i = 0; i < rank; ++i) { + result_shape.AddDim(target_dimensions[i]); + } + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape( + result_shape, &result_shape_handle)); + } + c->set_output(0, result_shape_handle); + return Status::OK(); + }) .Doc(R"doc( Periodically resample elements of a tensor to conform to `shape`. @@ -101,4 +136,20 @@ output: Periodically resampled tensor that has dimensions specified as in )doc"); + +REGISTER_OP("PeriodicResampleOpGrad") + .Attr("T: numbertype") + .Input("grad: T") + .Attr("original_shape: shape") + .Attr("desired_shape: shape") + .Output("grad_values: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + tensorflow::TensorShape original_shape; + TF_RETURN_IF_ERROR(c->GetAttr("original_shape", &original_shape)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(original_shape, &s)); + c->set_output(0, s); + return Status::OK(); +}); + } // namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..55edf76fcd3eed461e1465b569e1c2e9e2facbc0 --- /dev/null +++ b/tensorflow/contrib/periodic_resample/ops/array_ops_test.cc @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/shape_inference_testutil.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(ArrayOpsTest, PeriodicResample_ShapeFn) { + ShapeInferenceTestOp op("PeriodicResample"); + // Case 1: output shape can be fully inferreed. + PartialTensorShape shape({4, 4, -1}); + TensorShapeProto shape_proto; + shape.AsProto(&shape_proto); + + TF_ASSERT_OK(NodeDefBuilder("test", "PeriodicResample") + .Input({"values", 0, DT_INT32}) + .Attr("shape", shape_proto) + .Finalize(&op.node_def)); + INFER_OK(op, "[2,2,4]", "[4,4,1]"); + // Case 2: output shape can not be inferred - report desired shape. + INFER_OK(op, "[2,2,?]", "[4,4,?]"); +} + +} // end namespace tensorflow diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py index a25de55e18b223db2b724aafb54b18d8f48a5baa..31a6fe1d94b8a972087e00cf7c676105b0f1129b 100644 --- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py +++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py @@ -21,8 +21,11 @@ from __future__ import print_function import numpy from tensorflow.contrib.periodic_resample import periodic_resample +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -93,7 +96,6 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): def testPeriodicResampleErrors(self): input_tensor = numpy.zeros(shape=[1, 2, 2, 4]) with self.test_session(): - variables.global_variables_initializer().run() with self.assertRaisesWithPredicateMatch( errors_impl.InvalidArgumentError, 'Dimension 3 input tensor has size 4, desired shape has size 1'): @@ -103,6 +105,29 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase): '4, to be the same as the length of the desired shape, 3'): periodic_resample(input_tensor, [None, 4, 4]).eval() + def testPeriodicResampleGradient(self): + desired_shape = numpy.array([4, 4, None]) + result_shape = (4, 4, 1) + input_shape = (2, 2, 4) + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32, shape=input_shape) + output = periodic_resample(x, desired_shape) + error = gradient_checker.compute_gradient_error( + x, input_shape, output, result_shape) + self.assertLess(error, 1e-4) + + def testPeriodicResampleShapeInference(self): + with self.test_session() as sess: + # Case 1: output shape can be fully inferreed. + x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4)) + output = periodic_resample(x, [4, 4, None]) + self.assertEqual(output.shape, [4, 4, 1]) + # Case 2: output shape can not be inferred - report desired shape. + x = array_ops.placeholder(dtypes.float32, shape=(2, 2, None)) + output = periodic_resample(x, [4, 4, None]) + self.assertTrue(output.shape.is_compatible_with([4, 4, None])) + self.assertEqual(output.shape[2].value, None) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py index 348623d8f8d0c2ed60f559eca281343722038100..470e300ccbe7108fd49718341f4a522683366fe3 100644 --- a/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py +++ b/tensorflow/contrib/periodic_resample/python/ops/periodic_resample_op.py @@ -21,11 +21,17 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.contrib.periodic_resample.python.ops import gen_periodic_resample_op -from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample +from tensorflow.contrib.periodic_resample.python.ops.gen_periodic_resample_op import periodic_resample, periodic_resample_op_grad from tensorflow.contrib.util import loader +from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader # pylint: enable=unused-import _periodic_resample_op = loader.load_op_library( resource_loader.get_path_to_datafile('_periodic_resample_op.so')) + +@ops.RegisterGradient("PeriodicResample") +def _periodic_resample_grad_cc(op, grad): + return periodic_resample_op_grad( + grad, op.inputs[0].shape, op.get_attr('shape')) diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py index b7a98c68e2343e9c8bb4b41556dc96bfe4ef444c..af3b2ad1b531b835f484a155efcc57bbe634f2df 100644 --- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py +++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py @@ -34,7 +34,8 @@ class ContribEstimatorPredictor(predictor.Predictor): prediction_input_fn, input_alternative_key=None, output_alternative_key=None, - graph=None): + graph=None, + config=None): """Initialize a `ContribEstimatorPredictor`. Args: @@ -48,6 +49,7 @@ class ContribEstimatorPredictor(predictor.Predictor): multi-headed models. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. """ self._graph = graph or ops.Graph() with self._graph.as_default(): @@ -58,6 +60,7 @@ class ContribEstimatorPredictor(predictor.Predictor): checkpoint_path = saver.latest_checkpoint(estimator.model_dir) self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( + config=config, checkpoint_filename_with_path=checkpoint_path)) input_alternative_key = ( diff --git a/tensorflow/contrib/predictor/core_estimator_predictor.py b/tensorflow/contrib/predictor/core_estimator_predictor.py index d78d94c2699b14c80e7decee2181d190a6d91f99..a725072e72df2db64cde5ea31ab16e7c2dc5d2ce 100644 --- a/tensorflow/contrib/predictor/core_estimator_predictor.py +++ b/tensorflow/contrib/predictor/core_estimator_predictor.py @@ -51,7 +51,8 @@ class CoreEstimatorPredictor(predictor.Predictor): estimator, serving_input_receiver_fn, output_key=None, - graph=None): + graph=None, + config=None): """Initialize a `CoreEstimatorPredictor`. Args: @@ -62,6 +63,7 @@ class CoreEstimatorPredictor(predictor.Predictor): `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. """ self._graph = graph or ops.Graph() with self._graph.as_default(): @@ -71,6 +73,7 @@ class CoreEstimatorPredictor(predictor.Predictor): checkpoint_dir = estimator.model_dir self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( + config=config, checkpoint_dir=checkpoint_dir)) feed_tensor_info = signature_def.inputs diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py index 6e77e934fe19851eea9ed0b74eb7aecc76f6237a..f275bc15adfa0a51a48964dff8edddbd45500e45 100644 --- a/tensorflow/contrib/predictor/predictor_factories.py +++ b/tensorflow/contrib/predictor/predictor_factories.py @@ -30,7 +30,8 @@ def from_contrib_estimator(estimator, prediction_input_fn, input_alternative_key=None, output_alternative_key=None, - graph=None): + graph=None, + config=None): """Constructs a `Predictor` from a `tf.contrib.learn.Estimator`. Args: @@ -44,6 +45,7 @@ def from_contrib_estimator(estimator, multi-headed models. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. Returns: An initialized `Predictor`. @@ -62,13 +64,15 @@ def from_contrib_estimator(estimator, prediction_input_fn, input_alternative_key=input_alternative_key, output_alternative_key=output_alternative_key, - graph=graph) + graph=graph, + config=config) def from_estimator(estimator, serving_input_receiver_fn, output_key=None, - graph=None): + graph=None, + config=None): """Constructs a `Predictor` from a `tf.python.estimator.Estimator`. Args: @@ -79,6 +83,7 @@ def from_estimator(estimator, `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. Returns: An initialized `Predictor`. @@ -93,14 +98,19 @@ def from_estimator(estimator, 'tf.contrib.learn.Estimator. You likely want to call ' 'from_contrib_estimator.') return core_estimator_predictor.CoreEstimatorPredictor( - estimator, serving_input_receiver_fn, output_key=output_key, graph=graph) + estimator, + serving_input_receiver_fn, + output_key=output_key, + graph=graph, + config=config) def from_saved_model(export_dir, signature_def_key=None, signature_def=None, tags=None, - graph=None): + graph=None, + config=None): """Constructs a `Predictor` from a `SavedModel` on disk. Args: @@ -115,6 +125,7 @@ def from_saved_model(export_dir, `SignatureDef`. Defaults to `DEFAULT_TAGS`. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. Returns: An initialized `Predictor`. @@ -128,4 +139,5 @@ def from_saved_model(export_dir, signature_def_key=signature_def_key, signature_def=signature_def, tags=tags, - graph=graph) + graph=graph, + config=config) diff --git a/tensorflow/contrib/predictor/predictor_factories_test.py b/tensorflow/contrib/predictor/predictor_factories_test.py index 578d9424b25dd38f1d77a267d1fdf1ff9ff2da88..a2ef1dc3af0986afacf646f0dc04b7ef857a7f93 100644 --- a/tensorflow/contrib/predictor/predictor_factories_test.py +++ b/tensorflow/contrib/predictor/predictor_factories_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.predictor import predictor_factories from tensorflow.contrib.predictor import testing_common +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import test MODEL_DIR_NAME = 'contrib/predictor/test_export_dir' @@ -41,6 +42,11 @@ class PredictorFactoriesTest(test.TestCase): """Test loading from_saved_model with tags.""" predictor_factories.from_saved_model(self._export_dir, tags='serve') + def testFromSavedModelWithSessionConfig(self): + """Test loading from_saved_model with session config.""" + predictor_factories.from_saved_model( + self._export_dir, config=config_pb2.ConfigProto()) + def testFromSavedModelWithBadTags(self): """Test that loading fails for bad tags.""" bad_tags_regex = ('.*? could not be found in SavedModel') @@ -53,6 +59,13 @@ class PredictorFactoriesTest(test.TestCase): predictor_factories.from_contrib_estimator( estimator, input_fn, output_alternative_key='sum') + def testFromContribEstimatorWithSessionConfig(self): + estimator = testing_common.get_arithmetic_estimator(core=False) + input_fn = testing_common.get_arithmetic_input_fn(core=False) + predictor_factories.from_contrib_estimator( + estimator, input_fn, output_alternative_key='sum', + config=config_pb2.ConfigProto()) + def testFromContribEstimatorWithCoreEstimatorRaises(self): estimator = testing_common.get_arithmetic_estimator(core=True) input_fn = testing_common.get_arithmetic_input_fn(core=True) @@ -64,6 +77,12 @@ class PredictorFactoriesTest(test.TestCase): input_fn = testing_common.get_arithmetic_input_fn(core=True) predictor_factories.from_estimator(estimator, input_fn) + def testFromCoreEstimatorWithSessionConfig(self): + estimator = testing_common.get_arithmetic_estimator(core=True) + input_fn = testing_common.get_arithmetic_input_fn(core=True) + predictor_factories.from_estimator( + estimator, input_fn, config=config_pb2.ConfigProto()) + def testFromCoreEstimatorWithContribEstimatorRaises(self): estimator = testing_common.get_arithmetic_estimator(core=False) input_fn = testing_common.get_arithmetic_input_fn(core=False) diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py index 0dbca0f8136e4e618234101ee41c80bc085511c0..95da6d04edc5214d1b5c1851c4ab05c6d7080b9b 100644 --- a/tensorflow/contrib/predictor/saved_model_predictor.py +++ b/tensorflow/contrib/predictor/saved_model_predictor.py @@ -121,7 +121,8 @@ class SavedModelPredictor(predictor.Predictor): input_names=None, output_names=None, tags=None, - graph=None): + graph=None, + config=None): """Initialize a `CoreEstimatorPredictor`. Args: @@ -142,6 +143,7 @@ class SavedModelPredictor(predictor.Predictor): the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. Raises: ValueError: If more than one of signature_def_key OR signature_def OR (input_names AND output_names) is specified. @@ -152,7 +154,7 @@ class SavedModelPredictor(predictor.Predictor): self._graph = graph or ops.Graph() with self._graph.as_default(): - self._session = session.Session() + self._session = session.Session(config=config) loader.load(self._session, tags.split(','), export_dir) if input_names is None: diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index c83623ec947c1550991352a9dd9a5c6ee9282290..27a933c0f945e53a1838aefd30aed82fadbbc146 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -6,7 +6,7 @@ inference. The details of the transformation implemented in this package is described here [1]. This is done using the -[fake quantization op](https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/fake_quantization). +[fake quantization op](https://www.tensorflow.org/api_guides/python/array_ops#Fake_quantization). Literature has shown that fixed point networks provide comparable performance to floating point networks [2]. This is achieved by modeling the quantization diff --git a/tensorflow/contrib/receptive_field/README.md b/tensorflow/contrib/receptive_field/README.md index 3ff85faf611afad71b6e6203453bbe97c56f9242..79b015a9163f5727caa40b54579c71e57621c92f 100644 --- a/tensorflow/contrib/receptive_field/README.md +++ b/tensorflow/contrib/receptive_field/README.md @@ -6,6 +6,32 @@ region your output features depend on. Better yet, using the parameters computed by the library, you can easily find the exact image region which is used to compute each convnet feature. +This library can be used to compute receptive field parameters of popular +convnets: + +

+ +convnet model | receptive field | effective stride | effective padding +:-----------------: | :-------------: | :--------------: | :---------------: +alexnet_v2 | 195 | 32 | 64 +vgg_16 | 212 | 32 | 90 +inception_v2 | 699 | 32 | 318 +inception_v3 | 1311 | 32 | 618 +inception_v4 | 2071 | 32 | 998 +inception_resnet_v2 | 3039 | 32 | 1482 +mobilenet_v1 | 315 | 32 | 126 +mobilenet_v1_075 | 315 | 32 | 126 +resnet_v1_50 | 483 | 32 | 241 +resnet_v1_101 | 1027 | 32 | 513 +resnet_v1_152 | 1507 | 32 | 753 +resnet_v1_200 | 1763 | 32 | 881 + +
+ +A comprehensive table with pre-computed receptive field parameters for different +end-points, input resolutions, and other variants of these networks can be found +[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md). + ## Basic usage The main function to be called is `compute_receptive_field_from_graph_def`, @@ -96,9 +122,9 @@ The script will write to stdout the receptive field parameters for many variants of several popular convnets: AlexNet, VGG, ResNet, Inception, Mobilenet. They are also written to the file `/tmp/rf_benchmark_results.csv`. -TODO: include here a plot for receptive field sizes of different convnets. - -TODO: include table/link to pre-computed RF parameters. +A comprehensive table with pre-computed receptive field parameters for different +networks can be found +[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md). ## Compute RF parameters from a graph pbtxt diff --git a/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md b/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md new file mode 100644 index 0000000000000000000000000000000000000000..736fbef6e7c66176e74144115f0b1acd6bf6cd2f --- /dev/null +++ b/tensorflow/contrib/receptive_field/RECEPTIVE_FIELD_TABLE.md @@ -0,0 +1,629 @@ +# Pre-computed receptive field parameters + +## Table with results + +The table below presents the receptive field parameters for several popular +convolutional neural networks. These are computed using the models from the +[TF-Slim +repository](https://github.com/tensorflow/models/tree/master/research/slim), +by using the [rf_benchmark +script](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/python/util/examples/rf_benchmark.py). + +Questions? See the [FAQ](#faq). + +CNN | resolution | end-point | RF | effective stride | effective padding +:----------------------------: | :--------: | :------------------: | :--: | :--------------: | :---------------: +alexnet_v2 | None | alexnet_v2/conv1 | 11 | 4 | 0 +alexnet_v2 | None | alexnet_v2/pool1 | 19 | 8 | 0 +alexnet_v2 | None | alexnet_v2/conv2 | 51 | 8 | 16 +alexnet_v2 | None | alexnet_v2/conv3 | 99 | 16 | 32 +alexnet_v2 | None | alexnet_v2/conv4 | 131 | 16 | 48 +alexnet_v2 | None | alexnet_v2/conv5 | 163 | 16 | 64 +alexnet_v2 | None | alexnet_v2/pool5 | 195 | 32 | 64 +alexnet_v2 | 224 | alexnet_v2/conv1 | 11 | 4 | 0 +alexnet_v2 | 224 | alexnet_v2/pool1 | 19 | 8 | 0 +alexnet_v2 | 224 | alexnet_v2/conv2 | 51 | 8 | 16 +alexnet_v2 | 224 | alexnet_v2/conv3 | 99 | 16 | 32 +alexnet_v2 | 224 | alexnet_v2/conv4 | 131 | 16 | 48 +alexnet_v2 | 224 | alexnet_v2/conv5 | 163 | 16 | 64 +alexnet_v2 | 224 | alexnet_v2/pool5 | 195 | 32 | 64 +alexnet_v2 | 321 | alexnet_v2/conv1 | 11 | 4 | 0 +alexnet_v2 | 321 | alexnet_v2/pool1 | 19 | 8 | 0 +alexnet_v2 | 321 | alexnet_v2/conv2 | 51 | 8 | 16 +alexnet_v2 | 321 | alexnet_v2/conv3 | 99 | 16 | 32 +alexnet_v2 | 321 | alexnet_v2/conv4 | 131 | 16 | 48 +alexnet_v2 | 321 | alexnet_v2/conv5 | 163 | 16 | 64 +alexnet_v2 | 321 | alexnet_v2/pool5 | 195 | 32 | 64 +vgg_a | None | vgg_a/conv1/conv1_1 | 3 | 1 | 1 +vgg_a | None | vgg_a/pool1 | 4 | 2 | 1 +vgg_a | None | vgg_a/conv2/conv2_1 | 8 | 2 | 3 +vgg_a | None | vgg_a/pool2 | 10 | 4 | 3 +vgg_a | None | vgg_a/conv3/conv3_1 | 18 | 4 | 7 +vgg_a | None | vgg_a/conv3/conv3_2 | 26 | 4 | 11 +vgg_a | None | vgg_a/pool3 | 30 | 8 | 11 +vgg_a | None | vgg_a/conv4/conv4_1 | 46 | 8 | 19 +vgg_a | None | vgg_a/conv4/conv4_2 | 62 | 8 | 27 +vgg_a | None | vgg_a/pool4 | 70 | 16 | 27 +vgg_a | None | vgg_a/conv5/conv5_1 | 102 | 16 | 43 +vgg_a | None | vgg_a/conv5/conv5_2 | 134 | 16 | 59 +vgg_a | None | vgg_a/pool5 | 150 | 32 | 59 +vgg_a | 224 | vgg_a/conv1/conv1_1 | 3 | 1 | 1 +vgg_a | 224 | vgg_a/pool1 | 4 | 2 | 1 +vgg_a | 224 | vgg_a/conv2/conv2_1 | 8 | 2 | 3 +vgg_a | 224 | vgg_a/pool2 | 10 | 4 | 3 +vgg_a | 224 | vgg_a/conv3/conv3_1 | 18 | 4 | 7 +vgg_a | 224 | vgg_a/conv3/conv3_2 | 26 | 4 | 11 +vgg_a | 224 | vgg_a/pool3 | 30 | 8 | 11 +vgg_a | 224 | vgg_a/conv4/conv4_1 | 46 | 8 | 19 +vgg_a | 224 | vgg_a/conv4/conv4_2 | 62 | 8 | 27 +vgg_a | 224 | vgg_a/pool4 | 70 | 16 | 27 +vgg_a | 224 | vgg_a/conv5/conv5_1 | 102 | 16 | 43 +vgg_a | 224 | vgg_a/conv5/conv5_2 | 134 | 16 | 59 +vgg_a | 224 | vgg_a/pool5 | 150 | 32 | 59 +vgg_a | 321 | vgg_a/conv1/conv1_1 | 3 | 1 | 1 +vgg_a | 321 | vgg_a/pool1 | 4 | 2 | 1 +vgg_a | 321 | vgg_a/conv2/conv2_1 | 8 | 2 | 3 +vgg_a | 321 | vgg_a/pool2 | 10 | 4 | 3 +vgg_a | 321 | vgg_a/conv3/conv3_1 | 18 | 4 | 7 +vgg_a | 321 | vgg_a/conv3/conv3_2 | 26 | 4 | 11 +vgg_a | 321 | vgg_a/pool3 | 30 | 8 | 11 +vgg_a | 321 | vgg_a/conv4/conv4_1 | 46 | 8 | 19 +vgg_a | 321 | vgg_a/conv4/conv4_2 | 62 | 8 | 27 +vgg_a | 321 | vgg_a/pool4 | 70 | 16 | 27 +vgg_a | 321 | vgg_a/conv5/conv5_1 | 102 | 16 | 43 +vgg_a | 321 | vgg_a/conv5/conv5_2 | 134 | 16 | 59 +vgg_a | 321 | vgg_a/pool5 | 150 | 32 | 59 +vgg_16 | None | vgg_16/conv1/conv1_1 | 3 | 1 | 1 +vgg_16 | None | vgg_16/pool1 | 6 | 2 | 2 +vgg_16 | None | vgg_16/conv2/conv2_1 | 10 | 2 | 4 +vgg_16 | None | vgg_16/pool2 | 16 | 4 | 6 +vgg_16 | None | vgg_16/conv3/conv3_1 | 24 | 4 | 10 +vgg_16 | None | vgg_16/conv3/conv3_2 | 32 | 4 | 14 +vgg_16 | None | vgg_16/pool3 | 44 | 8 | 18 +vgg_16 | None | vgg_16/conv4/conv4_1 | 60 | 8 | 26 +vgg_16 | None | vgg_16/conv4/conv4_2 | 76 | 8 | 34 +vgg_16 | None | vgg_16/pool4 | 100 | 16 | 42 +vgg_16 | None | vgg_16/conv5/conv5_1 | 132 | 16 | 58 +vgg_16 | None | vgg_16/conv5/conv5_2 | 164 | 16 | 74 +vgg_16 | None | vgg_16/pool5 | 212 | 32 | 90 +vgg_16 | 224 | vgg_16/conv1/conv1_1 | 3 | 1 | 1 +vgg_16 | 224 | vgg_16/pool1 | 6 | 2 | 2 +vgg_16 | 224 | vgg_16/conv2/conv2_1 | 10 | 2 | 4 +vgg_16 | 224 | vgg_16/pool2 | 16 | 4 | 6 +vgg_16 | 224 | vgg_16/conv3/conv3_1 | 24 | 4 | 10 +vgg_16 | 224 | vgg_16/conv3/conv3_2 | 32 | 4 | 14 +vgg_16 | 224 | vgg_16/pool3 | 44 | 8 | 18 +vgg_16 | 224 | vgg_16/conv4/conv4_1 | 60 | 8 | 26 +vgg_16 | 224 | vgg_16/conv4/conv4_2 | 76 | 8 | 34 +vgg_16 | 224 | vgg_16/pool4 | 100 | 16 | 42 +vgg_16 | 224 | vgg_16/conv5/conv5_1 | 132 | 16 | 58 +vgg_16 | 224 | vgg_16/conv5/conv5_2 | 164 | 16 | 74 +vgg_16 | 224 | vgg_16/pool5 | 212 | 32 | 90 +vgg_16 | 321 | vgg_16/conv1/conv1_1 | 3 | 1 | 1 +vgg_16 | 321 | vgg_16/pool1 | 6 | 2 | 2 +vgg_16 | 321 | vgg_16/conv2/conv2_1 | 10 | 2 | 4 +vgg_16 | 321 | vgg_16/pool2 | 16 | 4 | 6 +vgg_16 | 321 | vgg_16/conv3/conv3_1 | 24 | 4 | 10 +vgg_16 | 321 | vgg_16/conv3/conv3_2 | 32 | 4 | 14 +vgg_16 | 321 | vgg_16/pool3 | 44 | 8 | 18 +vgg_16 | 321 | vgg_16/conv4/conv4_1 | 60 | 8 | 26 +vgg_16 | 321 | vgg_16/conv4/conv4_2 | 76 | 8 | 34 +vgg_16 | 321 | vgg_16/pool4 | 100 | 16 | 42 +vgg_16 | 321 | vgg_16/conv5/conv5_1 | 132 | 16 | 58 +vgg_16 | 321 | vgg_16/conv5/conv5_2 | 164 | 16 | 74 +vgg_16 | 321 | vgg_16/pool5 | 212 | 32 | 90 +inception_v2 | None | Conv2d_1a_7x7 | 7 | 2 | None +inception_v2 | None | MaxPool_2a_3x3 | 11 | 4 | None +inception_v2 | None | Conv2d_2b_1x1 | 11 | 4 | None +inception_v2 | None | Conv2d_2c_3x3 | 19 | 4 | None +inception_v2 | None | MaxPool_3a_3x3 | 27 | 8 | None +inception_v2 | None | Mixed_3b | 59 | 8 | None +inception_v2 | None | Mixed_3c | 91 | 8 | None +inception_v2 | None | Mixed_4a | 123 | 16 | None +inception_v2 | None | Mixed_4b | 187 | 16 | None +inception_v2 | None | Mixed_4c | 251 | 16 | None +inception_v2 | None | Mixed_4d | 315 | 16 | None +inception_v2 | None | Mixed_4e | 379 | 16 | None +inception_v2 | None | Mixed_5a | 443 | 32 | None +inception_v2 | None | Mixed_5b | 571 | 32 | None +inception_v2 | None | Mixed_5c | 699 | 32 | None +inception_v2 | 224 | Conv2d_1a_7x7 | 7 | 2 | 2 +inception_v2 | 224 | MaxPool_2a_3x3 | 11 | 4 | 2 +inception_v2 | 224 | Conv2d_2b_1x1 | 11 | 4 | 2 +inception_v2 | 224 | Conv2d_2c_3x3 | 19 | 4 | 6 +inception_v2 | 224 | MaxPool_3a_3x3 | 27 | 8 | 6 +inception_v2 | 224 | Mixed_3b | 59 | 8 | 22 +inception_v2 | 224 | Mixed_3c | 91 | 8 | 38 +inception_v2 | 224 | Mixed_4a | 123 | 16 | 46 +inception_v2 | 224 | Mixed_4b | 187 | 16 | 78 +inception_v2 | 224 | Mixed_4c | 251 | 16 | 110 +inception_v2 | 224 | Mixed_4d | 315 | 16 | 142 +inception_v2 | 224 | Mixed_4e | 379 | 16 | 174 +inception_v2 | 224 | Mixed_5a | 443 | 32 | 190 +inception_v2 | 224 | Mixed_5b | 571 | 32 | 254 +inception_v2 | 224 | Mixed_5c | 699 | 32 | 318 +inception_v2 | 321 | Conv2d_1a_7x7 | 7 | 2 | 3 +inception_v2 | 321 | MaxPool_2a_3x3 | 11 | 4 | 5 +inception_v2 | 321 | Conv2d_2b_1x1 | 11 | 4 | 5 +inception_v2 | 321 | Conv2d_2c_3x3 | 19 | 4 | 9 +inception_v2 | 321 | MaxPool_3a_3x3 | 27 | 8 | 13 +inception_v2 | 321 | Mixed_3b | 59 | 8 | 29 +inception_v2 | 321 | Mixed_3c | 91 | 8 | 45 +inception_v2 | 321 | Mixed_4a | 123 | 16 | 61 +inception_v2 | 321 | Mixed_4b | 187 | 16 | 93 +inception_v2 | 321 | Mixed_4c | 251 | 16 | 125 +inception_v2 | 321 | Mixed_4d | 315 | 16 | 157 +inception_v2 | 321 | Mixed_4e | 379 | 16 | 189 +inception_v2 | 321 | Mixed_5a | 443 | 32 | 221 +inception_v2 | 321 | Mixed_5b | 571 | 32 | 285 +inception_v2 | 321 | Mixed_5c | 699 | 32 | 349 +inception_v2-no-separable-conv | None | Conv2d_1a_7x7 | 7 | 2 | None +inception_v2-no-separable-conv | None | MaxPool_2a_3x3 | 11 | 4 | None +inception_v2-no-separable-conv | None | Conv2d_2b_1x1 | 11 | 4 | None +inception_v2-no-separable-conv | None | Conv2d_2c_3x3 | 19 | 4 | None +inception_v2-no-separable-conv | None | MaxPool_3a_3x3 | 27 | 8 | None +inception_v2-no-separable-conv | None | Mixed_3b | 59 | 8 | None +inception_v2-no-separable-conv | None | Mixed_3c | 91 | 8 | None +inception_v2-no-separable-conv | None | Mixed_4a | 123 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4b | 187 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4c | 251 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4d | 315 | 16 | None +inception_v2-no-separable-conv | None | Mixed_4e | 379 | 16 | None +inception_v2-no-separable-conv | None | Mixed_5a | 443 | 32 | None +inception_v2-no-separable-conv | None | Mixed_5b | 571 | 32 | None +inception_v2-no-separable-conv | None | Mixed_5c | 699 | 32 | None +inception_v2-no-separable-conv | 224 | Conv2d_1a_7x7 | 7 | 2 | 2 +inception_v2-no-separable-conv | 224 | MaxPool_2a_3x3 | 11 | 4 | 2 +inception_v2-no-separable-conv | 224 | Conv2d_2b_1x1 | 11 | 4 | 2 +inception_v2-no-separable-conv | 224 | Conv2d_2c_3x3 | 19 | 4 | 6 +inception_v2-no-separable-conv | 224 | MaxPool_3a_3x3 | 27 | 8 | 6 +inception_v2-no-separable-conv | 224 | Mixed_3b | 59 | 8 | 22 +inception_v2-no-separable-conv | 224 | Mixed_3c | 91 | 8 | 38 +inception_v2-no-separable-conv | 224 | Mixed_4a | 123 | 16 | 46 +inception_v2-no-separable-conv | 224 | Mixed_4b | 187 | 16 | 78 +inception_v2-no-separable-conv | 224 | Mixed_4c | 251 | 16 | 110 +inception_v2-no-separable-conv | 224 | Mixed_4d | 315 | 16 | 142 +inception_v2-no-separable-conv | 224 | Mixed_4e | 379 | 16 | 174 +inception_v2-no-separable-conv | 224 | Mixed_5a | 443 | 32 | 190 +inception_v2-no-separable-conv | 224 | Mixed_5b | 571 | 32 | 254 +inception_v2-no-separable-conv | 224 | Mixed_5c | 699 | 32 | 318 +inception_v2-no-separable-conv | 321 | Conv2d_1a_7x7 | 7 | 2 | 3 +inception_v2-no-separable-conv | 321 | MaxPool_2a_3x3 | 11 | 4 | 5 +inception_v2-no-separable-conv | 321 | Conv2d_2b_1x1 | 11 | 4 | 5 +inception_v2-no-separable-conv | 321 | Conv2d_2c_3x3 | 19 | 4 | 9 +inception_v2-no-separable-conv | 321 | MaxPool_3a_3x3 | 27 | 8 | 13 +inception_v2-no-separable-conv | 321 | Mixed_3b | 59 | 8 | 29 +inception_v2-no-separable-conv | 321 | Mixed_3c | 91 | 8 | 45 +inception_v2-no-separable-conv | 321 | Mixed_4a | 123 | 16 | 61 +inception_v2-no-separable-conv | 321 | Mixed_4b | 187 | 16 | 93 +inception_v2-no-separable-conv | 321 | Mixed_4c | 251 | 16 | 125 +inception_v2-no-separable-conv | 321 | Mixed_4d | 315 | 16 | 157 +inception_v2-no-separable-conv | 321 | Mixed_4e | 379 | 16 | 189 +inception_v2-no-separable-conv | 321 | Mixed_5a | 443 | 32 | 221 +inception_v2-no-separable-conv | 321 | Mixed_5b | 571 | 32 | 285 +inception_v2-no-separable-conv | 321 | Mixed_5c | 699 | 32 | 349 +inception_v3 | None | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v3 | None | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v3 | None | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v3 | None | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_v3 | None | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_v3 | None | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_v3 | None | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_v3 | None | Mixed_5b | 63 | 8 | 18 +inception_v3 | None | Mixed_5c | 95 | 8 | 34 +inception_v3 | None | Mixed_5d | 127 | 8 | 50 +inception_v3 | None | Mixed_6a | 159 | 16 | 58 +inception_v3 | None | Mixed_6b | 351 | 16 | 154 +inception_v3 | None | Mixed_6c | 543 | 16 | 250 +inception_v3 | None | Mixed_6d | 735 | 16 | 346 +inception_v3 | None | Mixed_6e | 927 | 16 | 442 +inception_v3 | None | Mixed_7a | 1055 | 32 | 490 +inception_v3 | None | Mixed_7b | 1183 | 32 | 554 +inception_v3 | None | Mixed_7c | 1311 | 32 | 618 +inception_v3 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v3 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v3 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v3 | 224 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_v3 | 224 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_v3 | 224 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_v3 | 224 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_v3 | 224 | Mixed_5b | 63 | 8 | 18 +inception_v3 | 224 | Mixed_5c | 95 | 8 | 34 +inception_v3 | 224 | Mixed_5d | 127 | 8 | 50 +inception_v3 | 224 | Mixed_6a | 159 | 16 | 58 +inception_v3 | 224 | Mixed_6b | 351 | 16 | 154 +inception_v3 | 224 | Mixed_6c | 543 | 16 | 250 +inception_v3 | 224 | Mixed_6d | 735 | 16 | 346 +inception_v3 | 224 | Mixed_6e | 927 | 16 | 442 +inception_v3 | 224 | Mixed_7a | 1055 | 32 | 490 +inception_v3 | 224 | Mixed_7b | 1183 | 32 | 554 +inception_v3 | 224 | Mixed_7c | 1311 | 32 | 618 +inception_v3 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v3 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v3 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v3 | 321 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_v3 | 321 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_v3 | 321 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_v3 | 321 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_v3 | 321 | Mixed_5b | 63 | 8 | 18 +inception_v3 | 321 | Mixed_5c | 95 | 8 | 34 +inception_v3 | 321 | Mixed_5d | 127 | 8 | 50 +inception_v3 | 321 | Mixed_6a | 159 | 16 | 58 +inception_v3 | 321 | Mixed_6b | 351 | 16 | 154 +inception_v3 | 321 | Mixed_6c | 543 | 16 | 250 +inception_v3 | 321 | Mixed_6d | 735 | 16 | 346 +inception_v3 | 321 | Mixed_6e | 927 | 16 | 442 +inception_v3 | 321 | Mixed_7a | 1055 | 32 | 490 +inception_v3 | 321 | Mixed_7b | 1183 | 32 | 554 +inception_v3 | 321 | Mixed_7c | 1311 | 32 | 618 +inception_v4 | None | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v4 | None | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v4 | None | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v4 | None | Mixed_3a | 15 | 4 | 2 +inception_v4 | None | Mixed_4a | 47 | 4 | 14 +inception_v4 | None | Mixed_5a | 55 | 8 | 14 +inception_v4 | None | Mixed_5b | 87 | 8 | 30 +inception_v4 | None | Mixed_5c | 119 | 8 | 46 +inception_v4 | None | Mixed_5d | 151 | 8 | 62 +inception_v4 | None | Mixed_5e | 183 | 8 | 78 +inception_v4 | None | Mixed_6a | 215 | 16 | 86 +inception_v4 | None | Mixed_6b | 407 | 16 | 182 +inception_v4 | None | Mixed_6c | 599 | 16 | 278 +inception_v4 | None | Mixed_6d | 791 | 16 | 374 +inception_v4 | None | Mixed_6e | 983 | 16 | 470 +inception_v4 | None | Mixed_6f | 1175 | 16 | 566 +inception_v4 | None | Mixed_6g | 1367 | 16 | 662 +inception_v4 | None | Mixed_6h | 1559 | 16 | 758 +inception_v4 | None | Mixed_7a | 1687 | 32 | 806 +inception_v4 | None | Mixed_7b | 1815 | 32 | 870 +inception_v4 | None | Mixed_7c | 1943 | 32 | 934 +inception_v4 | None | Mixed_7d | 2071 | 32 | 998 +inception_v4 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v4 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v4 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v4 | 224 | Mixed_3a | 15 | 4 | 2 +inception_v4 | 224 | Mixed_4a | 47 | 4 | 14 +inception_v4 | 224 | Mixed_5a | 55 | 8 | 14 +inception_v4 | 224 | Mixed_5b | 87 | 8 | 30 +inception_v4 | 224 | Mixed_5c | 119 | 8 | 46 +inception_v4 | 224 | Mixed_5d | 151 | 8 | 62 +inception_v4 | 224 | Mixed_5e | 183 | 8 | 78 +inception_v4 | 224 | Mixed_6a | 215 | 16 | 86 +inception_v4 | 224 | Mixed_6b | 407 | 16 | 182 +inception_v4 | 224 | Mixed_6c | 599 | 16 | 278 +inception_v4 | 224 | Mixed_6d | 791 | 16 | 374 +inception_v4 | 224 | Mixed_6e | 983 | 16 | 470 +inception_v4 | 224 | Mixed_6f | 1175 | 16 | 566 +inception_v4 | 224 | Mixed_6g | 1367 | 16 | 662 +inception_v4 | 224 | Mixed_6h | 1559 | 16 | 758 +inception_v4 | 224 | Mixed_7a | 1687 | 32 | 806 +inception_v4 | 224 | Mixed_7b | 1815 | 32 | 870 +inception_v4 | 224 | Mixed_7c | 1943 | 32 | 934 +inception_v4 | 224 | Mixed_7d | 2071 | 32 | 998 +inception_v4 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_v4 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_v4 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_v4 | 321 | Mixed_3a | 15 | 4 | 2 +inception_v4 | 321 | Mixed_4a | 47 | 4 | 14 +inception_v4 | 321 | Mixed_5a | 55 | 8 | 14 +inception_v4 | 321 | Mixed_5b | 87 | 8 | 30 +inception_v4 | 321 | Mixed_5c | 119 | 8 | 46 +inception_v4 | 321 | Mixed_5d | 151 | 8 | 62 +inception_v4 | 321 | Mixed_5e | 183 | 8 | 78 +inception_v4 | 321 | Mixed_6a | 215 | 16 | 86 +inception_v4 | 321 | Mixed_6b | 407 | 16 | 182 +inception_v4 | 321 | Mixed_6c | 599 | 16 | 278 +inception_v4 | 321 | Mixed_6d | 791 | 16 | 374 +inception_v4 | 321 | Mixed_6e | 983 | 16 | 470 +inception_v4 | 321 | Mixed_6f | 1175 | 16 | 566 +inception_v4 | 321 | Mixed_6g | 1367 | 16 | 662 +inception_v4 | 321 | Mixed_6h | 1559 | 16 | 758 +inception_v4 | 321 | Mixed_7a | 1687 | 32 | 806 +inception_v4 | 321 | Mixed_7b | 1815 | 32 | 870 +inception_v4 | 321 | Mixed_7c | 1943 | 32 | 934 +inception_v4 | 321 | Mixed_7d | 2071 | 32 | 998 +inception_resnet_v2 | None | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2 | None | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_resnet_v2 | None | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_resnet_v2 | None | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_resnet_v2 | None | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_resnet_v2 | None | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_resnet_v2 | None | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_resnet_v2 | None | Mixed_5b | 63 | 8 | 18 +inception_resnet_v2 | None | Mixed_6a | 415 | 16 | 186 +inception_resnet_v2 | None | PreAuxLogits | 2335 | 16 | 1146 +inception_resnet_v2 | None | Mixed_7a | 2399 | 32 | 1162 +inception_resnet_v2 | None | Conv2d_7b_1x1 | 3039 | 32 | 1482 +inception_resnet_v2 | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2 | 224 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_resnet_v2 | 224 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_resnet_v2 | 224 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_resnet_v2 | 224 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_resnet_v2 | 224 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_resnet_v2 | 224 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_resnet_v2 | 224 | Mixed_5b | 63 | 8 | 18 +inception_resnet_v2 | 224 | Mixed_6a | 415 | 16 | 186 +inception_resnet_v2 | 224 | PreAuxLogits | 2335 | 16 | 1146 +inception_resnet_v2 | 224 | Mixed_7a | 2399 | 32 | 1162 +inception_resnet_v2 | 224 | Conv2d_7b_1x1 | 3039 | 32 | 1482 +inception_resnet_v2 | 321 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2 | 321 | Conv2d_2a_3x3 | 7 | 2 | 0 +inception_resnet_v2 | 321 | Conv2d_2b_3x3 | 11 | 2 | 2 +inception_resnet_v2 | 321 | MaxPool_3a_3x3 | 15 | 4 | 2 +inception_resnet_v2 | 321 | Conv2d_3b_1x1 | 15 | 4 | 2 +inception_resnet_v2 | 321 | Conv2d_4a_3x3 | 23 | 4 | 2 +inception_resnet_v2 | 321 | MaxPool_5a_3x3 | 31 | 8 | 2 +inception_resnet_v2 | 321 | Mixed_5b | 63 | 8 | 18 +inception_resnet_v2 | 321 | Mixed_6a | 415 | 16 | 186 +inception_resnet_v2 | 321 | PreAuxLogits | 2335 | 16 | 1146 +inception_resnet_v2 | 321 | Mixed_7a | 2399 | 32 | 1162 +inception_resnet_v2 | 321 | Conv2d_7b_1x1 | 3039 | 32 | 1482 +inception_resnet_v2-same | None | Conv2d_1a_3x3 | 3 | 2 | None +inception_resnet_v2-same | None | Conv2d_2a_3x3 | 7 | 2 | None +inception_resnet_v2-same | None | Conv2d_2b_3x3 | 11 | 2 | None +inception_resnet_v2-same | None | MaxPool_3a_3x3 | 15 | 4 | None +inception_resnet_v2-same | None | Conv2d_3b_1x1 | 15 | 4 | None +inception_resnet_v2-same | None | Conv2d_4a_3x3 | 23 | 4 | None +inception_resnet_v2-same | None | MaxPool_5a_3x3 | 31 | 8 | None +inception_resnet_v2-same | None | Mixed_5b | 63 | 8 | None +inception_resnet_v2-same | None | Mixed_6a | 415 | 16 | None +inception_resnet_v2-same | None | PreAuxLogits | 2335 | 16 | None +inception_resnet_v2-same | None | Mixed_7a | 2399 | 32 | None +inception_resnet_v2-same | None | Conv2d_7b_1x1 | 3039 | 32 | None +inception_resnet_v2-same | 224 | Conv2d_1a_3x3 | 3 | 2 | 0 +inception_resnet_v2-same | 224 | Conv2d_2a_3x3 | 7 | 2 | 2 +inception_resnet_v2-same | 224 | Conv2d_2b_3x3 | 11 | 2 | 4 +inception_resnet_v2-same | 224 | MaxPool_3a_3x3 | 15 | 4 | 4 +inception_resnet_v2-same | 224 | Conv2d_3b_1x1 | 15 | 4 | 4 +inception_resnet_v2-same | 224 | Conv2d_4a_3x3 | 23 | 4 | 8 +inception_resnet_v2-same | 224 | MaxPool_5a_3x3 | 31 | 8 | 8 +inception_resnet_v2-same | 224 | Mixed_5b | 63 | 8 | 24 +inception_resnet_v2-same | 224 | Mixed_6a | 415 | 16 | 192 +inception_resnet_v2-same | 224 | PreAuxLogits | 2335 | 16 | 1152 +inception_resnet_v2-same | 224 | Mixed_7a | 2399 | 32 | 1168 +inception_resnet_v2-same | 224 | Conv2d_7b_1x1 | 3039 | 32 | 1488 +inception_resnet_v2-same | 321 | Conv2d_1a_3x3 | 3 | 2 | 1 +inception_resnet_v2-same | 321 | Conv2d_2a_3x3 | 7 | 2 | 3 +inception_resnet_v2-same | 321 | Conv2d_2b_3x3 | 11 | 2 | 5 +inception_resnet_v2-same | 321 | MaxPool_3a_3x3 | 15 | 4 | 7 +inception_resnet_v2-same | 321 | Conv2d_3b_1x1 | 15 | 4 | 7 +inception_resnet_v2-same | 321 | Conv2d_4a_3x3 | 23 | 4 | 11 +inception_resnet_v2-same | 321 | MaxPool_5a_3x3 | 31 | 8 | 15 +inception_resnet_v2-same | 321 | Mixed_5b | 63 | 8 | 31 +inception_resnet_v2-same | 321 | Mixed_6a | 415 | 16 | 207 +inception_resnet_v2-same | 321 | PreAuxLogits | 2335 | 16 | 1167 +inception_resnet_v2-same | 321 | Mixed_7a | 2399 | 32 | 1199 +inception_resnet_v2-same | 321 | Conv2d_7b_1x1 | 3039 | 32 | 1519 +mobilenet_v1 | None | Conv2d_0 | 3 | 2 | None +mobilenet_v1 | None | Conv2d_1_pointwise | 7 | 2 | None +mobilenet_v1 | None | Conv2d_2_pointwise | 11 | 4 | None +mobilenet_v1 | None | Conv2d_3_pointwise | 19 | 4 | None +mobilenet_v1 | None | Conv2d_4_pointwise | 27 | 8 | None +mobilenet_v1 | None | Conv2d_5_pointwise | 43 | 8 | None +mobilenet_v1 | None | Conv2d_6_pointwise | 59 | 16 | None +mobilenet_v1 | None | Conv2d_7_pointwise | 91 | 16 | None +mobilenet_v1 | None | Conv2d_8_pointwise | 123 | 16 | None +mobilenet_v1 | None | Conv2d_9_pointwise | 155 | 16 | None +mobilenet_v1 | None | Conv2d_10_pointwise | 187 | 16 | None +mobilenet_v1 | None | Conv2d_11_pointwise | 219 | 16 | None +mobilenet_v1 | None | Conv2d_12_pointwise | 251 | 32 | None +mobilenet_v1 | None | Conv2d_13_pointwise | 315 | 32 | None +mobilenet_v1 | 224 | Conv2d_0 | 3 | 2 | 0 +mobilenet_v1 | 224 | Conv2d_1_pointwise | 7 | 2 | 2 +mobilenet_v1 | 224 | Conv2d_2_pointwise | 11 | 4 | 2 +mobilenet_v1 | 224 | Conv2d_3_pointwise | 19 | 4 | 6 +mobilenet_v1 | 224 | Conv2d_4_pointwise | 27 | 8 | 6 +mobilenet_v1 | 224 | Conv2d_5_pointwise | 43 | 8 | 14 +mobilenet_v1 | 224 | Conv2d_6_pointwise | 59 | 16 | 14 +mobilenet_v1 | 224 | Conv2d_7_pointwise | 91 | 16 | 30 +mobilenet_v1 | 224 | Conv2d_8_pointwise | 123 | 16 | 46 +mobilenet_v1 | 224 | Conv2d_9_pointwise | 155 | 16 | 62 +mobilenet_v1 | 224 | Conv2d_10_pointwise | 187 | 16 | 78 +mobilenet_v1 | 224 | Conv2d_11_pointwise | 219 | 16 | 94 +mobilenet_v1 | 224 | Conv2d_12_pointwise | 251 | 32 | 94 +mobilenet_v1 | 224 | Conv2d_13_pointwise | 315 | 32 | 126 +mobilenet_v1 | 321 | Conv2d_0 | 3 | 2 | 1 +mobilenet_v1 | 321 | Conv2d_1_pointwise | 7 | 2 | 3 +mobilenet_v1 | 321 | Conv2d_2_pointwise | 11 | 4 | 5 +mobilenet_v1 | 321 | Conv2d_3_pointwise | 19 | 4 | 9 +mobilenet_v1 | 321 | Conv2d_4_pointwise | 27 | 8 | 13 +mobilenet_v1 | 321 | Conv2d_5_pointwise | 43 | 8 | 21 +mobilenet_v1 | 321 | Conv2d_6_pointwise | 59 | 16 | 29 +mobilenet_v1 | 321 | Conv2d_7_pointwise | 91 | 16 | 45 +mobilenet_v1 | 321 | Conv2d_8_pointwise | 123 | 16 | 61 +mobilenet_v1 | 321 | Conv2d_9_pointwise | 155 | 16 | 77 +mobilenet_v1 | 321 | Conv2d_10_pointwise | 187 | 16 | 93 +mobilenet_v1 | 321 | Conv2d_11_pointwise | 219 | 16 | 109 +mobilenet_v1 | 321 | Conv2d_12_pointwise | 251 | 32 | 125 +mobilenet_v1 | 321 | Conv2d_13_pointwise | 315 | 32 | 157 +mobilenet_v1_075 | None | Conv2d_0 | 3 | 2 | None +mobilenet_v1_075 | None | Conv2d_1_pointwise | 7 | 2 | None +mobilenet_v1_075 | None | Conv2d_2_pointwise | 11 | 4 | None +mobilenet_v1_075 | None | Conv2d_3_pointwise | 19 | 4 | None +mobilenet_v1_075 | None | Conv2d_4_pointwise | 27 | 8 | None +mobilenet_v1_075 | None | Conv2d_5_pointwise | 43 | 8 | None +mobilenet_v1_075 | None | Conv2d_6_pointwise | 59 | 16 | None +mobilenet_v1_075 | None | Conv2d_7_pointwise | 91 | 16 | None +mobilenet_v1_075 | None | Conv2d_8_pointwise | 123 | 16 | None +mobilenet_v1_075 | None | Conv2d_9_pointwise | 155 | 16 | None +mobilenet_v1_075 | None | Conv2d_10_pointwise | 187 | 16 | None +mobilenet_v1_075 | None | Conv2d_11_pointwise | 219 | 16 | None +mobilenet_v1_075 | None | Conv2d_12_pointwise | 251 | 32 | None +mobilenet_v1_075 | None | Conv2d_13_pointwise | 315 | 32 | None +mobilenet_v1_075 | 224 | Conv2d_0 | 3 | 2 | 0 +mobilenet_v1_075 | 224 | Conv2d_1_pointwise | 7 | 2 | 2 +mobilenet_v1_075 | 224 | Conv2d_2_pointwise | 11 | 4 | 2 +mobilenet_v1_075 | 224 | Conv2d_3_pointwise | 19 | 4 | 6 +mobilenet_v1_075 | 224 | Conv2d_4_pointwise | 27 | 8 | 6 +mobilenet_v1_075 | 224 | Conv2d_5_pointwise | 43 | 8 | 14 +mobilenet_v1_075 | 224 | Conv2d_6_pointwise | 59 | 16 | 14 +mobilenet_v1_075 | 224 | Conv2d_7_pointwise | 91 | 16 | 30 +mobilenet_v1_075 | 224 | Conv2d_8_pointwise | 123 | 16 | 46 +mobilenet_v1_075 | 224 | Conv2d_9_pointwise | 155 | 16 | 62 +mobilenet_v1_075 | 224 | Conv2d_10_pointwise | 187 | 16 | 78 +mobilenet_v1_075 | 224 | Conv2d_11_pointwise | 219 | 16 | 94 +mobilenet_v1_075 | 224 | Conv2d_12_pointwise | 251 | 32 | 94 +mobilenet_v1_075 | 224 | Conv2d_13_pointwise | 315 | 32 | 126 +mobilenet_v1_075 | 321 | Conv2d_0 | 3 | 2 | 1 +mobilenet_v1_075 | 321 | Conv2d_1_pointwise | 7 | 2 | 3 +mobilenet_v1_075 | 321 | Conv2d_2_pointwise | 11 | 4 | 5 +mobilenet_v1_075 | 321 | Conv2d_3_pointwise | 19 | 4 | 9 +mobilenet_v1_075 | 321 | Conv2d_4_pointwise | 27 | 8 | 13 +mobilenet_v1_075 | 321 | Conv2d_5_pointwise | 43 | 8 | 21 +mobilenet_v1_075 | 321 | Conv2d_6_pointwise | 59 | 16 | 29 +mobilenet_v1_075 | 321 | Conv2d_7_pointwise | 91 | 16 | 45 +mobilenet_v1_075 | 321 | Conv2d_8_pointwise | 123 | 16 | 61 +mobilenet_v1_075 | 321 | Conv2d_9_pointwise | 155 | 16 | 77 +mobilenet_v1_075 | 321 | Conv2d_10_pointwise | 187 | 16 | 93 +mobilenet_v1_075 | 321 | Conv2d_11_pointwise | 219 | 16 | 109 +mobilenet_v1_075 | 321 | Conv2d_12_pointwise | 251 | 32 | 125 +mobilenet_v1_075 | 321 | Conv2d_13_pointwise | 315 | 32 | 157 +resnet_v1_50 | None | resnet_v1_50/block1 | 35 | 8 | None +resnet_v1_50 | None | resnet_v1_50/block2 | 99 | 16 | None +resnet_v1_50 | None | resnet_v1_50/block3 | 291 | 32 | None +resnet_v1_50 | None | resnet_v1_50/block4 | 483 | 32 | None +resnet_v1_50 | 224 | resnet_v1_50/block1 | 35 | 8 | 15 +resnet_v1_50 | 224 | resnet_v1_50/block2 | 99 | 16 | 47 +resnet_v1_50 | 224 | resnet_v1_50/block3 | 291 | 32 | 143 +resnet_v1_50 | 224 | resnet_v1_50/block4 | 483 | 32 | 239 +resnet_v1_50 | 321 | resnet_v1_50/block1 | 35 | 8 | 17 +resnet_v1_50 | 321 | resnet_v1_50/block2 | 99 | 16 | 49 +resnet_v1_50 | 321 | resnet_v1_50/block3 | 291 | 32 | 145 +resnet_v1_50 | 321 | resnet_v1_50/block4 | 483 | 32 | 241 +resnet_v1_101 | None | resnet_v1_101/block1 | 35 | 8 | None +resnet_v1_101 | None | resnet_v1_101/block2 | 99 | 16 | None +resnet_v1_101 | None | resnet_v1_101/block3 | 835 | 32 | None +resnet_v1_101 | None | resnet_v1_101/block4 | 1027 | 32 | None +resnet_v1_101 | 224 | resnet_v1_101/block1 | 35 | 8 | 15 +resnet_v1_101 | 224 | resnet_v1_101/block2 | 99 | 16 | 47 +resnet_v1_101 | 224 | resnet_v1_101/block3 | 835 | 32 | 415 +resnet_v1_101 | 224 | resnet_v1_101/block4 | 1027 | 32 | 511 +resnet_v1_101 | 321 | resnet_v1_101/block1 | 35 | 8 | 17 +resnet_v1_101 | 321 | resnet_v1_101/block2 | 99 | 16 | 49 +resnet_v1_101 | 321 | resnet_v1_101/block3 | 835 | 32 | 417 +resnet_v1_101 | 321 | resnet_v1_101/block4 | 1027 | 32 | 513 +resnet_v1_152 | None | resnet_v1_152/block1 | 35 | 8 | None +resnet_v1_152 | None | resnet_v1_152/block2 | 163 | 16 | None +resnet_v1_152 | None | resnet_v1_152/block3 | 1315 | 32 | None +resnet_v1_152 | None | resnet_v1_152/block4 | 1507 | 32 | None +resnet_v1_152 | 224 | resnet_v1_152/block1 | 35 | 8 | 15 +resnet_v1_152 | 224 | resnet_v1_152/block2 | 163 | 16 | 79 +resnet_v1_152 | 224 | resnet_v1_152/block3 | 1315 | 32 | 655 +resnet_v1_152 | 224 | resnet_v1_152/block4 | 1507 | 32 | 751 +resnet_v1_152 | 321 | resnet_v1_152/block1 | 35 | 8 | 17 +resnet_v1_152 | 321 | resnet_v1_152/block2 | 163 | 16 | 81 +resnet_v1_152 | 321 | resnet_v1_152/block3 | 1315 | 32 | 657 +resnet_v1_152 | 321 | resnet_v1_152/block4 | 1507 | 32 | 753 +resnet_v1_200 | None | resnet_v1_200/block1 | 35 | 8 | None +resnet_v1_200 | None | resnet_v1_200/block2 | 419 | 16 | None +resnet_v1_200 | None | resnet_v1_200/block3 | 1571 | 32 | None +resnet_v1_200 | None | resnet_v1_200/block4 | 1763 | 32 | None +resnet_v1_200 | 224 | resnet_v1_200/block1 | 35 | 8 | 15 +resnet_v1_200 | 224 | resnet_v1_200/block2 | 419 | 16 | 207 +resnet_v1_200 | 224 | resnet_v1_200/block3 | 1571 | 32 | 783 +resnet_v1_200 | 224 | resnet_v1_200/block4 | 1763 | 32 | 879 +resnet_v1_200 | 321 | resnet_v1_200/block1 | 35 | 8 | 17 +resnet_v1_200 | 321 | resnet_v1_200/block2 | 419 | 16 | 209 +resnet_v1_200 | 321 | resnet_v1_200/block3 | 1571 | 32 | 785 +resnet_v1_200 | 321 | resnet_v1_200/block4 | 1763 | 32 | 881 +resnet_v2_50 | None | resnet_v2_50/block1 | 35 | 8 | None +resnet_v2_50 | None | resnet_v2_50/block2 | 99 | 16 | None +resnet_v2_50 | None | resnet_v2_50/block3 | 291 | 32 | None +resnet_v2_50 | None | resnet_v2_50/block4 | 483 | 32 | None +resnet_v2_50 | 224 | resnet_v2_50/block1 | 35 | 8 | 15 +resnet_v2_50 | 224 | resnet_v2_50/block2 | 99 | 16 | 47 +resnet_v2_50 | 224 | resnet_v2_50/block3 | 291 | 32 | 143 +resnet_v2_50 | 224 | resnet_v2_50/block4 | 483 | 32 | 239 +resnet_v2_50 | 321 | resnet_v2_50/block1 | 35 | 8 | 17 +resnet_v2_50 | 321 | resnet_v2_50/block2 | 99 | 16 | 49 +resnet_v2_50 | 321 | resnet_v2_50/block3 | 291 | 32 | 145 +resnet_v2_50 | 321 | resnet_v2_50/block4 | 483 | 32 | 241 +resnet_v2_101 | None | resnet_v2_101/block1 | 35 | 8 | None +resnet_v2_101 | None | resnet_v2_101/block2 | 99 | 16 | None +resnet_v2_101 | None | resnet_v2_101/block3 | 835 | 32 | None +resnet_v2_101 | None | resnet_v2_101/block4 | 1027 | 32 | None +resnet_v2_101 | 224 | resnet_v2_101/block1 | 35 | 8 | 15 +resnet_v2_101 | 224 | resnet_v2_101/block2 | 99 | 16 | 47 +resnet_v2_101 | 224 | resnet_v2_101/block3 | 835 | 32 | 415 +resnet_v2_101 | 224 | resnet_v2_101/block4 | 1027 | 32 | 511 +resnet_v2_101 | 321 | resnet_v2_101/block1 | 35 | 8 | 17 +resnet_v2_101 | 321 | resnet_v2_101/block2 | 99 | 16 | 49 +resnet_v2_101 | 321 | resnet_v2_101/block3 | 835 | 32 | 417 +resnet_v2_101 | 321 | resnet_v2_101/block4 | 1027 | 32 | 513 +resnet_v2_152 | None | resnet_v2_152/block1 | 35 | 8 | None +resnet_v2_152 | None | resnet_v2_152/block2 | 163 | 16 | None +resnet_v2_152 | None | resnet_v2_152/block3 | 1315 | 32 | None +resnet_v2_152 | None | resnet_v2_152/block4 | 1507 | 32 | None +resnet_v2_152 | 224 | resnet_v2_152/block1 | 35 | 8 | 15 +resnet_v2_152 | 224 | resnet_v2_152/block2 | 163 | 16 | 79 +resnet_v2_152 | 224 | resnet_v2_152/block3 | 1315 | 32 | 655 +resnet_v2_152 | 224 | resnet_v2_152/block4 | 1507 | 32 | 751 +resnet_v2_152 | 321 | resnet_v2_152/block1 | 35 | 8 | 17 +resnet_v2_152 | 321 | resnet_v2_152/block2 | 163 | 16 | 81 +resnet_v2_152 | 321 | resnet_v2_152/block3 | 1315 | 32 | 657 +resnet_v2_152 | 321 | resnet_v2_152/block4 | 1507 | 32 | 753 +resnet_v2_200 | None | resnet_v2_200/block1 | 35 | 8 | None +resnet_v2_200 | None | resnet_v2_200/block2 | 419 | 16 | None +resnet_v2_200 | None | resnet_v2_200/block3 | 1571 | 32 | None +resnet_v2_200 | None | resnet_v2_200/block4 | 1763 | 32 | None +resnet_v2_200 | 224 | resnet_v2_200/block1 | 35 | 8 | 15 +resnet_v2_200 | 224 | resnet_v2_200/block2 | 419 | 16 | 207 +resnet_v2_200 | 224 | resnet_v2_200/block3 | 1571 | 32 | 783 +resnet_v2_200 | 224 | resnet_v2_200/block4 | 1763 | 32 | 879 +resnet_v2_200 | 321 | resnet_v2_200/block1 | 35 | 8 | 17 +resnet_v2_200 | 321 | resnet_v2_200/block2 | 419 | 16 | 209 +resnet_v2_200 | 321 | resnet_v2_200/block3 | 1571 | 32 | 785 +resnet_v2_200 | 321 | resnet_v2_200/block4 | 1763 | 32 | 881 + +## FAQ + +### What does a resolution of 'None' mean? + +In this case, the input resolution is undefined. For most models, the receptive +field parameters can be computed even without knowing the input resolution. + +### For some networks, effective_padding shows as 'None' (eg, for Inception_v2 or Mobilenet_v1 when input size is not specified). Why is that? + +This means that the padding for these networks depends on the input size. So, +unless we know exactly the input image dimensionality to be used, it is not +possible to determine the padding applied at the different layers. Look at the +other entries where the input size is fixed; for those cases, effective_padding +is not None. + +This happens due to Tensorflow's implementation of the 'SAME' padding mode, +which may depend on the input feature map size to a given layer. For background +on this, see [these notes from the TF +documentation](https://www.tensorflow.org/versions/master/api_guides/python/nn#Notes_on_SAME_Convolution_Padding). + +Also, note that in this case the program is not able to check if the network is +aligned (ie, it could be that the different paths from input to output have +receptive fields which are not consistently centered at the same position in the +input image). + +So you should be aware that such networks might not be aligned -- the program +has no way of checking it when the padding cannot be determined. + +### The receptive field parameters for network X seem different from what I expected... maybe your calculation is incorrect? + +First, note that the results presented here are based on the tensorflow +implementations from the [TF-Slim model +library](https://github.com/tensorflow/models/tree/master/research/slim). + +So, it is possible that due to some implementation details the RF parameters are +different. + +One common case of confusion is the TF-Slim Resnet implementation, which applies +stride in the last residual unit of each block, instead of at the input +activations in the first residual unit of each block (which is what is described +in the Resnet paper) -- see [this +comment](https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_utils.py#L30). +This makes the stride with respect to each convolution block potentially +different. In this case, though, note that a +[flag](https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v1.py#L150) +may be used to recover the original striding convention. + +Second, it could be that we have a bug somewhere. While we include [many +tests](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py) +in our library, it is always possible that we missed something. If you suspect +this is happening, please file a GitHub issue +[here](https://github.com/tensorflow/tensorflow/issues). diff --git a/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py b/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py new file mode 100644 index 0000000000000000000000000000000000000000..4495d74bbf66fa461a05f38b430dd404d7da4b08 --- /dev/null +++ b/tensorflow/contrib/receptive_field/python/util/examples/csv_to_markdown_table.py @@ -0,0 +1,82 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Simple script to convert CSV output from rf_benchmark to Markdown format. + +The input CSV should have the following fields: +- CNN +- input resolution +- end_point +- RF size hor +- RF size ver +- effective stride hor +- effective stride ver +- effective padding hor +- effective padding ver + +Since usually in all cases the parameters in the horizontal and vertical +directions are the same, this is assumed by this script, which only prints one +of them to the Markdown file. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import csv +import sys + +from tensorflow.python.platform import app + +cmd_args = None + + +def main(unused_argv): + with open(cmd_args.markdown_path, 'w') as f: + # Write table header and field size. + f.write('CNN | resolution | end-point | RF | effective stride | ' + 'effective padding|\n') + f.write( + ':--------------------: | :----------: | :---------------: | :-----: |' + ' :----: | :----:|\n') + with open(cmd_args.csv_path) as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + # Make sure horizontal and parameters are the same. + assert row['RF size hor'] == row['RF size ver'] + assert row['effective stride hor'] == row['effective stride ver'] + assert row['effective padding hor'] == row['effective padding ver'] + + f.write('%s|%s|%s|%s|%s|%s\n' % + (row['CNN'], row['input resolution'], row['end_point'], + row['RF size hor'], row['effective stride hor'], + row['effective padding hor'])) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--csv_path', + type=str, + default='/tmp/rf.csv', + help='Path where CSV output of rf_benchmark was saved.') + parser.add_argument( + '--markdown_path', + type=str, + default='/tmp/rf.md', + help='Path where Markdown output will be saved.') + cmd_args, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py index bc383a803496380aaba4d0248d2b7f93253b2b50..0e3c46f17d2e2a277418d39e31927db73a509670 100644 --- a/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py +++ b/tensorflow/contrib/receptive_field/python/util/parse_layer_parameters.py @@ -27,7 +27,7 @@ from tensorflow.python.platform import tf_logging as logging _UNCHANGED_RF_LAYER_OPS = [ "Add", "BiasAdd", "Cast", "Ceil", "ConcatV2", "Const", "Floor", "FusedBatchNorm", "Identity", "Log", "Mul", "Pow", "RealDiv", "Relu", - "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2" + "Relu6", "Round", "Rsqrt", "Softplus", "Sub", "VariableV2", "LRN" ] # Different ways in which padding modes may be spelled. diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 43c0f7595590802aa80e1012967d377a6ab83d29..4eb5c920b3517a8968ff730003e786ae2a9c9e26 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -193,6 +193,10 @@ tf_py_test( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = [ + "manual", + "notap", + ], ) cuda_py_tests( diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index fdecceff526a860a274354e53e824b98d11418a6..6bd58c4d322c04d4d14d04678e24a05c0f876208 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -1,4 +1,4 @@ -package(default_visibility = ["//tensorflow:__subpackages__"]) +package(default_visibility = ["//tensorflow:internal"]) licenses(["notice"]) # Apache 2.0 diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index 94fc12ca814721acf62f16b72ffa50473043cc8b..3d0308aaf3da3b5b16fd22a2905db36917e8c97b 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -26,7 +26,6 @@ import time import numpy as np from tensorflow.contrib.framework.python.ops import variables as variables_lib -from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.contrib.slim.python.slim import evaluation from tensorflow.contrib.training.python.training import evaluation as evaluation_lib from tensorflow.core.protobuf import saver_pb2 @@ -37,6 +36,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics from tensorflow.python.ops import variables from tensorflow.python.platform import flags from tensorflow.python.platform import gfile @@ -89,8 +89,8 @@ class EvaluationTest(test.TestCase): self._predictions, self._scale = TestModel(self._inputs) def testFinalOpsOnEvaluationLoop(self): - value_op, update_op = metric_ops.streaming_accuracy(self._predictions, - self._labels) + value_op, update_op = metrics.accuracy( + labels=self._labels, predictions=self._predictions) init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) # Create checkpoint and log directories: @@ -136,9 +136,10 @@ class EvaluationTest(test.TestCase): self.assertTrue(obj.hook_was_run) def _create_names_to_metrics(self, predictions, labels): - accuracy0, update_op0 = metric_ops.streaming_accuracy(predictions, labels) - accuracy1, update_op1 = metric_ops.streaming_accuracy(predictions + 1, - labels) + accuracy0, update_op0 = metrics.accuracy( + labels=labels, predictions=predictions) + accuracy1, update_op1 = metrics.accuracy( + labels=labels, predictions=predictions + 1) names_to_values = {'Accuracy': accuracy0, 'Another_accuracy': accuracy1} names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1} @@ -198,8 +199,8 @@ class EvaluationTest(test.TestCase): predictions_limited = input.limit_epochs(self._predictions, num_epochs=1) labels_limited = input.limit_epochs(self._labels, num_epochs=1) - value_op, update_op = metric_ops.streaming_accuracy( - predictions_limited, labels_limited) + value_op, update_op = metrics.accuracy( + labels=labels_limited, predictions=predictions_limited) init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) @@ -260,8 +261,8 @@ class SingleEvaluationTest(test.TestCase): self._prepareCheckpoint(checkpoint_path) # Next, determine the metric to evaluate: - value_op, update_op = metric_ops.streaming_accuracy(self._predictions, - self._labels) + value_op, update_op = metrics.accuracy( + labels=self._labels, predictions=self._predictions) # Run the evaluation and verify the results: accuracy_value = evaluation.evaluate_once( @@ -276,8 +277,8 @@ class SingleEvaluationTest(test.TestCase): self._prepareCheckpoint(checkpoint_path) # Next, determine the metric to evaluate: - value_op, update_op = metric_ops.streaming_accuracy(self._predictions, - self._labels) + value_op, update_op = metrics.accuracy( + labels=self._labels, predictions=self._predictions) dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir') dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False) diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py index 9305c6a11c4ec898c82553773e8e7277a54ab82e..85918bf8506623cf5e0c9106ae9ed80e233f5a7d 100644 --- a/tensorflow/contrib/solvers/python/ops/linear_equations.py +++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py @@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import linalg_ops def conjugate_gradient(operator, diff --git a/tensorflow/contrib/stat_summarizer/BUILD b/tensorflow/contrib/stat_summarizer/BUILD index 30be14c10cd8576ded75b8489cc89d439a9cc282..0b8fc0cdc66ae41807cce92776ada263675b1f94 100644 --- a/tensorflow/contrib/stat_summarizer/BUILD +++ b/tensorflow/contrib/stat_summarizer/BUILD @@ -31,5 +31,8 @@ tf_py_test( "//tensorflow/python:math_ops", "//tensorflow/python:variables", ], - tags = ["no_windows"], + tags = [ + "no_windows", + "notap", # TODO(b/80546574): test is flaky + ], ) diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index e893e1d1c836cc7feef15757dde79d0db362cbaf..d8236a0a6fa6d0d0e383e454eb0146bb10b6f49d 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -21,10 +21,10 @@ import numpy as np from tensorflow.contrib import losses from tensorflow.contrib.learn.python.learn.estimators import prediction_key -from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics from tensorflow.python.ops import nn INFERENCE_PROB_NAME = prediction_key.PredictionKey.PROBABILITIES @@ -38,12 +38,13 @@ def _top_k_generator(k): targets = math_ops.to_int32(targets) if targets.get_shape().ndims > 1: targets = array_ops.squeeze(targets, axis=[1]) - return metric_ops.streaming_mean(nn.in_top_k(probabilities, targets, k)) + return metrics.mean(nn.in_top_k(probabilities, targets, k)) return _top_k def _accuracy(predictions, targets, weights=None): - return metric_ops.streaming_accuracy(predictions, targets, weights=weights) + return metrics.accuracy( + labels=targets, predictions=predictions, weights=weights) def _r2(probabilities, targets, weights=None): @@ -53,7 +54,7 @@ def _r2(probabilities, targets, weights=None): squares_residuals = math_ops.reduce_sum( math_ops.square(targets - probabilities), 0) score = 1 - math_ops.reduce_sum(squares_residuals / squares_total) - return metric_ops.streaming_mean(score, weights=weights) + return metrics.mean(score, weights=weights) def _squeeze_and_onehot(targets, depth): @@ -62,7 +63,7 @@ def _squeeze_and_onehot(targets, depth): def _sigmoid_entropy(probabilities, targets, weights=None): - return metric_ops.streaming_mean( + return metrics.mean( losses.sigmoid_cross_entropy(probabilities, _squeeze_and_onehot( targets, @@ -71,7 +72,7 @@ def _sigmoid_entropy(probabilities, targets, weights=None): def _softmax_entropy(probabilities, targets, weights=None): - return metric_ops.streaming_mean( + return metrics.mean( losses.sparse_softmax_cross_entropy(probabilities, math_ops.to_int32(targets)), weights=weights) @@ -82,7 +83,7 @@ def _predictions(predictions, unused_targets, **unused_kwargs): def _class_log_loss(probabilities, targets, weights=None): - return metric_ops.streaming_mean( + return metrics.mean( losses.log_loss(probabilities, _squeeze_and_onehot(targets, array_ops.shape(probabilities)[1])), @@ -90,34 +91,36 @@ def _class_log_loss(probabilities, targets, weights=None): def _precision(predictions, targets, weights=None): - return metric_ops.streaming_precision(predictions, targets, weights=weights) + return metrics.precision( + labels=targets, predictions=predictions, weights=weights) def _precision_at_thresholds(predictions, targets, weights=None): - return metric_ops.streaming_precision_at_thresholds( - array_ops.slice(predictions, [0, 1], [-1, 1]), - targets, - np.arange( - 0, 1, 0.01, dtype=np.float32), + return metrics.precision_at_thresholds( + labels=targets, + predictions=array_ops.slice(predictions, [0, 1], [-1, 1]), + thresholds=np.arange(0, 1, 0.01, dtype=np.float32), weights=weights) def _recall(predictions, targets, weights=None): - return metric_ops.streaming_recall(predictions, targets, weights=weights) + return metrics.recall( + labels=targets, predictions=predictions, weights=weights) def _recall_at_thresholds(predictions, targets, weights=None): - return metric_ops.streaming_recall_at_thresholds( - array_ops.slice(predictions, [0, 1], [-1, 1]), - targets, - np.arange( - 0, 1, 0.01, dtype=np.float32), + return metrics.recall_at_thresholds( + labels=targets, + predictions=array_ops.slice(predictions, [0, 1], [-1, 1]), + thresholds=np.arange(0, 1, 0.01, dtype=np.float32), weights=weights) def _auc(probs, targets, weights=None): - return metric_ops.streaming_auc(array_ops.slice(probs, [0, 1], [-1, 1]), - targets, weights=weights) + return metrics.auc( + labels=targets, + predictions=array_ops.slice(probs, [0, 1], [-1, 1]), + weights=weights) _EVAL_METRICS = { diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc index 630c0607ae21d0276a9dd0507346d5dc4ed9f4a9..cfdc884277a025aa11995d329389f3748b17490c 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" +#include + #include "tensorflow/contrib/tensorboard/db/summary_converter.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -66,14 +68,9 @@ const char* kImagePluginName = "images"; const char* kAudioPluginName = "audio"; const char* kHistogramPluginName = "histograms"; -const int kScalarSlots = 10000; -const int kImageSlots = 10; -const int kAudioSlots = 10; -const int kHistogramSlots = 1; -const int kTensorSlots = 10; - const int64 kReserveMinBytes = 32; const double kReserveMultiplier = 1.5; +const int64 kPreallocateRows = 1000; // Flush is a misnomer because what we're actually doing is having lots // of commits inside any SqliteTransaction that writes potentially @@ -139,22 +136,6 @@ void PatchPluginName(SummaryMetadata* metadata, const char* name) { } } -int GetSlots(const Tensor& t, const SummaryMetadata& metadata) { - if (metadata.plugin_data().plugin_name() == kScalarPluginName) { - return kScalarSlots; - } else if (metadata.plugin_data().plugin_name() == kImagePluginName) { - return kImageSlots; - } else if (metadata.plugin_data().plugin_name() == kAudioPluginName) { - return kAudioSlots; - } else if (metadata.plugin_data().plugin_name() == kHistogramPluginName) { - return kHistogramSlots; - } else if (t.dims() == 0 && t.dtype() != DT_STRING) { - return kScalarSlots; - } else { - return kTensorSlots; - } -} - Status SetDescription(Sqlite* db, int64 id, const StringPiece& markdown) { const char* sql = R"sql( INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?) @@ -481,24 +462,6 @@ class RunMetadata { return insert.StepAndReset(); } - Status GetIsWatching(Sqlite* db, bool* is_watching) - SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { - mutex_lock lock(mu_); - if (experiment_id_ == kAbsent) { - *is_watching = true; - return Status::OK(); - } - const char* sql = R"sql( - SELECT is_watching FROM Experiments WHERE experiment_id = ? - )sql"; - SqliteStatement stmt; - TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt)); - stmt.BindInt(1, experiment_id_); - TF_RETURN_IF_ERROR(stmt.StepOnce()); - *is_watching = stmt.ColumnInt(0) != 0; - return Status::OK(); - } - private: Status InitializeUser(Sqlite* db, uint64 now) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (user_id_ != kAbsent || user_name_.empty()) return Status::OK(); @@ -659,43 +622,15 @@ class RunMetadata { /// \brief Tensor writer for a single series, e.g. Tag. /// -/// This class can be used to write an infinite stream of Tensors to the -/// database in a fixed block of contiguous disk space. This is -/// accomplished using Algorithm R reservoir sampling. -/// -/// The reservoir consists of a fixed number of rows, which are inserted -/// using ZEROBLOB upon receiving the first sample, which is used to -/// predict how big the other ones are likely to be. This is done -/// transactionally in a way that tries to be mindful of other processes -/// that might be trying to access the same DB. -/// -/// Once the reservoir fills up, rows are replaced at random, and writes -/// gradually become no-ops. This allows long training to go fast -/// without configuration. The exception is when someone is actually -/// looking at TensorBoard. When that happens, the "keep last" behavior -/// is turned on and Append() will always result in a write. -/// -/// If no one is watching training, this class still holds on to the -/// most recent "dangling" Tensor, so if Finish() is called, the most -/// recent training state can be written to disk. -/// -/// The randomly selected sampling points should be consistent across -/// multiple instances. -/// /// This class is thread safe. class SeriesWriter { public: - SeriesWriter(int64 series, int slots, RunMetadata* meta) - : series_{series}, - slots_{slots}, - meta_{meta}, - rng_{std::mt19937_64::default_seed} { + SeriesWriter(int64 series, RunMetadata* meta) : series_{series}, meta_{meta} { DCHECK(series_ > 0); - DCHECK(slots_ > 0); } Status Append(Sqlite* db, int64 step, uint64 now, double computed_time, - Tensor t) SQLITE_TRANSACTIONS_EXCLUDED(*db) + const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); if (rowids_.empty()) { @@ -705,41 +640,20 @@ class SeriesWriter { return s; } } - DCHECK(rowids_.size() == slots_); - int64 rowid; - size_t i = count_; - if (i < slots_) { - rowid = last_rowid_ = rowids_[i]; - } else { - i = rng_() % (i + 1); - if (i < slots_) { - rowid = last_rowid_ = rowids_[i]; - } else { - bool keep_last; - TF_RETURN_IF_ERROR(meta_->GetIsWatching(db, &keep_last)); - if (!keep_last) { - ++count_; - dangling_tensor_.reset(new Tensor(std::move(t))); - dangling_step_ = step; - dangling_computed_time_ = computed_time; - return Status::OK(); - } - rowid = last_rowid_; - } - } + int64 rowid = rowids_.front(); Status s = Write(db, rowid, step, computed_time, t); if (s.ok()) { ++count_; - dangling_tensor_.reset(); } + rowids_.pop_front(); return s; } Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); - // Short runs: Delete unused pre-allocated Tensors. - if (count_ < rowids_.size()) { + // Delete unused pre-allocated Tensors. + if (!rowids_.empty()) { SqliteTransaction txn(*db); const char* sql = R"sql( DELETE FROM Tensors WHERE rowid = ? @@ -747,19 +661,13 @@ class SeriesWriter { SqliteStatement deleter; TF_RETURN_IF_ERROR(db->Prepare(sql, &deleter)); for (size_t i = count_; i < rowids_.size(); ++i) { - deleter.BindInt(1, rowids_[i]); + deleter.BindInt(1, rowids_.front()); TF_RETURN_IF_ERROR(deleter.StepAndReset()); + rowids_.pop_front(); } TF_RETURN_IF_ERROR(txn.Commit()); rowids_.clear(); } - // Long runs: Make last sample be the very most recent one. - if (dangling_tensor_) { - DCHECK(last_rowid_ != kAbsent); - TF_RETURN_IF_ERROR(Write(db, last_rowid_, dangling_step_, - dangling_computed_time_, *dangling_tensor_)); - dangling_tensor_.reset(); - } return Status::OK(); } @@ -783,7 +691,6 @@ class SeriesWriter { Status Update(Sqlite* db, int64 step, double computed_time, const Tensor& t, const StringPiece& data, int64 rowid) { - // TODO(jart): How can we ensure reservoir fills on replace? const char* sql = R"sql( UPDATE OR REPLACE Tensors @@ -878,7 +785,7 @@ class SeriesWriter { // TODO(jart): Maybe preallocate index pages by setting step. This // is tricky because UPDATE OR REPLACE can have a side // effect of deleting preallocated rows. - for (int64 i = 0; i < slots_; ++i) { + for (int64 i = 0; i < kPreallocateRows; ++i) { insert.BindInt(1, series_); insert.BindInt(2, reserved_bytes); TF_RETURN_WITH_CONTEXT_IF_ERROR(insert.StepAndReset(), "i=", i); @@ -902,16 +809,10 @@ class SeriesWriter { mutex mu_; const int64 series_; - const int slots_; RunMetadata* const meta_; - std::mt19937_64 rng_ GUARDED_BY(mu_); uint64 count_ GUARDED_BY(mu_) = 0; - int64 last_rowid_ GUARDED_BY(mu_) = kAbsent; - std::vector rowids_ GUARDED_BY(mu_); + std::deque rowids_ GUARDED_BY(mu_); uint64 unflushed_bytes_ GUARDED_BY(mu_) = 0; - std::unique_ptr dangling_tensor_ GUARDED_BY(mu_); - int64 dangling_step_ GUARDED_BY(mu_) = 0; - double dangling_computed_time_ GUARDED_BY(mu_) = 0.0; TF_DISALLOW_COPY_AND_ASSIGN(SeriesWriter); }; @@ -928,10 +829,10 @@ class RunWriter { explicit RunWriter(RunMetadata* meta) : meta_{meta} {} Status Append(Sqlite* db, int64 tag_id, int64 step, uint64 now, - double computed_time, Tensor t, int slots) + double computed_time, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) LOCKS_EXCLUDED(mu_) { - SeriesWriter* writer = GetSeriesWriter(tag_id, slots); - return writer->Append(db, step, now, computed_time, std::move(t)); + SeriesWriter* writer = GetSeriesWriter(tag_id); + return writer->Append(db, step, now, computed_time, t); } Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) @@ -948,11 +849,11 @@ class RunWriter { } private: - SeriesWriter* GetSeriesWriter(int64 tag_id, int slots) LOCKS_EXCLUDED(mu_) { + SeriesWriter* GetSeriesWriter(int64 tag_id) LOCKS_EXCLUDED(mu_) { mutex_lock sl(mu_); auto spot = series_writers_.find(tag_id); if (spot == series_writers_.end()) { - SeriesWriter* writer = new SeriesWriter(tag_id, slots, meta_); + SeriesWriter* writer = new SeriesWriter(tag_id, meta_); series_writers_[tag_id].reset(writer); return writer; } else { @@ -1082,8 +983,7 @@ class SummaryDbWriter : public SummaryWriterInterface { TF_RETURN_IF_ERROR( meta_.GetTagId(db_, now, computed_time, tag, &tag_id, metadata)); TF_RETURN_WITH_CONTEXT_IF_ERROR( - run_.Append(db_, tag_id, step, now, computed_time, t, - GetSlots(t, metadata)), + run_.Append(db_, tag_id, step, now, computed_time, t), meta_.user_name(), "/", meta_.experiment_name(), "/", meta_.run_name(), "/", tag, "@", step); return Status::OK(); @@ -1155,8 +1055,7 @@ class SummaryDbWriter : public SummaryWriterInterface { int64 tag_id; TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t, - GetSlots(t, s->metadata())); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } // TODO(jart): Refactor Summary -> Tensor logic into separate file. @@ -1169,8 +1068,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kScalarPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kScalarSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) { @@ -1201,8 +1099,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kHistogramPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kHistogramSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) { @@ -1216,8 +1113,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kImagePluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kImageSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) { @@ -1230,8 +1126,7 @@ class SummaryDbWriter : public SummaryWriterInterface { PatchPluginName(s->mutable_metadata(), kAudioPluginName); TF_RETURN_IF_ERROR(meta_.GetTagId(db_, now, e->wall_time(), s->tag(), &tag_id, s->metadata())); - return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), - std::move(t), kAudioSlots); + return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } Env* const env_; diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc index 2044692b6e746bc317843d715fa17ab5ec0bf99d..2e8d4109dd624ab66d774668ad04def9a7d3cdf2 100644 --- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -189,7 +189,7 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) { ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); int64 user_id = QueryInt("SELECT user_id FROM Users"); int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments"); @@ -238,7 +238,7 @@ TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) { ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments")); ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs")); ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(10000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(1000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); } TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { @@ -255,7 +255,7 @@ TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) { TF_ASSERT_OK(writer_->WriteEvent(std::move(e))); TF_ASSERT_OK(writer_->Flush()); ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags")); - ASSERT_EQ(20000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + ASSERT_EQ(2000LL, QueryInt("SELECT COUNT(*) FROM Tensors")); int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'π'"); int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'"); EXPECT_GT(tag1_id, 0LL); diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 7a8a71ac7f491ec48a47ae1ea1aff750a587beaa..a5d8b061b6b26f9d05be40a1162481ae219b0e9c 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -192,7 +192,7 @@ tf_py_wrap_cc( ":trt_conversion", ":trt_engine_op_kernel", "//tensorflow/core:framework_lite", - "//util/python:python_headers", + "//third_party/python_runtime:headers", ], ) @@ -303,7 +303,7 @@ tf_cuda_library( ], deps = [ "//tensorflow/core:framework_lite", - "//tensorflow/core:platform_base", + "//tensorflow/core:lib_proto_parsing", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]), diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index b7b26cfb1c05ae74e932c8b9cb2479cfca308514..da4dd5a14cd74591fc9df63cd5868044e4e369ec 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -91,8 +91,11 @@ void GetSubGraphIncomingEdges(const tensorflow::Graph& graph, if (!subgraph_node_ids.count(edge->src()->id()) && !edge->src()->IsSource() && !edge->IsControlEdge()) { incoming_edges->insert(edge); + VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() + << " Y, "; } else { - VLOG(2) << node->name() << " -> " << edge->src()->name() << " N, "; + VLOG(2) << "INCOMING " << edge->src()->name() << " -> " << node->name() + << " N, "; } } } @@ -106,10 +109,12 @@ void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph, for (const tensorflow::Edge* edge : node->out_edges()) { if (!subgraph_node_ids.count(edge->dst()->id()) && !edge->dst()->IsSink() && !edge->IsControlEdge()) { - VLOG(2) << node->name() << " -> " << edge->dst()->name() << " Y, "; + VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() + << " Y, "; outgoing_edges->insert(edge); } else { - VLOG(2) << node->name() << " -> " << edge->dst()->name() << " N, "; + VLOG(2) << "OUTGOING " << node->name() << " -> " << edge->dst()->name() + << " N, "; } } } @@ -181,29 +186,27 @@ struct ConvertGraphParams { static tensorflow::Status FillSubGraphEdgeSets(ConvertGraphParams* p) { GetSubGraphIncomingEdges(p->graph, p->subgraph_node_ids, &p->subgraph_incoming_edges); + + std::set> unique_tensors; + // Add only unique input source nodes. If output of an outside node is shared + // between multiple nodes inside the engine, only one edge should be created for (const tensorflow::Edge* edge : p->subgraph_incoming_edges) { - p->subgraph_inputs.push_back({edge->src()->id(), edge->src_output()}); - } - auto output_name_to_index_map = BuildTensorNameMap(p->output_names); - std::set> subgraph_outputs_set; - // Collect outputs referenced from output_names - for (int node_id : p->subgraph_node_ids) { - tensorflow::Node* node = p->graph.FindNodeId(node_id); - if (output_name_to_index_map.count(node->name())) { - for (int index : output_name_to_index_map.at(node->name())) { - subgraph_outputs_set.insert({node_id, index}); - } - } + unique_tensors.insert({edge->src()->id(), edge->src_output()}); } + p->subgraph_inputs.insert(p->subgraph_inputs.begin(), unique_tensors.begin(), + unique_tensors.end()); GetSubGraphOutgoingEdges(p->graph, p->subgraph_node_ids, &p->subgraph_outgoing_edges); + unique_tensors.clear(); + // Similar to above, if multiple ouside nodes are sharing the output of an + // internal node only one output port should be created and shared between + // outputs for (const tensorflow::Edge* edge : p->subgraph_outgoing_edges) { - subgraph_outputs_set.insert({edge->src()->id(), edge->src_output()}); + unique_tensors.insert({edge->src()->id(), edge->src_output()}); } - p->subgraph_outputs.reserve(subgraph_outputs_set.size()); + p->subgraph_outputs.reserve(unique_tensors.size()); p->subgraph_outputs.insert(p->subgraph_outputs.begin(), - subgraph_outputs_set.begin(), - subgraph_outputs_set.end()); + unique_tensors.begin(), unique_tensors.end()); return tensorflow::Status::OK(); } @@ -225,7 +228,6 @@ tensorflow::Status GetCalibNode(ConvertGraphParams* params) { for (auto in_edge : params->subgraph_incoming_edges) { // loop over incoming edges and // attach them to calib node - // tensorflow::Node* src_node = in_edge->src(); auto src_output = in_edge->src_output(); auto dst_node = in_edge->dst(); auto dst_input = in_edge->dst_input(); @@ -257,19 +259,24 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { for (size_t i = 0; i < params->subgraph_inputs.size(); ++i) { subgraph_edge_to_input_map.insert({params->subgraph_inputs.at(i), i}); } + std::set> unique_tensors; for (const tensorflow::Edge* edge : params->subgraph_incoming_edges) { std::pair old_src = {edge->src()->id(), edge->src_output()}; + if (unique_tensors.count(old_src)) continue; + unique_tensors.insert(old_src); int new_src_output = subgraph_edge_to_input_map.at(old_src); params->graph.AddEdge(edge->src(), edge->src_output(), trt_node, new_src_output); + VLOG(1) << "Wire " << edge->src()->name() << ":" << edge->src_output() + << " -> " << trt_node->name() << ":" << new_src_output; params->graph.RemoveEdge(edge); } - - VLOG(2) << "new wiring edges: " << trt_node->in_edges().size(); - for (const tensorflow::Edge* edge : trt_node->in_edges()) { - VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); + if (VLOG_IS_ON(2)) { + VLOG(2) << "new edge count: " << trt_node->in_edges().size(); + for (const tensorflow::Edge* edge : trt_node->in_edges()) { + VLOG(2) << edge->src()->name() << " port: " << edge->src_output(); + } } - TF_RETURN_IF_ERROR(status); // Re-map outgoing edges to use the new TRT node instead of the orig subgraph @@ -283,6 +290,8 @@ tensorflow::Status ConvertSubGraphToTensorRT(ConvertGraphParams* params) { int new_src_output = subgraph_edge_to_output_map.at(old_src); TF_RETURN_IF_ERROR(params->graph.UpdateEdge( trt_node, new_src_output, edge->dst(), edge->dst_input())); + VLOG(1) << "Wire " << trt_node->name() << ":" << new_src_output << " -> " + << edge->dst()->name() << ":" << edge->dst_input(); } // Remove the original subgraph for (int node_id : params->subgraph_node_ids) { @@ -317,9 +326,12 @@ tensorflow::Status ConvertCalibGraphToInferGraph( tensorflow::GraphConstructorOptions(), graph_def, &graph)); // get calib nodes std::vector calib_nodes; - for (auto node : graph.op_nodes()) { + std::vector topo_order; + tensorflow::GetPostOrder(graph, &topo_order); + for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { + auto node = *rit; if (node->type_string() == "TRTCalibOp") { - VLOG(1) << "Found Calib Node"; + VLOG(1) << "Found Calib Node " << node->name(); calib_nodes.push_back(node); } } diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 32b211dcd1e282d334327b83a27f9401de7f310a..4e4d295538edadd26a347a38ec141737f097f26f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -362,10 +362,11 @@ void ReorderCKtoKC(const TRT_ShapedWeights& iweights, break; } case tensorflow::DataType::DT_HALF: { - Reorder2({k, c}, static_cast(iweights.GetValues()), - istrides, static_cast( - const_cast(oweights->GetValues())), - ostrides); + Reorder2( + {k, c}, static_cast(iweights.GetValues()), + istrides, + static_cast(const_cast(oweights->GetValues())), + ostrides); break; } default: @@ -1179,9 +1180,9 @@ tensorflow::Status BinaryTensorOpTensor( CHECK_EQ_TYPE(tensor_r->getType(), dtype); auto op_pair = ops.find(node_def.op()); if (op_pair == ops.end()) - return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + - " not supported at: " + - node_def.name()); + return tensorflow::errors::Unimplemented( + "binary op: " + node_def.op() + + " not supported at: " + node_def.name()); nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( *const_cast(tensor_l), @@ -2138,9 +2139,7 @@ void Converter::register_op_converters() { } } // namespace -tensorflow::Status GetTensorRTGraph(tensorrt::convert::SubGraphParams& s) { - return tensorflow::errors::Unimplemented("Not implemented yet"); -} + tensorflow::Status ConvertCalibrationNodeToEngineNode( tensorflow::Graph& graph, tensorflow::Node* c_node) { const auto ndef = c_node->def(); @@ -2164,9 +2163,23 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( for (auto n : graph.op_nodes()) { node_maps.insert({n->name(), n}); } + std::set subgraph_ids; + for (const auto internal_node : segment_nodes) { + subgraph_ids.insert(node_maps.at(internal_node)->id()); + } + if (VLOG_IS_ON(2)) { + string node_names = StrCat(c_node->name(), " segment nodes= "); + + for (const auto& node_name : segment_nodes) { + StrAppend(&node_names, node_name, ", "); + } + VLOG(2) << node_names; + } + VLOG(1) << "Output Nodes:"; std::vector out_types; std::vector out_edges; + for (auto& i : output_nodes) { auto node_port = tensorflow::str_util::Split(i, ":"); VLOG(1) << " " << i << " in graph " << node_maps.count(i); @@ -2186,18 +2199,24 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( out_types.push_back(out_node->output_type(0)); } for (auto out_edge : out_node->out_edges()) { + if (subgraph_ids.count(out_edge->dst()->id())) + continue; // skip internal edges; if (out_edge->src_output() == port) { out_edges.push_back(out_edge); - break; + VLOG(1) << "OUTPUT EDGE " << out_edge->src()->name() << ":" + << out_edge->src_output() << " -> " << out_edge->dst()->name() + << ":" << out_edge->dst_input(); } } } else { LOG(WARNING) << " couldn't find output node " << out_node_name; } } - VLOG(1) << "Input Nodes:"; - for (auto& i : input_names) { - VLOG(1) << " " << i << " in graph " << node_maps.count(i); + if (VLOG_IS_ON(1)) { + VLOG(1) << c_node->name() << " Input Nodes:"; + for (auto& i : input_names) { + VLOG(1) << " Input " << i << " in graph " << node_maps.count(i); + } } auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); auto resmgr = trt_rm->getManager("TRTCalibOps"); @@ -2231,14 +2250,24 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( calib_res->builder_ = nullptr; tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); std::vector income_edges; + income_edges.resize(c_node->num_inputs()); for (const auto in_edge : c_node->in_edges()) { auto src = in_edge->src(); int dest_port = in_edge->dst_input(); - income_edges.emplace_back(src->name(), in_edge->src_output(), - c_node->input_type(dest_port)); + VLOG(1) << "Incoming connection " << src->name() << ":" + << in_edge->src_output() << " -> " << c_node->name() << ":" + << dest_port; + income_edges.at(dest_port) = {src->name(), in_edge->src_output(), + c_node->input_type(dest_port)}; } tensorflow::gtl::ArraySlice input_list( income_edges); + if (VLOG_IS_ON(2)) { + for (const auto& inp : input_list) { + VLOG(2) << " Input from inputlist " << inp.node << ":" << inp.index << " " + << tensorflow::DataTypeString(inp.data_type); + } + } op_builder.Input(input_list); tensorflow::NodeDef engine_node; const char* engine_plan_data = static_cast(engine_plan->data()); @@ -2255,13 +2284,26 @@ tensorflow::Status ConvertCalibrationNodeToEngineNode( } auto trt_engine_node = graph.AddNode(engine_node, &status); TF_RETURN_IF_ERROR(status); - for (size_t i = 0; i < out_edges.size(); i++) { - VLOG(1) << "Connecting trt_engine_node output " << i << " with " - << out_edges.at(i)->dst()->name() << " port " - << out_edges.at(i)->dst_input(); - TF_RETURN_IF_ERROR(graph.UpdateEdge(trt_engine_node, i, - out_edges.at(i)->dst(), - out_edges.at(i)->dst_input())); + std::map port_map; + for (size_t t = 0; t < output_nodes.size(); t++) { + port_map.insert({output_nodes.at(t), t}); + } + for (auto& i : out_edges) { + string s(i->src()->name()); + if (i->src_output()) StrAppend(&s, ":", i->src_output()); + int out_port = port_map.at(s); + VLOG(1) << "Connecting " << trt_engine_node->name() << ":" << out_port + << " -> " << i->dst()->name() << ":" << i->dst_input(); + TF_RETURN_IF_ERROR( + graph.UpdateEdge(trt_engine_node, out_port, i->dst(), i->dst_input())); + } + for (const auto ed : trt_engine_node->in_edges()) { + VLOG(1) << "In Edge " << ed->src()->name() << ":" << ed->src_output() + << " -> " << ed->dst()->name() << ":" << ed->dst_input(); + } + for (const auto ed : trt_engine_node->out_edges()) { + VLOG(1) << "Out Edge " << ed->src()->name() << ":" << ed->src_output() + << " -> " << ed->dst()->name() << ":" << ed->dst_input(); } VLOG(1) << "Segment nodes:"; for (auto& i : segment_nodes) { @@ -2332,6 +2374,7 @@ tensorflow::Status ConvertSubgraph( std::vector* output_names, std::vector* output_dtypes, const string& engine_name) { + std::set added_tensors; for (const std::pair& input : s.input_inds) { VLOG(2) << "parsing input. Node id= " << input.first; int node_id = input.first; @@ -2374,7 +2417,6 @@ tensorflow::Status ConvertSubgraph( auto op_info = op_info_vec.at(shape_inference_output_idx); tensorflow::DataType tf_dtype = op_info.dtype(); - input_dtypes->push_back(tf_dtype); nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); auto type_status = ConvertDType(tf_dtype, &dtype); @@ -2410,8 +2452,10 @@ tensorflow::Status ConvertSubgraph( if (output_idx != 0) { input_tensor_name = StrCat(node_name, ":", output_idx); } - + if (added_tensors.count(input_tensor_name)) continue; + added_tensors.insert(input_tensor_name); input_names->push_back(input_tensor_name); + input_dtypes->push_back(tf_dtype); nvinfer1::ITensor* input_tensor = converter.network()->addInput( input_tensor_name.c_str(), dtype, input_dim_pseudo_chw); @@ -2435,6 +2479,7 @@ tensorflow::Status ConvertSubgraph( // Gather output metadata int trt_engine_op_output_idx = 0; + added_tensors.clear(); for (const std::pair& output : s.output_inds) { int node_id = output.first; int output_idx = output.second; @@ -2451,6 +2496,8 @@ tensorflow::Status ConvertSubgraph( if (output_idx != 0) tensorflow::strings::StrAppend(&tensor_name, ":", output_idx); VLOG(2) << "Output tensor name: " << tensor_name; + if (added_tensors.count(tensor_name)) continue; + added_tensors.insert(tensor_name); output_names->push_back(tensor_name); auto tensor_or_weights = converter.get_tensor(tensor_name); if (!tensor_or_weights.is_tensor()) { @@ -2534,7 +2581,7 @@ tensorflow::Status InjectCalibrationNode(tensorrt::convert::SubGraphParams& s) { // Build the TRT op // TODO(sami,ben,jie): proper naming! tensorflow::NodeDefBuilder op_builder(calib_op_name, "TRTCalibOp"); - SetInputList(s, &op_builder, &input_names, &input_dtypes); + TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes)); std::vector segment_names; segment_names.reserve(s.subgraph_node_ids.size()); @@ -2632,7 +2679,7 @@ tensorflow::Status ConvertSubGraphToTensorRTNodeDef( // Build the TRT op tensorflow::NodeDefBuilder op_builder(engine_name, "TRTEngineOp"); - SetInputList(s, &op_builder, &input_names, &input_dtypes); + TF_RETURN_IF_ERROR(SetInputList(s, &op_builder, &input_names, &input_dtypes)); VLOG(0) << "Finished op preparation"; diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 2de3923b06a8ddf89c7e6f922138a85f55a618d6..f5b2d258d70d5577a9d68f2d9f6d6e678ede97ce 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -275,13 +275,13 @@ TEST_F(SegmentTest, Multiple) { // Expect two subgraphs EXPECT_EQ(segments.size(), 2); - std::vector expected0{"add0", "add1", "add2", "add3"}; + std::vector expected0{"add6", "add8"}; for (const auto& ex : expected0) { EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end()) << "Missing expected node " << ex; } - std::vector expected1{"add6", "add8"}; + std::vector expected1{"add0", "add1", "add2", "add3"}; for (const auto& ex : expected1) { EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end()) << "Missing expected node " << ex; diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index dbf1ab6bbf0ddc7429d8e19279451eb862981e0c..3b2d7adfff6b8de3145a73756a8b5306445034c5 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -53,7 +53,7 @@ tf_cc_binary( "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/platform/cloud:gcs_file_system", - "@grpc//:grpc++_unsecure", + "@grpc//:grpc++", ], ) diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 816897499b7a49365060c026d60b977990f3ecdc..f80f5652af79d410946971573ae160fdd0b85f6d 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -18,7 +18,7 @@ limitations under the License. // Initiates a TPU profiling on the TPUProfiler service at service_addr, // receives and dumps the profile data to a tensorboard log directory. -#include "grpc++/grpc++.h" +#include "grpcpp/grpcpp.h" #include #include @@ -79,7 +79,9 @@ ProfileRequest PopulateProfileRequest(int duration_ms, request.set_repository_root(repository_root); request.set_session_id(session_id); } + request.add_tools("op_profile"); request.add_tools("input_pipeline"); + request.add_tools("memory_viewer"); request.add_tools("overview_page"); *request.mutable_opts() = opts; std::cout << "Limiting the number of trace events to " << kMaxEvents diff --git a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc index 73d941e5e99801d239441e438e9d640478c4e6b6..98cc31f18d2d34765f2c123c3d34207802541036 100644 --- a/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/dump_tpu_profile.cc @@ -38,6 +38,7 @@ namespace { using ::tensorflow::io::JoinPath; using ::tensorflow::protobuf::util::JsonOptions; using ::tensorflow::protobuf::util::MessageToJsonString; +using ::tensorflow::str_util::EndsWith; using ::tensorflow::strings::StrCat; constexpr char kGraphRunPrefix[] = "tpu_profiler.hlo_graph."; @@ -46,6 +47,9 @@ constexpr char kJsonTraceFileName[] = "trace.json.gz"; constexpr char kProfilePluginDirectory[] = "plugins/profile/"; constexpr char kProtoTraceFileName[] = "trace"; +constexpr char kFlatProfilerFileName[] = "flat_profiler.pb"; +constexpr char kTfStatsHelperSuffix[] = "tf_stats_helper_result"; + Status WriteGzippedDataToFile(const string& filename, const string& data) { std::unique_ptr file; TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(filename, &file)); @@ -107,6 +111,10 @@ Status DumpToolDataToLogDirectory(StringPiece run_dir, const string& host_prefix, const tensorflow::ProfileToolData& tool, std::ostream* os) { + // Don't save the intermediate results for combining the per host tool data. + if (EndsWith(tool.name(), kFlatProfilerFileName) || + EndsWith(tool.name(), kTfStatsHelperSuffix)) + return Status::OK(); string path = JoinPath(run_dir, StrCat(host_prefix, tool.name())); TF_RETURN_IF_ERROR(WriteStringToFile(Env::Default(), path, tool.data())); if (os) { diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto index b9ac1a550c87e055fd5d555c346d5ec545bbe634..2b13343efa4e82386cb9259432b854be3ec821f7 100644 --- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -87,6 +87,8 @@ message StepInfoResult { optional uint64 wait_duration_ps = 5; // The time spent on cross-replica-sum in picoseconds. optional uint64 crs_duration_ps = 6; + // Percentage of unit b time spent on infeed. + optional double unit_b_infeed_percent = 7; } // Result proto for a sequence of steps. diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto index 7be694e866729c58efae4ccf7932dd929c03ed91..f0fca63db0bca80cdaa27e491b2a03ae2246c007 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto @@ -68,7 +68,8 @@ message ProfileRequest { } message ProfileToolData { - // The tool's name which this data is associated. (e.g. "input_pipeline".) + // The file name which this data is associated (e.g. "input_pipeline.json", + // "cluster_xxx.memory_viewer.json"). string name = 1; // The data payload (likely json) for the specific tool. diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto index 8b0bbde98e6a1dee8ade789328f3ba0624049562..d3c34bfd490080b86cf3d8b893c550f3a87bbbed 100644 --- a/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto +++ b/tensorflow/contrib/tpu/profiler/tpu_profiler_analysis.proto @@ -38,6 +38,9 @@ message EnumProfileSessionsAndToolsResponse { message ProfileSessionDataRequest { string repository_root = 1; string session_id = 2; + // Which host the data is associated. if empty, data from all hosts are + // aggregated. + string host_name = 5; // Which tool string tool_name = 3; // Tool's specific parameters. e.g. TraceViewer's viewport etc diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index e2f57ce9c503f454d2bf1eae895922a4a4d26ced..cd0fd6ae8a2b35efa85bb4583ed3846a1a33395f 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -21,6 +21,7 @@ from __future__ import print_function from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function @@ -125,7 +126,19 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): outside the replicated computation. """ - def __init__(self, name, num_replicas): + def __init__(self, name, num_replicas, pivot): + """Builds a new TPUReplicateContext. + + Args: + name: a unique name for the context, used to populate the `_tpu_replicate` + attribute. + num_replicas: an integer that gives the number of replicas for the + computation. + pivot: a pivot node. Nodes in the TPUReplicateContext that do not have any + inputs will have a control dependency on the pivot node. This ensures + that nodes are correctly included in any enclosing control flow + contexts. + """ super(TPUReplicateContext, self).__init__() self._num_replicas = num_replicas self._outer_device_function_stack = None @@ -137,6 +150,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._host_compute_core = [] self._name = name self._unsupported_ops = [] + self._pivot = pivot def report_unsupported_operations(self): if self._unsupported_ops: @@ -261,9 +275,6 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access super(TPUReplicateContext, self).Enter() - def Exit(self): - super(TPUReplicateContext, self).Exit() - def HostComputeCore(self): return self._host_compute_core @@ -299,10 +310,64 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) + # Remove any control edges from outer control flow contexts. These may cause + # mismatched frame errors. + control_inputs, external_inputs = self._RemoveExternalControlEdges(op) + + if not op.inputs: + # Add a control edge from the control pivot to this op. + if not control_inputs: + # pylint: disable=protected-access + op._add_control_input(self.GetControlPivot()) + # pylint: enable=protected-access + else: + for index in xrange(len(op.inputs)): + x = op.inputs[index] + real_x = self.AddValue(x) + if real_x != x: + op._update_input(index, real_x) # pylint: disable=protected-access + + if external_inputs: + # Use an identity to pull control inputs as data inputs. Note that we + # ignore ops which don't have outputs. TODO(phawkins): fix that. + with ops.control_dependencies(None): + self.Enter() + external_inputs = [ + array_ops.identity(x.outputs[0]).op + for x in external_inputs + if x.outputs + ] + self.Exit() + # pylint: disable=protected-access + op._add_control_inputs(external_inputs) + # pylint: enable=protected-access + + # Mark op's outputs as seen by this context and any outer contexts. + output_names = [x.name for x in op.outputs] + context = self + while context is not None: + # pylint: disable=protected-access + context._values.update(output_names) + context = context._outer_context + # pylint: enable=protected-access + + if self._outer_context: + self._outer_context.AddInnerOp(op) + def AddValue(self, val): + if val.name in self._values: + # Use the real value if it comes from outer context. + result = self._external_values.get(val.name) + return val if result is None else result + result = val + self._values.add(val.name) if self._outer_context: result = self._outer_context.AddValue(val) + self._values.add(result.name) + + self._external_values[val.name] = result + return result def AddInnerOp(self, op): @@ -318,17 +383,30 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): # grad_state should be as if this is the top-level gradient state. return None + @property + def back_prop(self): + """Forwards to the enclosing while context, if any.""" + if self.GetWhileContext(): + return self.GetWhileContext().back_prop + return False -def outside_compilation(computation, args=None): + def GetControlPivot(self): + return self._pivot + + +def outside_compilation(computation, *args, **kwargs): """Builds part of a computation outside any current TPU replicate scope. Args: computation: A Python function that builds the computation to place on the host. - args: Inputs to pass to computation. + *args: the positional arguments for the computation. + **kwargs: the keyword arguments for the computation. + Returns: The Tensors returned by computation. """ + args = [] if args is None else args graph = ops.get_default_graph() # If we are in a TPUReplicateContext, signal that we are now @@ -340,7 +418,7 @@ def outside_compilation(computation, args=None): context._EnterOutsideCompilationScope() # pylint: disable=protected-access context = context.outer_context - retval = computation(*args) + retval = computation(*args, **kwargs) # If we are in a TPUReplicateContext, signal that we are no longer # outside_compilation @@ -501,7 +579,9 @@ def split_compile_and_replicate(computation, tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) cluster_name = graph.unique_name("cluster") - context = TPUReplicateContext(name=cluster_name, num_replicas=num_replicas) + pivot = control_flow_ops.no_op(name=cluster_name + "/pivot") + context = TPUReplicateContext( + name=cluster_name, num_replicas=num_replicas, pivot=pivot) try: context.Enter() @@ -543,10 +623,16 @@ def split_compile_and_replicate(computation, vscope.set_use_resource(saved_use_resource) + # If the computation returns `None`, make it an empty tuple. + if outputs is None: + outputs = tuple() # If the computation only returned one value, makes it a tuple. if not isinstance(outputs, (list, tuple)): outputs = (outputs,) + # Append `no_op` here so that fetching any return value of this function + # will trigger TPUExecute node. + outputs += (control_flow_ops.no_op(),) try: with ops.device(core(0)): outputs = [ @@ -578,6 +664,7 @@ def split_compile_and_replicate(computation, with ops.device(t.device if t.device else core(0)): new_output_tensors.append(array_ops.identity(t)) output_tensors = new_output_tensors + context.ExitResult(output_tensors) finally: context.report_unsupported_operations() context.Exit() @@ -867,3 +954,152 @@ def rewrite(computation, device_assignment=device_assignment, name=name)[0] # pylint: enable=indexing-exception + + # Operations that indicate some error in the user's inference graph. +_BLACKLISTED_INFERENCE_OPS = set([ + "ReadVariableOp", + "AssignVariableOp", + "AssignAddVariableOp", + "AssignSubVariableOp", + "VarHandleOp", + "Variable", + "VariableV2", +]) + + +class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext): + """A `ControlFlowContext` for nodes inside a TPU inference computation. + + The primary role of `TPUReplicateContext` is to sanity check operators inside + a tpu.rewrite_for_inference() computation. + """ + + def __init__(self, name): + super(_TPUInferenceContext, self).__init__() + self._name = name + + def AddOp(self, op): + self._AddOpInternal(op) + + def _AddOpInternal(self, op): + # pylint: disable=protected-access + if op.type in _BLACKLISTED_INFERENCE_OPS: + raise NotImplementedError( + "Operation of type %s (%s) is not supported on the TPU for inference." + " Execution will fail if this op is used in the graph. Make sure your" + " variables are using variable_scope." % (op.type, op.name)) + if self._outer_context: + self._outer_context.AddInnerOp(op) + + def AddValue(self, val): + result = val + if self._outer_context: + result = self._outer_context.AddValue(val) + return result + + def AddInnerOp(self, op): + self._AddOpInternal(op) + + @property + def grad_state(self): + return None + + +@experimental +def validate_inference_rewrite_for_variables(graph): + """Validates whether rewrite_for_inference() 'worked' for variables. + + The rewrite_for_inference() method is supposed to append + GuaranteeConstOps after ReadVariableOps, but this mechanism works only + if you are using tf.get_variable() to create and access variables in your + tpu computation. This validation method can be called immediately after + calling tpu.rewrite_for_inference() to check whether GuaranteeConstOps + where added to the graph. + + Typical usages: + tpu.validate_inference_rewrite_for_variables(tf.get_default_graph()) + + tpu.validate_inference_rewrite_for_variables(sess.graph) + + Args: + graph: The graph which needs to be validated. + Raises: + RuntimeError: if validation failed. + """ + if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]): + raise RuntimeError( + "No GuaranteeConst ops found in the graph after " + "running tpu.rewrite_for_inference(...). Please " + "check that you are using tf.get_variable() to " + "create and access variables in your tpu " + "computation.") + + +@experimental +def rewrite_for_inference(computation, + inputs=None, + infeed_queue=None, + device_assignment=None, + name=None): + """Rewrites `computation` for inference on a TPU system. + + Other than 'rewriting' the computation to run on a TPU, if using variables + in your computation, it moves the ReadVariableOps outside the TPU + computation, and adds GuaranteeConst ops just after the ReadVariableOps. + This mechanism works only if you are using tf.get_variable() to create and + access variables in your tpu computation. You can validate whether + this worked, by calling validate_inference_rewrite_for_variables() method + immediately after this method to check whether GuaranteeConstOps where + added to the graph. + + Args: + computation: A Python function that builds a computation to apply + to the input. If the function takes n inputs, 'inputs' should be + a list of n tensors. If the function returns m outputs, rewrite + will return a list of m tensors. + inputs: A list of input tensors or `None` (equivalent to an empty list). + infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple + of arguments as inputs to `computation`. + device_assignment: if not `None`, a `DeviceAssignment` describing the + mapping between logical cores in the computation with physical cores in + the TPU topology. May be omitted for a single-core computation, in which + case the core attached to task 0, TPU device 0 is used. + name: The name of the operator. + Returns: + A list of output tensors. + """ + + def guarantee_const_getter(getter, name, *args, **kwargs): + with ops.control_dependencies(None): + return array_ops.guarantee_const( + getter(name, *args, **kwargs), name=name + "/GuaranteeConst") + + def wrapped_computation(*args, **kwargs): + """Execute computation under `_TPUInferenceContext`.""" + context = _TPUInferenceContext( + name=ops.get_default_graph().unique_name("rewrite_for_inference")) + try: + context.Enter() + + vscope = variable_scope.get_variable_scope() + prev_custom_getter = vscope.custom_getter + prev_caching_device = vscope.caching_device + vscope.set_custom_getter(guarantee_const_getter) + vscope.set_caching_device(lambda op: op.device) + + result = computation(*args, **kwargs) + + vscope.set_custom_getter(prev_custom_getter) + vscope.set_caching_device(prev_caching_device) + finally: + context.Exit() + return result + + # pylint: disable=undefined-variable + return rewrite( + wrapped_computation, + inputs=inputs, + infeed_queue=infeed_queue, + device_assignment=device_assignment, + name=name) + # pylint: enable=undefined-variable diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 77d117ba78276aa35e55eb3a524575acc1cf607b..e94bd78833f6cbe9adb1b6ca3f29a88bd8a53f64 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -46,6 +46,8 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -61,6 +63,7 @@ from tensorflow.python.ops import summary_ops_v2 as contrib_summary from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import tag_constants from tensorflow.python.summary import summary from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import evaluation @@ -71,6 +74,7 @@ from tensorflow.python.util import function_utils from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect + _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. _TPU_ESTIMATOR = 'tpu_estimator' @@ -81,6 +85,7 @@ _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _ONE_GIGABYTE = 1024 * 1024 * 1024 _TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' _TPU_TRAIN_OP = '_tpu_train_op' +_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] @@ -117,6 +122,33 @@ def _create_global_step(graph): def _create_or_get_iterations_per_loop(): + """Creates or gets the iterations_per_loop variable. + + In TPUEstimator, the user provided computation, the model_fn, is wrapped + inside a tf.while_loop for peak performance. The iterations of the loop are + specified by this variable, which adjusts its value on the CPU after each TPU + program execution and before the next TPU execution. + + The purpose of using a variable, rather then a constant, is to allow + TPUEstimator adapt the TPU training iterations according to the final steps + specified by users. For example, if the user sets the iterations_per_loop as 4 + in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop + variable will have the following value before each TPU training. + + - 1-th TPU execution: iterations_per_loop = 4 + - 2-th TPU execution: iterations_per_loop = 4 + - 3-th TPU execution: iterations_per_loop = 2 + + As model_fn increases the global step once per train_op invocation, the global + step is 10 after all TPU executions, matching the steps=10 inputs passed in by + users. + + Returns: + A TF non-trainable resource variable. + + Raises: + RuntimeError: If multi iterations_per_loop variables were found. + """ graph = ops.get_default_graph() collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) iter_vars = graph.get_collection(collection_name) @@ -383,20 +415,21 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): return def _cancel_session(): - # Close the session to avoid the main thread from hanging. If input - # pipeline triggers any error, the infeed thread dies but the main thread - # for TPU computation waits for the infeed enqueue forever. Close the - # Session to cancel the main thread Session.run execution. - # - # We sleep for a few seconds before closing to give some time - # for the TPU compilation error, if any, propagating, from TPU to CPU - # host. Compilation errors should be reported by the main thread so that - # the program can be interrupted and users can take action. Due to a race - # condition, the infeed thread might see an error first. Closing the - # session here immediately would result in a session cancellation - # exception in the main thread, instead of the expected compile error. - # User code that depends on having the proper exception type will - # therefore be confused. + """Close the session to avoid the main thread from hanging. + + If input pipeline triggers any error, the infeed thread dies but the main + thread for TPU computation waits for the infeed enqueue forever. Close the + Session to cancel the main thread Session.run execution. + + We sleep for a few seconds before closing to give some time for the TPU + compilation error, if any, propagating, from TPU to CPU host. Compilation + errors should be reported by the main thread so that the program can be + interrupted and users can take action. Due to a race condition, the + infeed thread might see an error first. Closing the session here + immediately would result in a session cancellation exception in the main + thread, instead of the expected compile error. User code that depends on + having the proper exception type will therefore be confused. + """ time.sleep(5) # If the main session is still running, the infeed/outfeed errors are @@ -716,6 +749,15 @@ def generate_per_host_enqueue_ops_fn_for_host( tpu_ordinal_function = None def enqueue_ops_fn(): + """A Fn returning the TPU infeed enqueue ops. + + By providing as a Fn, it can be invoked inside the tf.while_loop such that + the input pipeline for multiple iterations can be executed by one + Session.run call. + + Returns: + list of dict of ops. + """ with ops.device(device): num_of_replicas_per_host = ctx.num_of_replicas_per_host # Convert user input to features and labels. If the user returns a @@ -1090,10 +1132,16 @@ class _InputPipeline(object): return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator def _validate_input_pipeline(self): - # Perform some sanity checks to log user friendly information. We should - # error out to give users better error message. But, if - # _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break - # user code, so, log a warning. + """Validates the input pipeline. + + Perform some sanity checks to log user friendly information. We should + error out to give users better error message. But, if + _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break + user code, so, log a warning. + + Raises: + RuntimeError: If the validation failed. + """ if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): err_msg = ('Input pipeline contains one or more QueueRunners. ' 'It could be slow and not scalable. Please consider ' @@ -1264,13 +1312,11 @@ class _ModelFnWrapper(object): 'estimator_spec used by TPU prediction must have type' '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) + self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) + captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) to_record = {} identity_fn = lambda **kwargs: kwargs - # TODO(xiejw): Adds validation for prediction dictionrary. - # TODO(xiejw): Adds support for single tensor as predictions. - if not isinstance(tpu_estimator_spec.predictions, dict): - raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] to_record['signals'] = [identity_fn, stopping_signals] if tpu_estimator_spec.host_call is not None: @@ -1282,8 +1328,70 @@ class _ModelFnWrapper(object): return predict_step, host_calls, captured_scaffold_fn + def _verify_tpu_spec_predictions(self, predictions): + """Validates TPUEstimatorSpec.predictions dict.""" + # TODO(xiejw): Adds validation for prediction dictionrary. + # TODO(xiejw): Adds support for single tensor as predictions. + if not isinstance(predictions, dict): + raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') + + for (key, tensor) in predictions.items(): + if tensor.shape[0].value is None: + raise ValueError( + 'The tensor with key ({}) in TPUEstimatorSpec.predictions has ' + 'dynamic shape (should be static). Tensor: {}'.format( + key, tensor)) + return predictions + + def _validate_model_features_and_labels(self, + features, + labels, + is_export_mode): + """Validates that the features and labels for the model function are valid. + + A valid features/labels object is the one with: + - Type: Tensor or a dictionary of Tensors + - Static shape if is_export_mode is False. + + Args: + features: the features that would be input to the model function. + labels: the labels that would be input to the model function. + is_export_mode: boolean value specifying if in export mode. + + Raises: + TypeError: If features/labels are not of the correct type. + ValueError: If features/labels have dynamic shape. + """ + + def validate(obj, obj_name): + """Helper validate function.""" + if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict): + raise TypeError( + 'The {} to the model returned by input_fn must be either a Tensor ' + 'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name, + obj)) + if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode): + return + if isinstance(obj, ops.Tensor): + if not obj.get_shape().is_fully_defined(): + raise ValueError( + 'The {} to the model returned by input_fn must have static shape.' + ' Tensor: {}'.format(obj_name, obj)) + else: + for (key, tensor) in obj.items(): + if not tensor.get_shape().is_fully_defined(): + raise ValueError( + 'The {} to the model returned by input_fn must have static ' + 'shape. Key: \'{}\', Tensor: {}'.format( + obj_name, key, tensor)) + + validate(features, 'features') + if labels is not None: + validate(labels, 'labels') + def _call_model_fn(self, features, labels, is_export_mode=False): """Calls the model_fn with required parameters.""" + self._validate_model_features_and_labels(features, labels, is_export_mode) model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} @@ -1760,8 +1868,40 @@ class TPUEstimator(estimator_lib.Estimator): Exporting ========= - Exporting `SavedModel` support on TPU is not yet implemented. So, - `export_savedmodel` is executed on CPU, even if `use_tpu` is true. + `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`, + and another with `tag_constants.SERVING` and `tag_constants.TPU`. + At serving time, these tags are used to select metagraph to load. + + Before running the graph on TPU, TPU system needs to be initialized. If + TensorFlow Serving model-server is used, this is done automatically. If + not, please call `session.run(tpu.initialize_system())`. + + `tpu.outside_compilation` can be used to wrap TPU incompatible ops in + `model_fn`. + + Example: + ---------------- + + ``` + def model_fn(features, labels, mode, config, params): + ... + logits = ... + export_outputs = { + 'logits': export_output_lib.PredictOutput( + {'logits': logits}) + } + + def host_call(logits): + class_ids = math_ops.argmax(logits) + classes = string_ops.as_string(class_ids) + export_outputs['classes'] = + export_output_lib.ClassificationOutput(classes=classes) + + tpu.outside_compilation(host_call, logits) + + ... + ``` + """ def __init__(self, @@ -1775,13 +1915,15 @@ class TPUEstimator(estimator_lib.Estimator): predict_batch_size=None, batch_axis=None, eval_on_tpu=True, + export_to_tpu=True, warm_start_from=None): """Constructs an `TPUEstimator` instance. Args: model_fn: Model function as required by `Estimator`. For training, the returned `EstimatorSpec` cannot have hooks as it is not supported in - `TPUEstimator`. + `TPUEstimator`. Instead, the user can pass the training hooks as + an argument to `TPUEstimator.train()`. model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If `None`, the model_dir in @@ -1817,6 +1959,8 @@ class TPUEstimator(estimator_lib.Estimator): False or `PER_HOST_V2`, batch_axis is ignored. eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. + export_to_tpu: If True, `export_savedmodel()` exports a metagraph for + serving on TPU besides the one on CPU. warm_start_from: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a `tf.estimator.WarmStartSettings` object to fully configure warm-starting. If the string @@ -1888,8 +2032,120 @@ class TPUEstimator(estimator_lib.Estimator): use_tpu, eval_on_tpu) + self._export_to_tpu = export_to_tpu + self._is_input_fn_invoked = None + def _add_meta_graph_for_mode(self, + builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables=True, + mode=model_fn_lib.ModeKeys.PREDICT, + export_tags=None): + if mode != model_fn_lib.ModeKeys.PREDICT: + raise NotImplementedError( + 'TPUEstimator only handles mode PREDICT for export_savedmodel(); ' + 'got {}.'.format(mode)) + + super(TPUEstimator, self)._add_meta_graph_for_mode(builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables, + mode=mode) + + if self._export_to_tpu: + input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE: + input_receiver_fn_map[mode]} + export_tags = [tag_constants.SERVING, tag_constants.TPU] + mode = _REWRITE_FOR_INFERENCE_MODE + (super(TPUEstimator, self). + _add_meta_graph_for_mode(builder, + input_receiver_fn_map, + checkpoint_path, + strip_default_attrs, + save_variables=False, + mode=mode, + export_tags=export_tags)) + + def _call_model_fn(self, features, labels, mode, config): + if mode == _REWRITE_FOR_INFERENCE_MODE: + return self._call_model_fn_for_inference(features, labels, mode, config) + else: + return super(TPUEstimator, self)._call_model_fn( + features, labels, mode, config) + + def _call_model_fn_for_inference(self, features, labels, mode, config): + """Wraps `_call_model_fn` for `export_savedmodel`.""" + if mode != _REWRITE_FOR_INFERENCE_MODE: + raise ValueError('mode must be {}; ' + 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) + + capture = _CapturedObject() + + def computation(): + """Compute tpu tensors used in export_outputs. + + Passed to rewrite_for_inference so that model_fn will be called under + the rewriting contexts. Only tpu tensors are returned, but export_outputs + and scaffold are captured. + + Returns: + A list of Tensors used in export_outputs and not marked for + outside_compilation. + """ + # We should only call model fn once and it should be inside `computation` + # so that building the graph will happen under `rewrite_for_inference`. + mode = model_fn_lib.ModeKeys.PREDICT + estimator_spec = self._call_model_fn(features, labels, mode, config) + + # We pick the TPU tensors out from `export_output` and later return them + # from `computation` for rewriting. + tensors_dict = collections.OrderedDict( + (k, _export_output_to_tensors(v)) + for k, v in six.iteritems(estimator_spec.export_outputs) + ) + tensors = nest.flatten(tensors_dict) + tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)] + + # We cannot return anything other than `tpu_tensors` here so we capture + # the rest for later use. + capture.capture((estimator_spec, tensors_dict, tensors)) + return tpu_tensors + + tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation) + estimator_spec, tensors_dict, tensors = capture.get() + + # Reconstruct `tensors`, but with `tpu_tensors` replaced with + # `tpu_tensors_on_cpu`. + new_tensors = [] + for t in tensors: + if _is_tpu_tensor(t): + new_tensors.append(tpu_tensors_on_cpu.pop(0)) + elif t is None: + new_tensors.append(None) + else: + # Only fetching `tpu_tensors_on_cpu` does not trigger + # TPU computation and blocks, so we add the control dependency here. + control_inputs = (tpu_tensors_on_cpu + if isinstance(tpu_tensors_on_cpu, (list, tuple)) + else (tpu_tensors_on_cpu,)) + with ops.control_dependencies(control_inputs): + new_tensors.append(array_ops.identity(t)) + + # Reconstruct `tensors_dict`. + new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) + # Reconstruct `export_outputs`. + export_outputs = estimator_spec.export_outputs + new_export_outputs = collections.OrderedDict( + (k, _clone_export_output_with_tensors(export_outputs[k], v)) + for k, v in six.iteritems(new_tensors_dict) + ) + + return estimator_spec._replace(export_outputs=new_export_outputs) + def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -2072,11 +2328,11 @@ class TPUEstimator(estimator_lib.Estimator): if shutdown_mode: if shutdown_mode == 'shutdown_worker': finalizer_hooks = [ - session_support.ShutdownLameWorkers(timeout_ms=1000), + session_support.ShutdownLameWorkers(timeout_ms=60*1000), ] elif shutdown_mode == 'shutdown_computation': finalizer_hooks = [ - session_support.RestartComputation(timeout_ms=1000), + session_support.RestartComputation(timeout_ms=60*1000), ] else: raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % @@ -2265,6 +2521,76 @@ class TPUEstimator(estimator_lib.Estimator): return _model_fn +def _is_tpu_tensor(tensor): + if not isinstance(tensor, ops.Tensor): + return False + try: + tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access + except ValueError: + return True + else: + return False + + +def _export_output_to_tensors(export_output): + """Get a list of `Tensors` used in `export_output`. + + Args: + export_output: an `ExportOutput` object such as `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + Returns: + a list of tensors used in export_output. + + Raises: + ValueError: if `export_output` is not one of `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + """ + if isinstance(export_output, export_output_lib.ClassificationOutput): + return [export_output.scores, export_output.classes] + elif isinstance(export_output, export_output_lib.RegressionOutput): + return [export_output.value] + elif isinstance(export_output, export_output_lib.PredictOutput): + return export_output.outputs.values() + else: + raise ValueError( + '`export_output` must be have type `ClassificationOutput`, ' + '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) + + +def _clone_export_output_with_tensors(export_output, tensors): + """Clones `export_output` but with new `tensors`. + + Args: + export_output: an `ExportOutput` object such as `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + tensors: a list of `Tensors` used to construct a new `export_output`. + + Returns: + A dict similar to `export_output` but with `tensors`. + + Raises: + ValueError: if `export_output` is not one of `ClassificationOutput`, + `RegressionOutput`, or `PredictOutput`. + """ + if isinstance(export_output, export_output_lib.ClassificationOutput): + if len(tensors) != 2: + raise ValueError('tensors must be of length 2; ' + 'got {}.'.format(len(tensors))) + return export_output_lib.ClassificationOutput(*tensors) + elif isinstance(export_output, export_output_lib.RegressionOutput): + if len(tensors) != 1: + raise ValueError('tensors must be of length 1; ' + 'got {}'.format(len(tensors))) + return export_output_lib.RegressionOutput(*tensors) + elif isinstance(export_output, export_output_lib.PredictOutput): + return export_output_lib.PredictOutput( + dict(zip(export_output.outputs.keys(), tensors))) + else: + raise ValueError( + '`export_output` must be have type `ClassificationOutput`, ' + '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) + + def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" iterations_per_loop_var = _create_or_get_iterations_per_loop() @@ -2412,7 +2738,7 @@ class _CapturedObject(object): def capture(self, o): if self._captured: raise RuntimeError( - 'InternalError: Object can be captured only. Please file bug .') + 'InternalError: Object can capture only once. Please file bug.') self._captured = True self._object = o @@ -2421,7 +2747,7 @@ class _CapturedObject(object): if not self._captured: raise RuntimeError( 'InternalError: Object is not captured properly before `get`. ' - 'Please file bug .') + 'Please file bug.') return self._object @@ -2522,7 +2848,8 @@ class _Inputs(object): """ iterator = self._dataset.make_initializable_iterator() # pylint: disable=protected-access - hook = estimator_lib._DatasetInitializerHook(iterator) + hook = estimator_util._DatasetInitializerHook(iterator) + # pylint: enable=protected-access self._iterator = iterator return hook @@ -2668,6 +2995,7 @@ class _StopSignals(object): @staticmethod def should_stop(scalar_stopping_signal): + """Detects whether scalar_stopping_signal indicates stopping.""" if isinstance(scalar_stopping_signal, ops.Tensor): # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF # way to express the bool check whether scalar_stopping_signal is True. @@ -2831,4 +3159,3 @@ def _add_item_to_params(params, key, value): else: # Now params is Python dict. params[key] = value - diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py index c3882b8a27bc835f906c47dc5219f280c53800b8..6bdaa528f9f946ae4b9813d554409da2406b1f8d 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.tpu.python.tpu import training_loop from tensorflow.python.framework import dtypes from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops @@ -37,7 +38,8 @@ class TPUContextTest(test.TestCase): def testIsInContext(self): """Test that control_flow_util can check that we're in a TPU context.""" z1 = array_ops.identity(1) - context = tpu.TPUReplicateContext(b"context", 1) + pivot = control_flow_ops.no_op() + context = tpu.TPUReplicateContext(b"context", 1, pivot=pivot) context.Enter() z2 = array_ops.identity(1) context.Exit() diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py index 409aba817c1ec37003eb98f000f6cf8918234c5d..a2444934bc21d58ed57d15494b3548a31ce3a2df 100644 --- a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py +++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -45,14 +46,14 @@ class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset): self._input_dataset = input_dataset self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") - # pylint: disable=protected-access if padded_shapes is None: self._padded_shapes = nest.map_structure( - dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes) + convert.partial_shape_to_tensor, input_dataset.output_shapes) else: self._padded_shapes = nest.map_structure_up_to( - input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor, + input_dataset.output_shapes, convert.partial_shape_to_tensor, padded_shapes) + # pylint: disable=protected-access padding_values = ( padding_values if padding_values is not None else dataset_ops._default_padding(input_dataset)) diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index 9720fd6e8657de18cf8d7565f834568ae52fdbda..1b45584dcb84fe62de6cc14017e7cae575f99b2f 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -58,7 +58,7 @@ cc_library( "//tensorflow/core/distributed_runtime/rpc:async_service_interface", "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_util", - "@grpc//:grpc++_unsecure", + "@grpc//:grpc++", ], alwayslink = 1, ) @@ -69,7 +69,7 @@ cc_library( hdrs = ["grpc_verbs_service_impl.h"], deps = [ ":verbs_service_proto_cc", - "@grpc//:grpc++_unsecure", + "@grpc//:grpc++", ], ) diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc index 742f946c9536973eb8a6a11afda1b32ae4a7726b..af29abd91feda22824e57c19c13a3f48fb1d61b7 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc @@ -15,9 +15,9 @@ limitations under the License. #ifdef TENSORFLOW_USE_VERBS -#include "grpc++/alarm.h" -#include "grpc++/grpc++.h" -#include "grpc++/server_builder.h" +#include "grpcpp/alarm.h" +#include "grpcpp/grpcpp.h" +#include "grpcpp/server_builder.h" #include "tensorflow/contrib/verbs/grpc_verbs_service.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc index 991f9a9d8bdf883b1b68bfa1fb6af7bf51b7e66a..4da7b59c69c88a4d04be37543aae7f03decd2c52 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc @@ -15,14 +15,14 @@ limitations under the License. #include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h" -#include "grpc++/impl/codegen/async_stream.h" -#include "grpc++/impl/codegen/async_unary_call.h" -#include "grpc++/impl/codegen/channel_interface.h" -#include "grpc++/impl/codegen/client_unary_call.h" -#include "grpc++/impl/codegen/method_handler_impl.h" -#include "grpc++/impl/codegen/rpc_service_method.h" -#include "grpc++/impl/codegen/service_type.h" -#include "grpc++/impl/codegen/sync_stream.h" +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/channel_interface.h" +#include "grpcpp/impl/codegen/client_unary_call.h" +#include "grpcpp/impl/codegen/method_handler_impl.h" +#include "grpcpp/impl/codegen/rpc_service_method.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/sync_stream.h" namespace tensorflow { diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h index 1f0f10517e98a32ae882c027330091928f1a6ee2..abe5e08b07cd71b7ca28321e6eb2cf0eec5d1b0f 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h @@ -16,14 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ #define TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_ -#include "grpc++/impl/codegen/async_stream.h" -#include "grpc++/impl/codegen/async_unary_call.h" -#include "grpc++/impl/codegen/proto_utils.h" -#include "grpc++/impl/codegen/rpc_method.h" -#include "grpc++/impl/codegen/service_type.h" -#include "grpc++/impl/codegen/status.h" -#include "grpc++/impl/codegen/stub_options.h" -#include "grpc++/impl/codegen/sync_stream.h" +#include "grpcpp/impl/codegen/async_stream.h" +#include "grpcpp/impl/codegen/async_unary_call.h" +#include "grpcpp/impl/codegen/proto_utils.h" +#include "grpcpp/impl/codegen/rpc_method.h" +#include "grpcpp/impl/codegen/service_type.h" +#include "grpcpp/impl/codegen/status.h" +#include "grpcpp/impl/codegen/stub_options.h" +#include "grpcpp/impl/codegen/sync_stream.h" #include "tensorflow/contrib/verbs/verbs_service.pb.h" diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a576e360973381d6df3ed399b207a987006d58c7..b6b48a077cdafe12aeb1e4e0988493692c82eace 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -72,24 +72,23 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "cc_header_only_library", "full_path", "if_android", - "if_not_android_mips_and_mips64", "if_ios", "if_linux_x86_64", "if_mobile", "if_not_mobile", - "if_windows", "if_not_windows", - "tf_copts", + "if_windows", "tf_cc_test", "tf_cc_tests", + "tf_copts", "tf_cuda_library", "tf_gen_op_libs", "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", - "cc_header_only_library", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") @@ -101,49 +100,48 @@ load("//tensorflow:tensorflow.bzl", "tf_cuda_only_cc_test") # For platform specific build config load( "//tensorflow/core:platform/default/build_config.bzl", - "tf_platform_hdrs", - "tf_platform_srcs", - "tf_proto_library", - "tf_proto_library_cc", "tf_additional_all_protos", + "tf_additional_cloud_kernel_deps", + "tf_additional_cloud_op_deps", "tf_additional_core_deps", + "tf_additional_cupti_wrapper_deps", + "tf_additional_device_tracer_cuda_deps", + "tf_additional_device_tracer_deps", + "tf_additional_device_tracer_srcs", + "tf_additional_gdr_lib_defines", + "tf_additional_human_readable_json_deps", "tf_additional_lib_defines", "tf_additional_lib_deps", "tf_additional_lib_hdrs", "tf_additional_lib_srcs", - "tf_additional_framework_hdrs", - "tf_additional_framework_srcs", - "tf_additional_minimal_lib_srcs", - "tf_additional_proto_hdrs", - "tf_additional_proto_srcs", - "tf_additional_cupti_wrapper_deps", "tf_additional_libdevice_data", "tf_additional_libdevice_deps", "tf_additional_libdevice_srcs", + "tf_additional_minimal_lib_srcs", + "tf_additional_mpi_lib_defines", + "tf_additional_proto_hdrs", + "tf_additional_proto_srcs", "tf_additional_test_deps", "tf_additional_test_srcs", - "tf_kernel_tests_linkstatic", - "tf_additional_cloud_op_deps", - "tf_additional_cloud_kernel_deps", - "tf_lib_proto_parsing_deps", "tf_additional_verbs_lib_defines", - "tf_additional_mpi_lib_defines", - "tf_additional_gdr_lib_defines", - "tf_additional_device_tracer_srcs", - "tf_additional_device_tracer_deps", - "tf_additional_device_tracer_cuda_deps", - "tf_pyclif_proto_library", "tf_jspb_proto_library", + "tf_kernel_tests_linkstatic", + "tf_lib_proto_parsing_deps", "tf_nano_proto_library", + "tf_platform_hdrs", + "tf_platform_srcs", + "tf_proto_library", + "tf_proto_library_cc", "tf_protos_all", "tf_protos_all_impl", "tf_protos_grappler", "tf_protos_grappler_impl", + "tf_pyclif_proto_library", ) load( "//tensorflow/core:platform/default/build_config_root.bzl", - "tf_cuda_tests_tags", "if_static", + "tf_cuda_tests_tags", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library") @@ -295,43 +293,18 @@ cc_library( ], ) -PLATFORM_BASE_HDRS = [ - "platform/env_time.h", - "platform/logging.h", - "platform/macros.h", - "platform/types.h", - "platform/byte_order.h", -] - -PLATFORM_OTHER_HDRS = [ - "platform/abi.h", - "platform/stacktrace.h", - "platform/stacktrace_handler.h", - "platform/context.h", - "platform/cpu_info.h", - "platform/cpu_feature_guard.h", - "platform/dynamic_annotations.h", - "platform/error.h", - "platform/env.h", - "platform/file_system.h", - "platform/file_system_helper.h", - "platform/fingerprint.h", - "platform/init_main.h", - "platform/mem.h", - "platform/mutex.h", - "platform/net.h", - "platform/notification.h", - "platform/null_file_system.h", - "platform/prefetch.h", - "platform/profile_utils/clock_cycle_profiler.h", - "platform/profile_utils/cpu_utils.h", - "platform/protobuf.h", - "platform/strong_hash.h", - "platform/subprocess.h", - "platform/thread_annotations.h", -] +filegroup( + name = "platform_base_hdrs", + srcs = [ + "platform/byte_order.h", + "platform/env_time.h", + "platform/logging.h", + "platform/macros.h", + "platform/types.h", + ], + visibility = ["//visibility:private"], +) -# Smaller platform libraries that don't depend on "lib" or "lib_internal". cc_library( name = "platform_base", srcs = tf_platform_hdrs([ @@ -343,16 +316,275 @@ cc_library( ]) + [ "platform/env_time.cc", ], - hdrs = PLATFORM_BASE_HDRS, + hdrs = [":platform_base_hdrs"], copts = tf_copts(), - # TODO(ahentz): remove use of this library so we can move it into 'platform' tags = ["avoid_dep"], + visibility = ["//tensorflow/core:__subpackages__"], deps = [ ":lib_platform", "//tensorflow/core/platform/default/build_config:base", ], ) +filegroup( + name = "platform_port_hdrs", + srcs = [ + "platform/cpu_info.h", + "platform/dynamic_annotations.h", + "platform/init_main.h", + "platform/mem.h", + "platform/mutex.h", + "platform/thread_annotations.h", + ], + visibility = ["//visibility:private"], +) + +# Headers that are not exported as part of ":lib". +filegroup( + name = "platform_port_internal_hdrs", + srcs = [ + "platform/demangle.h", + "platform/host_info.h", + "platform/snappy.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_port", + srcs = tf_platform_hdrs([ + "cpu_info.h", + "dynamic_annotations.h", + "thread_annotations.h", + "mutex.h", + ]) + tf_platform_srcs([ + "port.cc", + ]) + [ + "platform/cpu_info.cc", + ], + hdrs = [ + ":platform_port_hdrs", + ":platform_port_internal_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":lib_platform", + ":platform_base", + "//tensorflow/core/platform/default/build_config:port", + "@snappy", + ], +) + +filegroup( + name = "platform_protobuf_hdrs", + srcs = [ + "platform/protobuf.h", + ], + visibility = ["//visibility:private"], +) + +# Headers that are not exported as part of ":lib". +filegroup( + name = "platform_protobuf_internal_hdrs", + srcs = [ + "platform/protobuf_internal.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_protobuf", + srcs = tf_platform_hdrs([ + "protobuf.h", + ]) + tf_platform_srcs([ + "protobuf.cc", + ]) + [ + "platform/protobuf_util.cc", + "lib/core/status.h", + ], + hdrs = [ + ":platform_protobuf_hdrs", + ":platform_protobuf_internal_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":lib_platform", + ":platform_base", + ":platform_port", + "//tensorflow/core/platform/default/build_config:protobuf", + "@protobuf_archive//:protobuf", + ], +) + +cc_library( + name = "human_readable_json", + srcs = tf_platform_srcs(["human_readable_json.cc"]), + hdrs = ["platform/human_readable_json.h"], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":lib", + ":lib_internal", + ] + tf_additional_human_readable_json_deps(), +) + +filegroup( + name = "platform_env_hdrs", + srcs = [ + "platform/env.h", + "platform/file_statistics.h", + "platform/file_system.h", + ], + visibility = ["//visibility:private"], +) + +# Headers that are not exported as part of ":lib". +filegroup( + name = "platform_env_internal_hdrs", + srcs = [ + "platform/load_library.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_env", + srcs = tf_platform_srcs([ + "env.cc", + "load_library.cc", + ]) + tf_platform_hdrs([ + "wide_char.h", + ]) + [ + "platform/env.cc", + "platform/file_system.cc", + ], + hdrs = [ + ":platform_env_hdrs", + ":platform_env_internal_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":error_codes_proto_cc", + ":lib", + ":lib_internal", + ":lib_platform", + ":platform_base", + ":platform_port", + ":platform_protobuf", + "//tensorflow/core/platform/default/build_config:env", + ], +) + +filegroup( + name = "platform_file_system_hdrs", + srcs = [ + "platform/file_system_helper.h", + "platform/null_file_system.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_file_system", + srcs = tf_platform_srcs([ + ]) + tf_platform_hdrs([ + "windows_file_system.h", + ]) + [ + "platform/file_system_helper.cc", + ], + hdrs = [ + ":platform_file_system_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":lib", + ":lib_platform", + ":platform_env", + ], +) + +filegroup( + name = "platform_other_hdrs", + srcs = [ + "platform/abi.h", + "platform/context.h", + "platform/cpu_feature_guard.h", + "platform/error.h", + "platform/fingerprint.h", + "platform/net.h", + "platform/notification.h", + "platform/prefetch.h", + "platform/profile_utils/android_armv7a_cpu_utils_helper.h", + "platform/profile_utils/clock_cycle_profiler.h", + "platform/profile_utils/cpu_utils.h", + "platform/profile_utils/i_cpu_utils_helper.h", + "platform/stacktrace.h", + "platform/stacktrace_handler.h", + "platform/strong_hash.h", + "platform/subprocess.h", + ], + visibility = ["//visibility:private"], +) + +# Headers that are not exported as part of ":lib". +filegroup( + name = "platform_other_internal_hdrs", + srcs = [ + "platform/denormal.h", + "platform/setround.h", + "platform/tracing.h", + ], + visibility = ["//visibility:private"], +) + +cc_library( + name = "platform_other", + srcs = tf_platform_srcs([ + "subprocess.cc", + "net.cc", + "tracing.cc", + ]) + tf_platform_hdrs([ + "tracing.h", + "error.h", + "context.h", + "fingerprint.h", + "notification.h", + "stacktrace.h", + "strong_hash.h", + "subprocess.h", + "tracing_impl.h", + ]) + [ + "platform/cpu_feature_guard.cc", + "platform/setround.cc", + "platform/tracing.cc", + "platform/denormal.cc", + "platform/profile_utils/android_armv7a_cpu_utils_helper.cc", + "platform/profile_utils/clock_cycle_profiler.cc", + "platform/profile_utils/cpu_utils.cc", + ], + hdrs = [ + ":platform_other_hdrs", + ":platform_other_internal_hdrs", + ], + copts = tf_copts(), + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + ":lib", + ":lib_platform", + ":platform_base", + ":platform_env", + ":platform_port", + ":platform_protobuf", + "//tensorflow/core/platform/default/build_config:other", + "//tensorflow/core/platform/default/build_config:platformlib", + "//tensorflow/core/platform/default/build_config:port", + ], +) + # Minimal lib so that tools used for mobile compilation # don't have to depend on lib/platformlib. cc_library( @@ -386,8 +618,7 @@ cc_library( # tf_cc_test and tf_cc_binary will include the necessary symbols. cc_library( name = "lib", - hdrs = PLATFORM_BASE_HDRS + - PLATFORM_OTHER_HDRS + [ + hdrs = [ "lib/bfloat16/bfloat16.h", "lib/core/arena.h", "lib/core/bitmap.h", @@ -434,6 +665,12 @@ cc_library( "lib/strings/str_util.h", "lib/strings/strcat.h", "lib/strings/stringprintf.h", + ":platform_base_hdrs", + ":platform_env_hdrs", + ":platform_file_system_hdrs", + ":platform_other_hdrs", + ":platform_port_hdrs", + ":platform_protobuf_hdrs", ], visibility = ["//visibility:public"], deps = [ @@ -607,6 +844,7 @@ tf_cuda_library( "util/sparse/group_iterator.h", "util/sparse/sparse_tensor.h", "util/stat_summarizer.h", + "util/stat_summarizer_options.h", "util/stream_executor_util.h", "util/strided_slice_op.h", "util/tensor_format.h", @@ -631,6 +869,18 @@ tf_cuda_library( deps = [":framework_internal"], ) +cc_library( + name = "stats_calculator_portable", + srcs = [ + "util/stat_summarizer_options.h", + "util/stats_calculator.cc", + ], + hdrs = [ + "util/stats_calculator.h", + ], + copts = tf_copts(), +) + cc_library( name = "overflow", hdrs = ["util/overflow.h"], @@ -640,6 +890,12 @@ cc_library( ], ) +cc_library( + name = "exec_on_stall", + hdrs = ["util/exec_on_stall.h"], + deps = [":framework_lite"], +) + cc_library( name = "ptr_util", hdrs = ["util/ptr_util.h"], @@ -1111,6 +1367,7 @@ cc_library( ":shape_inference_testutil", ":tensor_testutil", ":test", + ":testlib_ops", "//tensorflow/cc:scope", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:ops_testutil", @@ -1118,6 +1375,18 @@ cc_library( ], ) +cc_library( + name = "testlib_ops", + testonly = 1, + srcs = ["common_runtime/testlib_ops.cc"], + linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + # This is a link-only library to provide a DirectSession # implementation of the Session interface. tf_cuda_library( @@ -1179,6 +1448,7 @@ filegroup( "lib/png/**/*", "lib/gif/**/*", "util/events_writer.*", + "util/stats_calculator.*", "util/reporter.*", "platform/**/cuda_libdevice_path.*", "platform/default/test_benchmark.*", @@ -1262,6 +1532,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protos_all_cc_impl", + ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", "@nsync//:nsync_cpp", @@ -1302,6 +1573,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":protos_all_cc_impl", + ":stats_calculator_portable", "//third_party/eigen3", "@double_conversion//:double-conversion", "@nsync//:nsync_cpp", @@ -1767,9 +2039,8 @@ cc_library( "platform/**/cuda_libdevice_path.cc", "platform/**/device_tracer.cc", "platform/**/logging.cc", + "platform/**/human_readable_json.cc", "platform/abi.cc", - "platform/variant_coding.cc", - "platform/**/variant_cord_coding.cc", ], ) + tf_additional_lib_srcs( exclude = [ @@ -1781,9 +2052,8 @@ cc_library( "platform/**/env_time.cc", "platform/**/device_tracer.cc", "platform/**/logging.cc", + "platform/**/human_readable_json.cc", "platform/abi.cc", - "platform/variant_coding.cc", - "platform/**/variant_cord_coding.cc", ] + # Protobuf deps already included through the ":lib_proto_parsing" # dependency. @@ -2033,7 +2303,6 @@ cc_library( ) FRAMEWORK_INTERNAL_PRIVATE_HEADERS = [ - "platform/variant_coding.h", "graph/edgeset.h", "graph/graph.h", "graph/graph_def_builder.h", @@ -2074,14 +2343,13 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ "framework/tracking_allocator.h", # only needed for tests "framework/unique_tensor_references.h", "framework/variant.h", - "platform/variant_coding.h", "util/command_line_flags.h", "util/env_var.h", "util/equal_graph_def.h", "util/presized_cuckoo_map.h", "util/tensor_slice_set.h", "util/tensor_slice_util.h", -] + tf_additional_framework_hdrs() +] tf_cuda_library( name = "framework_internal", @@ -2123,9 +2391,7 @@ cc_header_only_library( tf_cuda_library( name = "framework_internal_impl", - srcs = FRAMEWORK_INTERNAL_PRIVATE_HEADERS + [ - "platform/variant_coding.cc", - ] + glob( + srcs = FRAMEWORK_INTERNAL_PRIVATE_HEADERS + glob( [ "example/**/*.cc", "framework/**/*.cc", @@ -2159,7 +2425,7 @@ tf_cuda_library( "util/memmapped_file_system.cc", "util/memmapped_file_system_writer.cc", ], - }) + tf_additional_framework_srcs(), + }), hdrs = FRAMEWORK_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), linkopts = select({ @@ -2368,6 +2634,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/dma_helper.h", "common_runtime/eigen_thread_pool.h", "common_runtime/executor.h", + "common_runtime/executor_factory.h", "common_runtime/graph_optimizer.h", "common_runtime/local_device.h", "common_runtime/lower_if_op.h", @@ -2417,6 +2684,7 @@ tf_cuda_library( "common_runtime/device_resolver_local.cc", "common_runtime/device_set.cc", "common_runtime/executor.cc", + "common_runtime/executor_factory.cc", "common_runtime/function.cc", "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", @@ -2524,6 +2792,7 @@ cc_library( ], visibility = [ "//tensorflow/compiler:__subpackages__", + "//tensorflow/core/kernels:__subpackages__", "//tensorflow/core/profiler:__subpackages__", ], deps = [":lib_internal"], @@ -2997,6 +3266,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "exec_on_stall_test", + size = "small", + srcs = ["util/exec_on_stall_test.cc"], + deps = [ + ":exec_on_stall", + ":framework_lite", + ":test", + ":test_main", + ], +) + tf_cc_test( name = "lib_jpeg_jpeg_mem_unittest", srcs = ["lib/jpeg/jpeg_mem_unittest.cc"], @@ -3760,6 +4041,31 @@ tf_cc_test( ], ) +tf_cc_test( + name = "common_runtime_executor_test", + size = "small", + srcs = ["common_runtime/executor_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core", + ":core_cpu", + ":core_cpu_internal", + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":protos_all_cc", + ":test", + ":test_main", + ":testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:math", + "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:state", + ], +) + tf_cc_test( name = "common_runtime_function_test", size = "small", diff --git a/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..d8c2ed40a324d4854d83c471e8eef50e50277b93 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AnonymousIterator.pbtxt @@ -0,0 +1,13 @@ +op { + graph_op_name: "AnonymousIterator" + out_arg { + name: "handle" + description: <